blob: 44616afa3f42f4f4701435c293a74e254cccf675 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
David Neto118188e2018-08-24 11:27:54 -040019#include "llvm/IR/Constants.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000023#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/Pass.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/raw_ostream.h"
27#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040028
David Neto118188e2018-08-24 11:27:54 -040029#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040030
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Neto22f144c2017-06-12 14:26:21 -040033using namespace llvm;
34
35#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
36
37namespace {
38uint32_t clz(uint32_t v) {
39 uint32_t r;
40 uint32_t shift;
41
42 r = (v > 0xFFFF) << 4;
43 v >>= r;
44 shift = (v > 0xFF) << 3;
45 v >>= shift;
46 r |= shift;
47 shift = (v > 0xF) << 2;
48 v >>= shift;
49 r |= shift;
50 shift = (v > 0x3) << 1;
51 v >>= shift;
52 r |= shift;
53 r |= (v >> 1);
54
55 return r;
56}
57
58Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
59 if (1 == elements) {
60 return Type::getInt1Ty(C);
61 } else {
62 return VectorType::get(Type::getInt1Ty(C), elements);
63 }
64}
65
66struct ReplaceOpenCLBuiltinPass final : public ModulePass {
67 static char ID;
68 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
69
70 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +000071 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040072 bool replaceRecip(Module &M);
73 bool replaceDivide(Module &M);
74 bool replaceExp10(Module &M);
75 bool replaceLog10(Module &M);
76 bool replaceBarrier(Module &M);
77 bool replaceMemFence(Module &M);
78 bool replaceRelational(Module &M);
79 bool replaceIsInfAndIsNan(Module &M);
80 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +000081 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +000082 bool replaceRotate(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +000083 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +000084 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +000085 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040086 bool replaceSignbit(Module &M);
87 bool replaceMadandMad24andMul24(Module &M);
88 bool replaceVloadHalf(Module &M);
89 bool replaceVloadHalf2(Module &M);
90 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -070091 bool replaceClspvVloadaHalf2(Module &M);
92 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040093 bool replaceVstoreHalf(Module &M);
94 bool replaceVstoreHalf2(Module &M);
95 bool replaceVstoreHalf4(Module &M);
96 bool replaceReadImageF(Module &M);
97 bool replaceAtomics(Module &M);
98 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -040099 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700100 bool replaceVload(Module &M);
101 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400102};
103}
104
105char ReplaceOpenCLBuiltinPass::ID = 0;
106static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
107 "Replace OpenCL Builtins Pass");
108
109namespace clspv {
110ModulePass *createReplaceOpenCLBuiltinPass() {
111 return new ReplaceOpenCLBuiltinPass();
112}
113}
114
115bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
116 bool Changed = false;
117
Kévin Petit2444e9b2018-11-09 14:14:37 +0000118 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400119 Changed |= replaceRecip(M);
120 Changed |= replaceDivide(M);
121 Changed |= replaceExp10(M);
122 Changed |= replaceLog10(M);
123 Changed |= replaceBarrier(M);
124 Changed |= replaceMemFence(M);
125 Changed |= replaceRelational(M);
126 Changed |= replaceIsInfAndIsNan(M);
127 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000128 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000129 Changed |= replaceRotate(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000130 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000131 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000132 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400133 Changed |= replaceSignbit(M);
134 Changed |= replaceMadandMad24andMul24(M);
135 Changed |= replaceVloadHalf(M);
136 Changed |= replaceVloadHalf2(M);
137 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700138 Changed |= replaceClspvVloadaHalf2(M);
139 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400140 Changed |= replaceVstoreHalf(M);
141 Changed |= replaceVstoreHalf2(M);
142 Changed |= replaceVstoreHalf4(M);
143 Changed |= replaceReadImageF(M);
144 Changed |= replaceAtomics(M);
145 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400146 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700147 Changed |= replaceVload(M);
148 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400149
150 return Changed;
151}
152
Kévin Petit2444e9b2018-11-09 14:14:37 +0000153bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
154 bool Changed = false;
155
156 const char *Names[] = {
157 "_Z3abst",
158 "_Z3absDv2_t",
159 "_Z3absDv3_t",
160 "_Z3absDv4_t",
161 "_Z3absj",
162 "_Z3absDv2_j",
163 "_Z3absDv3_j",
164 "_Z3absDv4_j",
165 "_Z3absm",
166 "_Z3absDv2_m",
167 "_Z3absDv3_m",
168 "_Z3absDv4_m",
169 };
170
171 for (auto Name : Names) {
172 // If we find a function with the matching name.
173 if (auto F = M.getFunction(Name)) {
174 SmallVector<Instruction *, 4> ToRemoves;
175
176 // Walk the users of the function.
177 for (auto &U : F->uses()) {
178 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
179 // Abs has one arg.
180 auto Arg = CI->getOperand(0);
181
182 // Use the argument unchanged, we know it's unsigned
183 CI->replaceAllUsesWith(Arg);
184
185 // Lastly, remember to remove the user.
186 ToRemoves.push_back(CI);
187 }
188 }
189
190 Changed = !ToRemoves.empty();
191
192 // And cleanup the calls we don't use anymore.
193 for (auto V : ToRemoves) {
194 V->eraseFromParent();
195 }
196
197 // And remove the function we don't need either too.
198 F->eraseFromParent();
199 }
200 }
201
202 return Changed;
203}
204
David Neto22f144c2017-06-12 14:26:21 -0400205bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
206 bool Changed = false;
207
208 const char *Names[] = {
209 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
210 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
211 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
212 };
213
214 for (auto Name : Names) {
215 // If we find a function with the matching name.
216 if (auto F = M.getFunction(Name)) {
217 SmallVector<Instruction *, 4> ToRemoves;
218
219 // Walk the users of the function.
220 for (auto &U : F->uses()) {
221 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
222 // Recip has one arg.
223 auto Arg = CI->getOperand(0);
224
225 auto Div = BinaryOperator::Create(
226 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
227 CI);
228
229 CI->replaceAllUsesWith(Div);
230
231 // Lastly, remember to remove the user.
232 ToRemoves.push_back(CI);
233 }
234 }
235
236 Changed = !ToRemoves.empty();
237
238 // And cleanup the calls we don't use anymore.
239 for (auto V : ToRemoves) {
240 V->eraseFromParent();
241 }
242
243 // And remove the function we don't need either too.
244 F->eraseFromParent();
245 }
246 }
247
248 return Changed;
249}
250
251bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
252 bool Changed = false;
253
254 const char *Names[] = {
255 "_Z11half_divideff", "_Z13native_divideff",
256 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
257 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
258 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
259 };
260
261 for (auto Name : Names) {
262 // If we find a function with the matching name.
263 if (auto F = M.getFunction(Name)) {
264 SmallVector<Instruction *, 4> ToRemoves;
265
266 // Walk the users of the function.
267 for (auto &U : F->uses()) {
268 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
269 auto Div = BinaryOperator::Create(
270 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
271
272 CI->replaceAllUsesWith(Div);
273
274 // Lastly, remember to remove the user.
275 ToRemoves.push_back(CI);
276 }
277 }
278
279 Changed = !ToRemoves.empty();
280
281 // And cleanup the calls we don't use anymore.
282 for (auto V : ToRemoves) {
283 V->eraseFromParent();
284 }
285
286 // And remove the function we don't need either too.
287 F->eraseFromParent();
288 }
289 }
290
291 return Changed;
292}
293
294bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
295 bool Changed = false;
296
297 const std::map<const char *, const char *> Map = {
298 {"_Z5exp10f", "_Z3expf"},
299 {"_Z10half_exp10f", "_Z8half_expf"},
300 {"_Z12native_exp10f", "_Z10native_expf"},
301 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
302 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
303 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
304 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
305 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
306 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
307 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
308 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
309 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
310
311 for (auto Pair : Map) {
312 // If we find a function with the matching name.
313 if (auto F = M.getFunction(Pair.first)) {
314 SmallVector<Instruction *, 4> ToRemoves;
315
316 // Walk the users of the function.
317 for (auto &U : F->uses()) {
318 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
319 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
320
321 auto Arg = CI->getOperand(0);
322
323 // Constant of the natural log of 10 (ln(10)).
324 const double Ln10 =
325 2.302585092994045684017991454684364207601101488628772976033;
326
327 auto Mul = BinaryOperator::Create(
328 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
329 CI);
330
331 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
332
333 CI->replaceAllUsesWith(NewCI);
334
335 // Lastly, remember to remove the user.
336 ToRemoves.push_back(CI);
337 }
338 }
339
340 Changed = !ToRemoves.empty();
341
342 // And cleanup the calls we don't use anymore.
343 for (auto V : ToRemoves) {
344 V->eraseFromParent();
345 }
346
347 // And remove the function we don't need either too.
348 F->eraseFromParent();
349 }
350 }
351
352 return Changed;
353}
354
355bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
356 bool Changed = false;
357
358 const std::map<const char *, const char *> Map = {
359 {"_Z5log10f", "_Z3logf"},
360 {"_Z10half_log10f", "_Z8half_logf"},
361 {"_Z12native_log10f", "_Z10native_logf"},
362 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
363 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
364 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
365 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
366 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
367 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
368 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
369 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
370 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
371
372 for (auto Pair : Map) {
373 // If we find a function with the matching name.
374 if (auto F = M.getFunction(Pair.first)) {
375 SmallVector<Instruction *, 4> ToRemoves;
376
377 // Walk the users of the function.
378 for (auto &U : F->uses()) {
379 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
380 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
381
382 auto Arg = CI->getOperand(0);
383
384 // Constant of the reciprocal of the natural log of 10 (ln(10)).
385 const double Ln10 =
386 0.434294481903251827651128918916605082294397005803666566114;
387
388 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
389
390 auto Mul = BinaryOperator::Create(
391 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
392 "", CI);
393
394 CI->replaceAllUsesWith(Mul);
395
396 // Lastly, remember to remove the user.
397 ToRemoves.push_back(CI);
398 }
399 }
400
401 Changed = !ToRemoves.empty();
402
403 // And cleanup the calls we don't use anymore.
404 for (auto V : ToRemoves) {
405 V->eraseFromParent();
406 }
407
408 // And remove the function we don't need either too.
409 F->eraseFromParent();
410 }
411 }
412
413 return Changed;
414}
415
416bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
417 bool Changed = false;
418
419 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
420
421 const std::map<const char *, const char *> Map = {
422 {"_Z7barrierj", "__spirv_control_barrier"}};
423
424 for (auto Pair : Map) {
425 // If we find a function with the matching name.
426 if (auto F = M.getFunction(Pair.first)) {
427 SmallVector<Instruction *, 4> ToRemoves;
428
429 // Walk the users of the function.
430 for (auto &U : F->uses()) {
431 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
432 auto FType = F->getFunctionType();
433 SmallVector<Type *, 3> Params;
434 for (unsigned i = 0; i < 3; i++) {
435 Params.push_back(FType->getParamType(0));
436 }
437 auto NewFType =
438 FunctionType::get(FType->getReturnType(), Params, false);
439 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
440
441 auto Arg = CI->getOperand(0);
442
443 // We need to map the OpenCL constants to the SPIR-V equivalents.
444 const auto LocalMemFence =
445 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
446 const auto GlobalMemFence =
447 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
448 const auto ConstantSequentiallyConsistent = ConstantInt::get(
449 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
450 const auto ConstantScopeDevice =
451 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
452 const auto ConstantScopeWorkgroup =
453 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
454
455 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
456 const auto LocalMemFenceMask = BinaryOperator::Create(
457 Instruction::And, LocalMemFence, Arg, "", CI);
458 const auto WorkgroupShiftAmount =
459 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
460 clz(CLK_LOCAL_MEM_FENCE);
461 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
462 Instruction::Shl, LocalMemFenceMask,
463 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
464
465 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
466 const auto GlobalMemFenceMask = BinaryOperator::Create(
467 Instruction::And, GlobalMemFence, Arg, "", CI);
468 const auto UniformShiftAmount =
469 clz(spv::MemorySemanticsUniformMemoryMask) -
470 clz(CLK_GLOBAL_MEM_FENCE);
471 const auto MemorySemanticsUniform = BinaryOperator::Create(
472 Instruction::Shl, GlobalMemFenceMask,
473 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
474
475 // And combine the above together, also adding in
476 // MemorySemanticsSequentiallyConsistentMask.
477 auto MemorySemantics =
478 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
479 ConstantSequentiallyConsistent, "", CI);
480 MemorySemantics = BinaryOperator::Create(
481 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
482
483 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
484 // Device Scope, otherwise Workgroup Scope.
485 const auto Cmp =
486 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
487 GlobalMemFenceMask, GlobalMemFence, "", CI);
488 const auto MemoryScope = SelectInst::Create(
489 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
490
491 // Lastly, the Execution Scope is always Workgroup Scope.
492 const auto ExecutionScope = ConstantScopeWorkgroup;
493
494 auto NewCI = CallInst::Create(
495 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
496
497 CI->replaceAllUsesWith(NewCI);
498
499 // Lastly, remember to remove the user.
500 ToRemoves.push_back(CI);
501 }
502 }
503
504 Changed = !ToRemoves.empty();
505
506 // And cleanup the calls we don't use anymore.
507 for (auto V : ToRemoves) {
508 V->eraseFromParent();
509 }
510
511 // And remove the function we don't need either too.
512 F->eraseFromParent();
513 }
514 }
515
516 return Changed;
517}
518
519bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
520 bool Changed = false;
521
522 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
523
Neil Henning39672102017-09-29 14:33:13 +0100524 using Tuple = std::tuple<const char *, unsigned>;
525 const std::map<const char *, Tuple> Map = {
526 {"_Z9mem_fencej",
527 Tuple("__spirv_memory_barrier",
528 spv::MemorySemanticsSequentiallyConsistentMask)},
529 {"_Z14read_mem_fencej",
530 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
531 {"_Z15write_mem_fencej",
532 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400533
534 for (auto Pair : Map) {
535 // If we find a function with the matching name.
536 if (auto F = M.getFunction(Pair.first)) {
537 SmallVector<Instruction *, 4> ToRemoves;
538
539 // Walk the users of the function.
540 for (auto &U : F->uses()) {
541 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
542 auto FType = F->getFunctionType();
543 SmallVector<Type *, 2> Params;
544 for (unsigned i = 0; i < 2; i++) {
545 Params.push_back(FType->getParamType(0));
546 }
547 auto NewFType =
548 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100549 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400550
551 auto Arg = CI->getOperand(0);
552
553 // We need to map the OpenCL constants to the SPIR-V equivalents.
554 const auto LocalMemFence =
555 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
556 const auto GlobalMemFence =
557 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
558 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100559 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400560 const auto ConstantScopeDevice =
561 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
562
563 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
564 const auto LocalMemFenceMask = BinaryOperator::Create(
565 Instruction::And, LocalMemFence, Arg, "", CI);
566 const auto WorkgroupShiftAmount =
567 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
568 clz(CLK_LOCAL_MEM_FENCE);
569 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
570 Instruction::Shl, LocalMemFenceMask,
571 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
572
573 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
574 const auto GlobalMemFenceMask = BinaryOperator::Create(
575 Instruction::And, GlobalMemFence, Arg, "", CI);
576 const auto UniformShiftAmount =
577 clz(spv::MemorySemanticsUniformMemoryMask) -
578 clz(CLK_GLOBAL_MEM_FENCE);
579 const auto MemorySemanticsUniform = BinaryOperator::Create(
580 Instruction::Shl, GlobalMemFenceMask,
581 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
582
583 // And combine the above together, also adding in
584 // MemorySemanticsSequentiallyConsistentMask.
585 auto MemorySemantics =
586 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
587 ConstantMemorySemantics, "", CI);
588 MemorySemantics = BinaryOperator::Create(
589 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
590
591 // Memory Scope is always device.
592 const auto MemoryScope = ConstantScopeDevice;
593
594 auto NewCI =
595 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
596
597 CI->replaceAllUsesWith(NewCI);
598
599 // Lastly, remember to remove the user.
600 ToRemoves.push_back(CI);
601 }
602 }
603
604 Changed = !ToRemoves.empty();
605
606 // And cleanup the calls we don't use anymore.
607 for (auto V : ToRemoves) {
608 V->eraseFromParent();
609 }
610
611 // And remove the function we don't need either too.
612 F->eraseFromParent();
613 }
614 }
615
616 return Changed;
617}
618
619bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
620 bool Changed = false;
621
622 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
623 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
624 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
625 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
626 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
627 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
628 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
629 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
630 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
631 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
632 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
633 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
634 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
635 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
636 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
637 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
638 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
639 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
640 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
641 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
642 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
643 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
644 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
645 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
646 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
647 };
648
649 for (auto Pair : Map) {
650 // If we find a function with the matching name.
651 if (auto F = M.getFunction(Pair.first)) {
652 SmallVector<Instruction *, 4> ToRemoves;
653
654 // Walk the users of the function.
655 for (auto &U : F->uses()) {
656 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
657 // The predicate to use in the CmpInst.
658 auto Predicate = Pair.second.first;
659
660 // The value to return for true.
661 auto TrueValue =
662 ConstantInt::getSigned(CI->getType(), Pair.second.second);
663
664 // The value to return for false.
665 auto FalseValue = Constant::getNullValue(CI->getType());
666
667 auto Arg1 = CI->getOperand(0);
668 auto Arg2 = CI->getOperand(1);
669
670 const auto Cmp =
671 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
672
673 const auto Select =
674 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
675
676 CI->replaceAllUsesWith(Select);
677
678 // Lastly, remember to remove the user.
679 ToRemoves.push_back(CI);
680 }
681 }
682
683 Changed = !ToRemoves.empty();
684
685 // And cleanup the calls we don't use anymore.
686 for (auto V : ToRemoves) {
687 V->eraseFromParent();
688 }
689
690 // And remove the function we don't need either too.
691 F->eraseFromParent();
692 }
693 }
694
695 return Changed;
696}
697
698bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
699 bool Changed = false;
700
701 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
702 {"_Z5isinff", {"__spirv_isinff", 1}},
703 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
704 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
705 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
706 {"_Z5isnanf", {"__spirv_isnanf", 1}},
707 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
708 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
709 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
710 };
711
712 for (auto Pair : Map) {
713 // If we find a function with the matching name.
714 if (auto F = M.getFunction(Pair.first)) {
715 SmallVector<Instruction *, 4> ToRemoves;
716
717 // Walk the users of the function.
718 for (auto &U : F->uses()) {
719 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
720 const auto CITy = CI->getType();
721
722 // The fake SPIR-V intrinsic to generate.
723 auto SPIRVIntrinsic = Pair.second.first;
724
725 // The value to return for true.
726 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
727
728 // The value to return for false.
729 auto FalseValue = Constant::getNullValue(CITy);
730
731 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
732 M.getContext(),
733 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
734
735 auto NewFType =
736 FunctionType::get(CorrespondingBoolTy,
737 F->getFunctionType()->getParamType(0), false);
738
739 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
740
741 auto Arg = CI->getOperand(0);
742
743 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
744
745 const auto Select =
746 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
747
748 CI->replaceAllUsesWith(Select);
749
750 // Lastly, remember to remove the user.
751 ToRemoves.push_back(CI);
752 }
753 }
754
755 Changed = !ToRemoves.empty();
756
757 // And cleanup the calls we don't use anymore.
758 for (auto V : ToRemoves) {
759 V->eraseFromParent();
760 }
761
762 // And remove the function we don't need either too.
763 F->eraseFromParent();
764 }
765 }
766
767 return Changed;
768}
769
770bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
771 bool Changed = false;
772
773 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000774 // all
alan-bakerb39c8262019-03-08 14:03:37 -0500775 {"_Z3allc", ""},
776 {"_Z3allDv2_c", "__spirv_allDv2_c"},
777 {"_Z3allDv3_c", "__spirv_allDv3_c"},
778 {"_Z3allDv4_c", "__spirv_allDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000779 {"_Z3alls", ""},
780 {"_Z3allDv2_s", "__spirv_allDv2_s"},
781 {"_Z3allDv3_s", "__spirv_allDv3_s"},
782 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400783 {"_Z3alli", ""},
784 {"_Z3allDv2_i", "__spirv_allDv2_i"},
785 {"_Z3allDv3_i", "__spirv_allDv3_i"},
786 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000787 {"_Z3alll", ""},
788 {"_Z3allDv2_l", "__spirv_allDv2_l"},
789 {"_Z3allDv3_l", "__spirv_allDv3_l"},
790 {"_Z3allDv4_l", "__spirv_allDv4_l"},
791
792 // any
alan-bakerb39c8262019-03-08 14:03:37 -0500793 {"_Z3anyc", ""},
794 {"_Z3anyDv2_c", "__spirv_anyDv2_c"},
795 {"_Z3anyDv3_c", "__spirv_anyDv3_c"},
796 {"_Z3anyDv4_c", "__spirv_anyDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000797 {"_Z3anys", ""},
798 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
799 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
800 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400801 {"_Z3anyi", ""},
802 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
803 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
804 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000805 {"_Z3anyl", ""},
806 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
807 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
808 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400809 };
810
811 for (auto Pair : Map) {
812 // If we find a function with the matching name.
813 if (auto F = M.getFunction(Pair.first)) {
814 SmallVector<Instruction *, 4> ToRemoves;
815
816 // Walk the users of the function.
817 for (auto &U : F->uses()) {
818 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
819 // The fake SPIR-V intrinsic to generate.
820 auto SPIRVIntrinsic = Pair.second;
821
822 auto Arg = CI->getOperand(0);
823
824 Value *V;
825
Kévin Petitfd27cca2018-10-31 13:00:17 +0000826 // If the argument is a 32-bit int, just use a shift
827 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
828 V = BinaryOperator::Create(Instruction::LShr, Arg,
829 ConstantInt::get(Arg->getType(), 31), "",
830 CI);
831 } else {
David Neto22f144c2017-06-12 14:26:21 -0400832 // The value for zero to compare against.
833 const auto ZeroValue = Constant::getNullValue(Arg->getType());
834
David Neto22f144c2017-06-12 14:26:21 -0400835 // The value to return for true.
836 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
837
838 // The value to return for false.
839 const auto FalseValue = Constant::getNullValue(CI->getType());
840
Kévin Petitfd27cca2018-10-31 13:00:17 +0000841 const auto Cmp = CmpInst::Create(
842 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
843
844 Value* SelectSource;
845
846 // If we have a function to call, call it!
847 if (0 < strlen(SPIRVIntrinsic)) {
848
849 const auto NewFType = FunctionType::get(
850 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
851
852 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
853
854 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
855
856 SelectSource = NewCI;
857
858 } else {
859 SelectSource = Cmp;
860 }
861
862 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400863 }
864
865 CI->replaceAllUsesWith(V);
866
867 // Lastly, remember to remove the user.
868 ToRemoves.push_back(CI);
869 }
870 }
871
872 Changed = !ToRemoves.empty();
873
874 // And cleanup the calls we don't use anymore.
875 for (auto V : ToRemoves) {
876 V->eraseFromParent();
877 }
878
879 // And remove the function we don't need either too.
880 F->eraseFromParent();
881 }
882 }
883
884 return Changed;
885}
886
Kévin Petitbf0036c2019-03-06 13:57:10 +0000887bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
888 bool Changed = false;
889
890 for (auto const &SymVal : M.getValueSymbolTable()) {
891 // Skip symbols whose name doesn't match
892 if (!SymVal.getKey().startswith("_Z8upsample")) {
893 continue;
894 }
895 // Is there a function going by that name?
896 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
897
898 SmallVector<Instruction *, 4> ToRemoves;
899
900 // Walk the users of the function.
901 for (auto &U : F->uses()) {
902 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
903
904 // Get arguments
905 auto HiValue = CI->getOperand(0);
906 auto LoValue = CI->getOperand(1);
907
908 // Don't touch overloads that aren't in OpenCL C
909 auto HiType = HiValue->getType();
910 auto LoType = LoValue->getType();
911
912 if (HiType != LoType) {
913 continue;
914 }
915
916 if (!HiType->isIntOrIntVectorTy()) {
917 continue;
918 }
919
920 if (HiType->getScalarSizeInBits() * 2 !=
921 CI->getType()->getScalarSizeInBits()) {
922 continue;
923 }
924
925 if ((HiType->getScalarSizeInBits() != 8) &&
926 (HiType->getScalarSizeInBits() != 16) &&
927 (HiType->getScalarSizeInBits() != 32)) {
928 continue;
929 }
930
931 if (HiType->isVectorTy()) {
932 if ((HiType->getVectorNumElements() != 2) &&
933 (HiType->getVectorNumElements() != 3) &&
934 (HiType->getVectorNumElements() != 4) &&
935 (HiType->getVectorNumElements() != 8) &&
936 (HiType->getVectorNumElements() != 16)) {
937 continue;
938 }
939 }
940
941 // Convert both operands to the result type
942 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
943 "", CI);
944 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
945 "", CI);
946
947 // Shift high operand
948 auto ShiftAmount = ConstantInt::get(CI->getType(),
949 HiType->getScalarSizeInBits());
950 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
951 ShiftAmount, "", CI);
952
953 // OR both results
954 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
955 "", CI);
956
957 // Replace call with the expression
958 CI->replaceAllUsesWith(V);
959
960 // Lastly, remember to remove the user.
961 ToRemoves.push_back(CI);
962 }
963 }
964
965 Changed = !ToRemoves.empty();
966
967 // And cleanup the calls we don't use anymore.
968 for (auto V : ToRemoves) {
969 V->eraseFromParent();
970 }
971
972 // And remove the function we don't need either too.
973 F->eraseFromParent();
974 }
975 }
976
977 return Changed;
978}
979
Kévin Petitd44eef52019-03-08 13:22:14 +0000980bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
981 bool Changed = false;
982
983 for (auto const &SymVal : M.getValueSymbolTable()) {
984 // Skip symbols whose name doesn't match
985 if (!SymVal.getKey().startswith("_Z6rotate")) {
986 continue;
987 }
988 // Is there a function going by that name?
989 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
990
991 SmallVector<Instruction *, 4> ToRemoves;
992
993 // Walk the users of the function.
994 for (auto &U : F->uses()) {
995 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
996
997 // Get arguments
998 auto SrcValue = CI->getOperand(0);
999 auto RotAmount = CI->getOperand(1);
1000
1001 // Don't touch overloads that aren't in OpenCL C
1002 auto SrcType = SrcValue->getType();
1003 auto RotType = RotAmount->getType();
1004
1005 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1006 continue;
1007 }
1008
1009 if (!SrcType->isIntOrIntVectorTy()) {
1010 continue;
1011 }
1012
1013 if ((SrcType->getScalarSizeInBits() != 8) &&
1014 (SrcType->getScalarSizeInBits() != 16) &&
1015 (SrcType->getScalarSizeInBits() != 32) &&
1016 (SrcType->getScalarSizeInBits() != 64)) {
1017 continue;
1018 }
1019
1020 if (SrcType->isVectorTy()) {
1021 if ((SrcType->getVectorNumElements() != 2) &&
1022 (SrcType->getVectorNumElements() != 3) &&
1023 (SrcType->getVectorNumElements() != 4) &&
1024 (SrcType->getVectorNumElements() != 8) &&
1025 (SrcType->getVectorNumElements() != 16)) {
1026 continue;
1027 }
1028 }
1029
1030 // The approach used is to shift the top bits down, the bottom bits up
1031 // and OR the two shifted values.
1032
1033 // The rotation amount is to be treated modulo the element size.
1034 // Since SPIR-V shift ops don't support this, let's apply the
1035 // modulo ahead of shifting. The element size is always a power of
1036 // two so we can just AND with a mask.
1037 auto ModMask = ConstantInt::get(SrcType,
1038 SrcType->getScalarSizeInBits() - 1);
1039 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1040 ModMask, "", CI);
1041
1042 // Let's calc the amount by which to shift top bits down
1043 auto ScalarSize = ConstantInt::get(SrcType,
1044 SrcType->getScalarSizeInBits());
1045 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1046 RotAmount, "", CI);
1047
1048 // Now shift the bottom bits up and the top bits down
1049 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1050 RotAmount, "", CI);
1051 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1052 DownAmount, "", CI);
1053
1054 // Finally OR the two shifted values
1055 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1056 HiRotated, "", CI);
1057
1058 // Replace call with the expression
1059 CI->replaceAllUsesWith(V);
1060
1061 // Lastly, remember to remove the user.
1062 ToRemoves.push_back(CI);
1063 }
1064 }
1065
1066 Changed = !ToRemoves.empty();
1067
1068 // And cleanup the calls we don't use anymore.
1069 for (auto V : ToRemoves) {
1070 V->eraseFromParent();
1071 }
1072
1073 // And remove the function we don't need either too.
1074 F->eraseFromParent();
1075 }
1076 }
1077
1078 return Changed;
1079}
1080
Kévin Petitf5b78a22018-10-25 14:32:17 +00001081bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1082 bool Changed = false;
1083
1084 for (auto const &SymVal : M.getValueSymbolTable()) {
1085 // Skip symbols whose name doesn't match
1086 if (!SymVal.getKey().startswith("_Z6select")) {
1087 continue;
1088 }
1089 // Is there a function going by that name?
1090 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1091
1092 SmallVector<Instruction *, 4> ToRemoves;
1093
1094 // Walk the users of the function.
1095 for (auto &U : F->uses()) {
1096 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1097
1098 // Get arguments
1099 auto FalseValue = CI->getOperand(0);
1100 auto TrueValue = CI->getOperand(1);
1101 auto PredicateValue = CI->getOperand(2);
1102
1103 // Don't touch overloads that aren't in OpenCL C
1104 auto FalseType = FalseValue->getType();
1105 auto TrueType = TrueValue->getType();
1106 auto PredicateType = PredicateValue->getType();
1107
1108 if (FalseType != TrueType) {
1109 continue;
1110 }
1111
1112 if (!PredicateType->isIntOrIntVectorTy()) {
1113 continue;
1114 }
1115
1116 if (!FalseType->isIntOrIntVectorTy() &&
1117 !FalseType->getScalarType()->isFloatingPointTy()) {
1118 continue;
1119 }
1120
1121 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1122 continue;
1123 }
1124
1125 if (FalseType->getScalarSizeInBits() !=
1126 PredicateType->getScalarSizeInBits()) {
1127 continue;
1128 }
1129
1130 if (FalseType->isVectorTy()) {
1131 if (FalseType->getVectorNumElements() !=
1132 PredicateType->getVectorNumElements()) {
1133 continue;
1134 }
1135
1136 if ((FalseType->getVectorNumElements() != 2) &&
1137 (FalseType->getVectorNumElements() != 3) &&
1138 (FalseType->getVectorNumElements() != 4) &&
1139 (FalseType->getVectorNumElements() != 8) &&
1140 (FalseType->getVectorNumElements() != 16)) {
1141 continue;
1142 }
1143 }
1144
1145 // Create constant
1146 const auto ZeroValue = Constant::getNullValue(PredicateType);
1147
1148 // Scalar and vector are to be treated differently
1149 CmpInst::Predicate Pred;
1150 if (PredicateType->isVectorTy()) {
1151 Pred = CmpInst::ICMP_SLT;
1152 } else {
1153 Pred = CmpInst::ICMP_NE;
1154 }
1155
1156 // Create comparison instruction
1157 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1158 ZeroValue, "", CI);
1159
1160 // Create select
1161 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1162
1163 // Replace call with the selection
1164 CI->replaceAllUsesWith(V);
1165
1166 // Lastly, remember to remove the user.
1167 ToRemoves.push_back(CI);
1168 }
1169 }
1170
1171 Changed = !ToRemoves.empty();
1172
1173 // And cleanup the calls we don't use anymore.
1174 for (auto V : ToRemoves) {
1175 V->eraseFromParent();
1176 }
1177
1178 // And remove the function we don't need either too.
1179 F->eraseFromParent();
1180 }
1181 }
1182
1183 return Changed;
1184}
1185
Kévin Petite7d0cce2018-10-31 12:38:56 +00001186bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1187 bool Changed = false;
1188
1189 for (auto const &SymVal : M.getValueSymbolTable()) {
1190 // Skip symbols whose name doesn't match
1191 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1192 continue;
1193 }
1194 // Is there a function going by that name?
1195 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1196
1197 SmallVector<Instruction *, 4> ToRemoves;
1198
1199 // Walk the users of the function.
1200 for (auto &U : F->uses()) {
1201 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1202
1203 if (CI->getNumOperands() != 4) {
1204 continue;
1205 }
1206
1207 // Get arguments
1208 auto FalseValue = CI->getOperand(0);
1209 auto TrueValue = CI->getOperand(1);
1210 auto PredicateValue = CI->getOperand(2);
1211
1212 // Don't touch overloads that aren't in OpenCL C
1213 auto FalseType = FalseValue->getType();
1214 auto TrueType = TrueValue->getType();
1215 auto PredicateType = PredicateValue->getType();
1216
1217 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1218 continue;
1219 }
1220
1221 if (TrueType->isVectorTy()) {
1222 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1223 !TrueType->getScalarType()->isIntegerTy()) {
1224 continue;
1225 }
1226 if ((TrueType->getVectorNumElements() != 2) &&
1227 (TrueType->getVectorNumElements() != 3) &&
1228 (TrueType->getVectorNumElements() != 4) &&
1229 (TrueType->getVectorNumElements() != 8) &&
1230 (TrueType->getVectorNumElements() != 16)) {
1231 continue;
1232 }
1233 }
1234
1235 // Remember the type of the operands
1236 auto OpType = TrueType;
1237
1238 // The actual bit selection will always be done on an integer type,
1239 // declare it here
1240 Type *BitType;
1241
1242 // If the operands are float, then bitcast them to int
1243 if (OpType->getScalarType()->isFloatingPointTy()) {
1244
1245 // First create the new type
1246 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1247 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1248 if (OpType->isVectorTy()) {
1249 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1250 }
1251
1252 // Then bitcast all operands
1253 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1254 BitType, "", CI);
1255 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1256 BitType, "", CI);
1257 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1258
1259 } else {
1260 // The operands have an integer type, use it directly
1261 BitType = OpType;
1262 }
1263
1264 // All the operands are now always integers
1265 // implement as (c & b) | (~c & a)
1266
1267 // Create our negated predicate value
1268 auto AllOnes = Constant::getAllOnesValue(BitType);
1269 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1270 PredicateValue,
1271 AllOnes, "", CI);
1272
1273 // Then put everything together
1274 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1275 NotPredicateValue,
1276 FalseValue, "", CI);
1277 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1278 PredicateValue,
1279 TrueValue, "", CI);
1280
1281 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1282 BitsTrue, "", CI);
1283
1284 // If we were dealing with a floating point type, we must bitcast
1285 // the result back to that
1286 if (OpType->getScalarType()->isFloatingPointTy()) {
1287 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1288 }
1289
1290 // Replace call with our new code
1291 CI->replaceAllUsesWith(V);
1292
1293 // Lastly, remember to remove the user.
1294 ToRemoves.push_back(CI);
1295 }
1296 }
1297
1298 Changed = !ToRemoves.empty();
1299
1300 // And cleanup the calls we don't use anymore.
1301 for (auto V : ToRemoves) {
1302 V->eraseFromParent();
1303 }
1304
1305 // And remove the function we don't need either too.
1306 F->eraseFromParent();
1307 }
1308 }
1309
1310 return Changed;
1311}
1312
Kévin Petit6b0a9532018-10-30 20:00:39 +00001313bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1314 bool Changed = false;
1315
1316 const std::map<const char *, const char *> Map = {
1317 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1318 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1319 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1320 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1321 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1322 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1323 };
1324
1325 for (auto Pair : Map) {
1326 // If we find a function with the matching name.
1327 if (auto F = M.getFunction(Pair.first)) {
1328 SmallVector<Instruction *, 4> ToRemoves;
1329
1330 // Walk the users of the function.
1331 for (auto &U : F->uses()) {
1332 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1333
1334 auto ReplacementFn = Pair.second;
1335
1336 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1337 Value *VectorArg;
1338
1339 // First figure out which function we're dealing with
1340 if (F->getName().startswith("_Z10smoothstep")) {
1341 ArgsToSplat.push_back(CI->getOperand(1));
1342 VectorArg = CI->getOperand(2);
1343 } else {
1344 VectorArg = CI->getOperand(1);
1345 }
1346
1347 // Splat arguments that need to be
1348 SmallVector<Value*, 2> SplatArgs;
1349 auto VecType = VectorArg->getType();
1350
1351 for (auto arg : ArgsToSplat) {
1352 Value* NewVectorArg = UndefValue::get(VecType);
1353 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1354 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1355 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1356 }
1357 SplatArgs.push_back(NewVectorArg);
1358 }
1359
1360 // Replace the call with the vector/vector flavour
1361 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1362 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1363
1364 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1365
1366 SmallVector<Value*, 3> NewArgs;
1367 for (auto arg : SplatArgs) {
1368 NewArgs.push_back(arg);
1369 }
1370 NewArgs.push_back(VectorArg);
1371
1372 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1373
1374 CI->replaceAllUsesWith(NewCI);
1375
1376 // Lastly, remember to remove the user.
1377 ToRemoves.push_back(CI);
1378 }
1379 }
1380
1381 Changed = !ToRemoves.empty();
1382
1383 // And cleanup the calls we don't use anymore.
1384 for (auto V : ToRemoves) {
1385 V->eraseFromParent();
1386 }
1387
1388 // And remove the function we don't need either too.
1389 F->eraseFromParent();
1390 }
1391 }
1392
1393 return Changed;
1394}
1395
David Neto22f144c2017-06-12 14:26:21 -04001396bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1397 bool Changed = false;
1398
1399 const std::map<const char *, Instruction::BinaryOps> Map = {
1400 {"_Z7signbitf", Instruction::LShr},
1401 {"_Z7signbitDv2_f", Instruction::AShr},
1402 {"_Z7signbitDv3_f", Instruction::AShr},
1403 {"_Z7signbitDv4_f", Instruction::AShr},
1404 };
1405
1406 for (auto Pair : Map) {
1407 // If we find a function with the matching name.
1408 if (auto F = M.getFunction(Pair.first)) {
1409 SmallVector<Instruction *, 4> ToRemoves;
1410
1411 // Walk the users of the function.
1412 for (auto &U : F->uses()) {
1413 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1414 auto Arg = CI->getOperand(0);
1415
1416 auto Bitcast =
1417 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1418
1419 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1420 ConstantInt::get(CI->getType(), 31),
1421 "", CI);
1422
1423 CI->replaceAllUsesWith(Shr);
1424
1425 // Lastly, remember to remove the user.
1426 ToRemoves.push_back(CI);
1427 }
1428 }
1429
1430 Changed = !ToRemoves.empty();
1431
1432 // And cleanup the calls we don't use anymore.
1433 for (auto V : ToRemoves) {
1434 V->eraseFromParent();
1435 }
1436
1437 // And remove the function we don't need either too.
1438 F->eraseFromParent();
1439 }
1440 }
1441
1442 return Changed;
1443}
1444
1445bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1446 bool Changed = false;
1447
1448 const std::map<const char *,
1449 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1450 Map = {
1451 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1452 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1453 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1454 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1455 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1456 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1457 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1458 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1459 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1460 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1461 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1462 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1463 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1464 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1465 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1466 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1467 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1468 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1469 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1470 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1471 };
1472
1473 for (auto Pair : Map) {
1474 // If we find a function with the matching name.
1475 if (auto F = M.getFunction(Pair.first)) {
1476 SmallVector<Instruction *, 4> ToRemoves;
1477
1478 // Walk the users of the function.
1479 for (auto &U : F->uses()) {
1480 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1481 // The multiply instruction to use.
1482 auto MulInst = Pair.second.first;
1483
1484 // The add instruction to use.
1485 auto AddInst = Pair.second.second;
1486
1487 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1488
1489 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1490 CI->getArgOperand(1), "", CI);
1491
1492 if (Instruction::BinaryOpsEnd != AddInst) {
1493 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1494 CI);
1495 }
1496
1497 CI->replaceAllUsesWith(I);
1498
1499 // Lastly, remember to remove the user.
1500 ToRemoves.push_back(CI);
1501 }
1502 }
1503
1504 Changed = !ToRemoves.empty();
1505
1506 // And cleanup the calls we don't use anymore.
1507 for (auto V : ToRemoves) {
1508 V->eraseFromParent();
1509 }
1510
1511 // And remove the function we don't need either too.
1512 F->eraseFromParent();
1513 }
1514 }
1515
1516 return Changed;
1517}
1518
Derek Chowcfd368b2017-10-19 20:58:45 -07001519bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1520 bool Changed = false;
1521
1522 struct VectorStoreOps {
1523 const char* name;
1524 int n;
1525 Type* (*get_scalar_type_function)(LLVMContext&);
1526 } vector_store_ops[] = {
1527 // TODO(derekjchow): Expand this list.
1528 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1529 };
1530
David Neto544fffc2017-11-16 18:35:14 -05001531 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001532 auto Name = Op.name;
1533 auto N = Op.n;
1534 auto TypeFn = Op.get_scalar_type_function;
1535 if (auto F = M.getFunction(Name)) {
1536 SmallVector<Instruction *, 4> ToRemoves;
1537
1538 // Walk the users of the function.
1539 for (auto &U : F->uses()) {
1540 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1541 // The value argument from vstoren.
1542 auto Arg0 = CI->getOperand(0);
1543
1544 // The index argument from vstoren.
1545 auto Arg1 = CI->getOperand(1);
1546
1547 // The pointer argument from vstoren.
1548 auto Arg2 = CI->getOperand(2);
1549
1550 // Get types.
1551 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1552 auto ScalarNPointerTy = PointerType::get(
1553 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1554
1555 // Cast to scalarn
1556 auto Cast = CastInst::CreatePointerCast(
1557 Arg2, ScalarNPointerTy, "", CI);
1558 // Index to correct address
1559 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1560 // Store
1561 auto Store = new StoreInst(Arg0, Index, CI);
1562
1563 CI->replaceAllUsesWith(Store);
1564 ToRemoves.push_back(CI);
1565 }
1566 }
1567
1568 Changed = !ToRemoves.empty();
1569
1570 // And cleanup the calls we don't use anymore.
1571 for (auto V : ToRemoves) {
1572 V->eraseFromParent();
1573 }
1574
1575 // And remove the function we don't need either too.
1576 F->eraseFromParent();
1577 }
1578 }
1579
1580 return Changed;
1581}
1582
1583bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1584 bool Changed = false;
1585
1586 struct VectorLoadOps {
1587 const char* name;
1588 int n;
1589 Type* (*get_scalar_type_function)(LLVMContext&);
1590 } vector_load_ops[] = {
1591 // TODO(derekjchow): Expand this list.
1592 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1593 };
1594
David Neto544fffc2017-11-16 18:35:14 -05001595 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001596 auto Name = Op.name;
1597 auto N = Op.n;
1598 auto TypeFn = Op.get_scalar_type_function;
1599 // If we find a function with the matching name.
1600 if (auto F = M.getFunction(Name)) {
1601 SmallVector<Instruction *, 4> ToRemoves;
1602
1603 // Walk the users of the function.
1604 for (auto &U : F->uses()) {
1605 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1606 // The index argument from vloadn.
1607 auto Arg0 = CI->getOperand(0);
1608
1609 // The pointer argument from vloadn.
1610 auto Arg1 = CI->getOperand(1);
1611
1612 // Get types.
1613 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1614 auto ScalarNPointerTy = PointerType::get(
1615 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1616
1617 // Cast to scalarn
1618 auto Cast = CastInst::CreatePointerCast(
1619 Arg1, ScalarNPointerTy, "", CI);
1620 // Index to correct address
1621 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1622 // Load
1623 auto Load = new LoadInst(Index, "", CI);
1624
1625 CI->replaceAllUsesWith(Load);
1626 ToRemoves.push_back(CI);
1627 }
1628 }
1629
1630 Changed = !ToRemoves.empty();
1631
1632 // And cleanup the calls we don't use anymore.
1633 for (auto V : ToRemoves) {
1634 V->eraseFromParent();
1635 }
1636
1637 // And remove the function we don't need either too.
1638 F->eraseFromParent();
1639
1640 }
1641 }
1642
1643 return Changed;
1644}
1645
David Neto22f144c2017-06-12 14:26:21 -04001646bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1647 bool Changed = false;
1648
1649 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
1650 "_Z10vload_halfjPU3AS2KDh"};
1651
1652 for (auto Name : Map) {
1653 // If we find a function with the matching name.
1654 if (auto F = M.getFunction(Name)) {
1655 SmallVector<Instruction *, 4> ToRemoves;
1656
1657 // Walk the users of the function.
1658 for (auto &U : F->uses()) {
1659 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1660 // The index argument from vload_half.
1661 auto Arg0 = CI->getOperand(0);
1662
1663 // The pointer argument from vload_half.
1664 auto Arg1 = CI->getOperand(1);
1665
David Neto22f144c2017-06-12 14:26:21 -04001666 auto IntTy = Type::getInt32Ty(M.getContext());
1667 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001668 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1669
David Neto22f144c2017-06-12 14:26:21 -04001670 // Our intrinsic to unpack a float2 from an int.
1671 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1672
1673 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1674
David Neto482550a2018-03-24 05:21:07 -07001675 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04001676 auto ShortTy = Type::getInt16Ty(M.getContext());
1677 auto ShortPointerTy = PointerType::get(
1678 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001679
David Netoac825b82017-05-30 12:49:01 -04001680 // Cast the half* pointer to short*.
1681 auto Cast =
1682 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001683
David Netoac825b82017-05-30 12:49:01 -04001684 // Index into the correct address of the casted pointer.
1685 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1686
1687 // Load from the short* we casted to.
1688 auto Load = new LoadInst(Index, "", CI);
1689
1690 // ZExt the short -> int.
1691 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1692
1693 // Get our float2.
1694 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1695
1696 // Extract out the bottom element which is our float result.
1697 auto Extract = ExtractElementInst::Create(
1698 Call, ConstantInt::get(IntTy, 0), "", CI);
1699
1700 CI->replaceAllUsesWith(Extract);
1701 } else {
1702 // Assume the pointer argument points to storage aligned to 32bits
1703 // or more.
1704 // TODO(dneto): Do more analysis to make sure this is true?
1705 //
1706 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1707 // with:
1708 //
1709 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1710 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1711 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1712 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1713 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1714 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1715 // x float> %converted, %index_is_odd32
1716
1717 auto IntPointerTy = PointerType::get(
1718 IntTy, Arg1->getType()->getPointerAddressSpace());
1719
David Neto973e6a82017-05-30 13:48:18 -04001720 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04001721 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04001722 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04001723 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1724
1725 auto One = ConstantInt::get(IntTy, 1);
1726 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1727 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1728
1729 // Index into the correct address of the casted pointer.
1730 auto Ptr =
1731 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1732
1733 // Load from the int* we casted to.
1734 auto Load = new LoadInst(Ptr, "", CI);
1735
1736 // Get our float2.
1737 auto Call = CallInst::Create(NewF, Load, "", CI);
1738
1739 // Extract out the float result, where the element number is
1740 // determined by whether the original index was even or odd.
1741 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1742
1743 CI->replaceAllUsesWith(Extract);
1744 }
David Neto22f144c2017-06-12 14:26:21 -04001745
1746 // Lastly, remember to remove the user.
1747 ToRemoves.push_back(CI);
1748 }
1749 }
1750
1751 Changed = !ToRemoves.empty();
1752
1753 // And cleanup the calls we don't use anymore.
1754 for (auto V : ToRemoves) {
1755 V->eraseFromParent();
1756 }
1757
1758 // And remove the function we don't need either too.
1759 F->eraseFromParent();
1760 }
1761 }
1762
1763 return Changed;
1764}
1765
1766bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
1767 bool Changed = false;
1768
David Neto556c7e62018-06-08 13:45:55 -07001769 const std::vector<const char *> Map = {
1770 "_Z11vload_half2jPU3AS1KDh",
1771 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
1772 "_Z11vload_half2jPU3AS2KDh",
1773 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
1774 };
David Neto22f144c2017-06-12 14:26:21 -04001775
1776 for (auto Name : Map) {
1777 // If we find a function with the matching name.
1778 if (auto F = M.getFunction(Name)) {
1779 SmallVector<Instruction *, 4> ToRemoves;
1780
1781 // Walk the users of the function.
1782 for (auto &U : F->uses()) {
1783 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1784 // The index argument from vload_half.
1785 auto Arg0 = CI->getOperand(0);
1786
1787 // The pointer argument from vload_half.
1788 auto Arg1 = CI->getOperand(1);
1789
1790 auto IntTy = Type::getInt32Ty(M.getContext());
1791 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1792 auto NewPointerTy = PointerType::get(
1793 IntTy, Arg1->getType()->getPointerAddressSpace());
1794 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1795
1796 // Cast the half* pointer to int*.
1797 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1798
1799 // Index into the correct address of the casted pointer.
1800 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
1801
1802 // Load from the int* we casted to.
1803 auto Load = new LoadInst(Index, "", CI);
1804
1805 // Our intrinsic to unpack a float2 from an int.
1806 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1807
1808 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1809
1810 // Get our float2.
1811 auto Call = CallInst::Create(NewF, Load, "", CI);
1812
1813 CI->replaceAllUsesWith(Call);
1814
1815 // Lastly, remember to remove the user.
1816 ToRemoves.push_back(CI);
1817 }
1818 }
1819
1820 Changed = !ToRemoves.empty();
1821
1822 // And cleanup the calls we don't use anymore.
1823 for (auto V : ToRemoves) {
1824 V->eraseFromParent();
1825 }
1826
1827 // And remove the function we don't need either too.
1828 F->eraseFromParent();
1829 }
1830 }
1831
1832 return Changed;
1833}
1834
1835bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
1836 bool Changed = false;
1837
David Neto556c7e62018-06-08 13:45:55 -07001838 const std::vector<const char *> Map = {
1839 "_Z11vload_half4jPU3AS1KDh",
1840 "_Z12vloada_half4jPU3AS1KDh",
1841 "_Z11vload_half4jPU3AS2KDh",
1842 "_Z12vloada_half4jPU3AS2KDh",
1843 };
David Neto22f144c2017-06-12 14:26:21 -04001844
1845 for (auto Name : Map) {
1846 // If we find a function with the matching name.
1847 if (auto F = M.getFunction(Name)) {
1848 SmallVector<Instruction *, 4> ToRemoves;
1849
1850 // Walk the users of the function.
1851 for (auto &U : F->uses()) {
1852 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1853 // The index argument from vload_half.
1854 auto Arg0 = CI->getOperand(0);
1855
1856 // The pointer argument from vload_half.
1857 auto Arg1 = CI->getOperand(1);
1858
1859 auto IntTy = Type::getInt32Ty(M.getContext());
1860 auto Int2Ty = VectorType::get(IntTy, 2);
1861 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1862 auto NewPointerTy = PointerType::get(
1863 Int2Ty, Arg1->getType()->getPointerAddressSpace());
1864 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1865
1866 // Cast the half* pointer to int2*.
1867 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1868
1869 // Index into the correct address of the casted pointer.
1870 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
1871
1872 // Load from the int2* we casted to.
1873 auto Load = new LoadInst(Index, "", CI);
1874
1875 // Extract each element from the loaded int2.
1876 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1877 "", CI);
1878 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1879 "", CI);
1880
1881 // Our intrinsic to unpack a float2 from an int.
1882 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1883
1884 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1885
1886 // Get the lower (x & y) components of our final float4.
1887 auto Lo = CallInst::Create(NewF, X, "", CI);
1888
1889 // Get the higher (z & w) components of our final float4.
1890 auto Hi = CallInst::Create(NewF, Y, "", CI);
1891
1892 Constant *ShuffleMask[4] = {
1893 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1894 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1895
1896 // Combine our two float2's into one float4.
1897 auto Combine = new ShuffleVectorInst(
1898 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1899
1900 CI->replaceAllUsesWith(Combine);
1901
1902 // Lastly, remember to remove the user.
1903 ToRemoves.push_back(CI);
1904 }
1905 }
1906
1907 Changed = !ToRemoves.empty();
1908
1909 // And cleanup the calls we don't use anymore.
1910 for (auto V : ToRemoves) {
1911 V->eraseFromParent();
1912 }
1913
1914 // And remove the function we don't need either too.
1915 F->eraseFromParent();
1916 }
1917 }
1918
1919 return Changed;
1920}
1921
David Neto6ad93232018-06-07 15:42:58 -07001922bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
1923 bool Changed = false;
1924
1925 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1926 //
1927 // %u = load i32 %ptr
1928 // %fxy = call <2 x float> Unpack2xHalf(u)
1929 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1930 const std::vector<const char *> Map = {
1931 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
1932 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
1933 "_Z20__clspv_vloada_half2jPKj", // private
1934 };
1935
1936 for (auto Name : Map) {
1937 // If we find a function with the matching name.
1938 if (auto F = M.getFunction(Name)) {
1939 SmallVector<Instruction *, 4> ToRemoves;
1940
1941 // Walk the users of the function.
1942 for (auto &U : F->uses()) {
1943 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
1944 auto Index = CI->getOperand(0);
1945 auto Ptr = CI->getOperand(1);
1946
1947 auto IntTy = Type::getInt32Ty(M.getContext());
1948 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1949 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1950
1951 auto IndexedPtr =
1952 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
1953 auto Load = new LoadInst(IndexedPtr, "", CI);
1954
1955 // Our intrinsic to unpack a float2 from an int.
1956 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1957
1958 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1959
1960 // Get our final float2.
1961 auto Result = CallInst::Create(NewF, Load, "", CI);
1962
1963 CI->replaceAllUsesWith(Result);
1964
1965 // Lastly, remember to remove the user.
1966 ToRemoves.push_back(CI);
1967 }
1968 }
1969
1970 Changed = true;
1971
1972 // And cleanup the calls we don't use anymore.
1973 for (auto V : ToRemoves) {
1974 V->eraseFromParent();
1975 }
1976
1977 // And remove the function we don't need either too.
1978 F->eraseFromParent();
1979 }
1980 }
1981
1982 return Changed;
1983}
1984
1985bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
1986 bool Changed = false;
1987
1988 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1989 //
1990 // %u2 = load <2 x i32> %ptr
1991 // %u2xy = extractelement %u2, 0
1992 // %u2zw = extractelement %u2, 1
1993 // %fxy = call <2 x float> Unpack2xHalf(uint)
1994 // %fzw = call <2 x float> Unpack2xHalf(uint)
1995 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1996 const std::vector<const char *> Map = {
1997 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
1998 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
1999 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2000 };
2001
2002 for (auto Name : Map) {
2003 // If we find a function with the matching name.
2004 if (auto F = M.getFunction(Name)) {
2005 SmallVector<Instruction *, 4> ToRemoves;
2006
2007 // Walk the users of the function.
2008 for (auto &U : F->uses()) {
2009 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2010 auto Index = CI->getOperand(0);
2011 auto Ptr = CI->getOperand(1);
2012
2013 auto IntTy = Type::getInt32Ty(M.getContext());
2014 auto Int2Ty = VectorType::get(IntTy, 2);
2015 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2016 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2017
2018 auto IndexedPtr =
2019 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
2020 auto Load = new LoadInst(IndexedPtr, "", CI);
2021
2022 // Extract each element from the loaded int2.
2023 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2024 "", CI);
2025 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2026 "", CI);
2027
2028 // Our intrinsic to unpack a float2 from an int.
2029 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2030
2031 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2032
2033 // Get the lower (x & y) components of our final float4.
2034 auto Lo = CallInst::Create(NewF, X, "", CI);
2035
2036 // Get the higher (z & w) components of our final float4.
2037 auto Hi = CallInst::Create(NewF, Y, "", CI);
2038
2039 Constant *ShuffleMask[4] = {
2040 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2041 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2042
2043 // Combine our two float2's into one float4.
2044 auto Combine = new ShuffleVectorInst(
2045 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2046
2047 CI->replaceAllUsesWith(Combine);
2048
2049 // Lastly, remember to remove the user.
2050 ToRemoves.push_back(CI);
2051 }
2052 }
2053
2054 Changed = true;
2055
2056 // And cleanup the calls we don't use anymore.
2057 for (auto V : ToRemoves) {
2058 V->eraseFromParent();
2059 }
2060
2061 // And remove the function we don't need either too.
2062 F->eraseFromParent();
2063 }
2064 }
2065
2066 return Changed;
2067}
2068
David Neto22f144c2017-06-12 14:26:21 -04002069bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
2070 bool Changed = false;
2071
2072 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
2073 "_Z15vstore_half_rtefjPU3AS1Dh",
2074 "_Z15vstore_half_rtzfjPU3AS1Dh"};
2075
2076 for (auto Name : Map) {
2077 // If we find a function with the matching name.
2078 if (auto F = M.getFunction(Name)) {
2079 SmallVector<Instruction *, 4> ToRemoves;
2080
2081 // Walk the users of the function.
2082 for (auto &U : F->uses()) {
2083 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2084 // The value to store.
2085 auto Arg0 = CI->getOperand(0);
2086
2087 // The index argument from vstore_half.
2088 auto Arg1 = CI->getOperand(1);
2089
2090 // The pointer argument from vstore_half.
2091 auto Arg2 = CI->getOperand(2);
2092
David Neto22f144c2017-06-12 14:26:21 -04002093 auto IntTy = Type::getInt32Ty(M.getContext());
2094 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002095 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04002096 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002097
2098 // Our intrinsic to pack a float2 to an int.
2099 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2100
2101 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2102
2103 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04002104 auto TempVec =
2105 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
2106 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002107
2108 // Pack the float2 -> half2 (in an int).
2109 auto X = CallInst::Create(NewF, TempVec, "", CI);
2110
David Neto482550a2018-03-24 05:21:07 -07002111 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002112 auto ShortTy = Type::getInt16Ty(M.getContext());
2113 auto ShortPointerTy = PointerType::get(
2114 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002115
David Neto17852de2017-05-29 17:29:31 -04002116 // Truncate our i32 to an i16.
2117 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002118
David Neto17852de2017-05-29 17:29:31 -04002119 // Cast the half* pointer to short*.
2120 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002121
David Neto17852de2017-05-29 17:29:31 -04002122 // Index into the correct address of the casted pointer.
2123 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002124
David Neto17852de2017-05-29 17:29:31 -04002125 // Store to the int* we casted to.
2126 auto Store = new StoreInst(Trunc, Index, CI);
2127
2128 CI->replaceAllUsesWith(Store);
2129 } else {
2130 // We can only write to 32-bit aligned words.
2131 //
2132 // Assuming base is aligned to 32-bits, replace the equivalent of
2133 // vstore_half(value, index, base)
2134 // with:
2135 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2136 // uint32_t write_to_upper_half = index & 1u;
2137 // uint32_t shift = write_to_upper_half << 4;
2138 //
2139 // // Pack the float value as a half number in bottom 16 bits
2140 // // of an i32.
2141 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2142 //
2143 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2144 // ^ ((packed & 0xffff) << shift)
2145 // // We only need relaxed consistency, but OpenCL 1.2 only has
2146 // // sequentially consistent atomics.
2147 // // TODO(dneto): Use relaxed consistency.
2148 // atomic_xor(target_ptr, xor_value)
2149 auto IntPointerTy = PointerType::get(
2150 IntTy, Arg2->getType()->getPointerAddressSpace());
2151
2152 auto Four = ConstantInt::get(IntTy, 4);
2153 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2154
2155 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2156 // Compute index / 2
2157 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2158 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2159 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2160 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2161 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2162 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2163 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2164
2165 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2166 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2167 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2168
2169 // Generate the call to atomi_xor.
2170 SmallVector<Type *, 5> ParamTypes;
2171 // The pointer type.
2172 ParamTypes.push_back(IntPointerTy);
2173 // The Types for memory scope, semantics, and value.
2174 ParamTypes.push_back(IntTy);
2175 ParamTypes.push_back(IntTy);
2176 ParamTypes.push_back(IntTy);
2177 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2178 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2179
2180 const auto ConstantScopeDevice =
2181 ConstantInt::get(IntTy, spv::ScopeDevice);
2182 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2183 // (SPIR-V Workgroup).
2184 const auto AddrSpaceSemanticsBits =
2185 IntPointerTy->getPointerAddressSpace() == 1
2186 ? spv::MemorySemanticsUniformMemoryMask
2187 : spv::MemorySemanticsWorkgroupMemoryMask;
2188
2189 // We're using relaxed consistency here.
2190 const auto ConstantMemorySemantics =
2191 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2192 AddrSpaceSemanticsBits);
2193
2194 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2195 ConstantMemorySemantics, ValueToXor};
2196 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2197 }
David Neto22f144c2017-06-12 14:26:21 -04002198
2199 // Lastly, remember to remove the user.
2200 ToRemoves.push_back(CI);
2201 }
2202 }
2203
2204 Changed = !ToRemoves.empty();
2205
2206 // And cleanup the calls we don't use anymore.
2207 for (auto V : ToRemoves) {
2208 V->eraseFromParent();
2209 }
2210
2211 // And remove the function we don't need either too.
2212 F->eraseFromParent();
2213 }
2214 }
2215
2216 return Changed;
2217}
2218
2219bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2220 bool Changed = false;
2221
David Netoe2871522018-06-08 11:09:54 -07002222 const std::vector<const char *> Map = {
2223 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2224 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2225 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2226 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2227 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2228 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2229 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2230 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2231 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2232 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2233 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2234 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2235 };
David Neto22f144c2017-06-12 14:26:21 -04002236
2237 for (auto Name : Map) {
2238 // If we find a function with the matching name.
2239 if (auto F = M.getFunction(Name)) {
2240 SmallVector<Instruction *, 4> ToRemoves;
2241
2242 // Walk the users of the function.
2243 for (auto &U : F->uses()) {
2244 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2245 // The value to store.
2246 auto Arg0 = CI->getOperand(0);
2247
2248 // The index argument from vstore_half.
2249 auto Arg1 = CI->getOperand(1);
2250
2251 // The pointer argument from vstore_half.
2252 auto Arg2 = CI->getOperand(2);
2253
2254 auto IntTy = Type::getInt32Ty(M.getContext());
2255 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2256 auto NewPointerTy = PointerType::get(
2257 IntTy, Arg2->getType()->getPointerAddressSpace());
2258 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2259
2260 // Our intrinsic to pack a float2 to an int.
2261 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2262
2263 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2264
2265 // Turn the packed x & y into the final packing.
2266 auto X = CallInst::Create(NewF, Arg0, "", CI);
2267
2268 // Cast the half* pointer to int*.
2269 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2270
2271 // Index into the correct address of the casted pointer.
2272 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2273
2274 // Store to the int* we casted to.
2275 auto Store = new StoreInst(X, Index, CI);
2276
2277 CI->replaceAllUsesWith(Store);
2278
2279 // Lastly, remember to remove the user.
2280 ToRemoves.push_back(CI);
2281 }
2282 }
2283
2284 Changed = !ToRemoves.empty();
2285
2286 // And cleanup the calls we don't use anymore.
2287 for (auto V : ToRemoves) {
2288 V->eraseFromParent();
2289 }
2290
2291 // And remove the function we don't need either too.
2292 F->eraseFromParent();
2293 }
2294 }
2295
2296 return Changed;
2297}
2298
2299bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2300 bool Changed = false;
2301
David Netoe2871522018-06-08 11:09:54 -07002302 const std::vector<const char *> Map = {
2303 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2304 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2305 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2306 "_Z13vstorea_half4Dv4_fjPDh", // private
2307 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2308 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2309 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2310 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2311 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2312 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2313 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2314 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2315 };
David Neto22f144c2017-06-12 14:26:21 -04002316
2317 for (auto Name : Map) {
2318 // If we find a function with the matching name.
2319 if (auto F = M.getFunction(Name)) {
2320 SmallVector<Instruction *, 4> ToRemoves;
2321
2322 // Walk the users of the function.
2323 for (auto &U : F->uses()) {
2324 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2325 // The value to store.
2326 auto Arg0 = CI->getOperand(0);
2327
2328 // The index argument from vstore_half.
2329 auto Arg1 = CI->getOperand(1);
2330
2331 // The pointer argument from vstore_half.
2332 auto Arg2 = CI->getOperand(2);
2333
2334 auto IntTy = Type::getInt32Ty(M.getContext());
2335 auto Int2Ty = VectorType::get(IntTy, 2);
2336 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2337 auto NewPointerTy = PointerType::get(
2338 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2339 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2340
2341 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2342 ConstantInt::get(IntTy, 1)};
2343
2344 // Extract out the x & y components of our to store value.
2345 auto Lo =
2346 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2347 ConstantVector::get(LoShuffleMask), "", CI);
2348
2349 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2350 ConstantInt::get(IntTy, 3)};
2351
2352 // Extract out the z & w components of our to store value.
2353 auto Hi =
2354 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2355 ConstantVector::get(HiShuffleMask), "", CI);
2356
2357 // Our intrinsic to pack a float2 to an int.
2358 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2359
2360 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2361
2362 // Turn the packed x & y into the final component of our int2.
2363 auto X = CallInst::Create(NewF, Lo, "", CI);
2364
2365 // Turn the packed z & w into the final component of our int2.
2366 auto Y = CallInst::Create(NewF, Hi, "", CI);
2367
2368 auto Combine = InsertElementInst::Create(
2369 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2370 Combine = InsertElementInst::Create(
2371 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2372
2373 // Cast the half* pointer to int2*.
2374 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2375
2376 // Index into the correct address of the casted pointer.
2377 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2378
2379 // Store to the int2* we casted to.
2380 auto Store = new StoreInst(Combine, Index, CI);
2381
2382 CI->replaceAllUsesWith(Store);
2383
2384 // Lastly, remember to remove the user.
2385 ToRemoves.push_back(CI);
2386 }
2387 }
2388
2389 Changed = !ToRemoves.empty();
2390
2391 // And cleanup the calls we don't use anymore.
2392 for (auto V : ToRemoves) {
2393 V->eraseFromParent();
2394 }
2395
2396 // And remove the function we don't need either too.
2397 F->eraseFromParent();
2398 }
2399 }
2400
2401 return Changed;
2402}
2403
2404bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2405 bool Changed = false;
2406
2407 const std::map<const char *, const char*> Map = {
2408 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2409 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2410 };
2411
2412 for (auto Pair : Map) {
2413 // If we find a function with the matching name.
2414 if (auto F = M.getFunction(Pair.first)) {
2415 SmallVector<Instruction *, 4> ToRemoves;
2416
2417 // Walk the users of the function.
2418 for (auto &U : F->uses()) {
2419 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2420 // The image.
2421 auto Arg0 = CI->getOperand(0);
2422
2423 // The sampler.
2424 auto Arg1 = CI->getOperand(1);
2425
2426 // The coordinate (integer type that we can't handle).
2427 auto Arg2 = CI->getOperand(2);
2428
2429 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2430
2431 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2432
2433 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2434
2435 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2436
2437 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2438
2439 CI->replaceAllUsesWith(NewCI);
2440
2441 // Lastly, remember to remove the user.
2442 ToRemoves.push_back(CI);
2443 }
2444 }
2445
2446 Changed = !ToRemoves.empty();
2447
2448 // And cleanup the calls we don't use anymore.
2449 for (auto V : ToRemoves) {
2450 V->eraseFromParent();
2451 }
2452
2453 // And remove the function we don't need either too.
2454 F->eraseFromParent();
2455 }
2456 }
2457
2458 return Changed;
2459}
2460
2461bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2462 bool Changed = false;
2463
2464 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002465 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2466 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2467 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2468 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2469 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2470 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002471 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2472 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2473 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2474 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2475 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002476 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002477
2478 for (auto Pair : Map) {
2479 // If we find a function with the matching name.
2480 if (auto F = M.getFunction(Pair.first)) {
2481 SmallVector<Instruction *, 4> ToRemoves;
2482
2483 // Walk the users of the function.
2484 for (auto &U : F->uses()) {
2485 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2486 auto FType = F->getFunctionType();
2487 SmallVector<Type *, 5> ParamTypes;
2488
2489 // The pointer type.
2490 ParamTypes.push_back(FType->getParamType(0));
2491
2492 auto IntTy = Type::getInt32Ty(M.getContext());
2493
2494 // The memory scope type.
2495 ParamTypes.push_back(IntTy);
2496
2497 // The memory semantics type.
2498 ParamTypes.push_back(IntTy);
2499
2500 if (2 < CI->getNumArgOperands()) {
2501 // The unequal memory semantics type.
2502 ParamTypes.push_back(IntTy);
2503
2504 // The value type.
2505 ParamTypes.push_back(FType->getParamType(2));
2506
2507 // The comparator type.
2508 ParamTypes.push_back(FType->getParamType(1));
2509 } else if (1 < CI->getNumArgOperands()) {
2510 // The value type.
2511 ParamTypes.push_back(FType->getParamType(1));
2512 }
2513
2514 auto NewFType =
2515 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2516 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2517
2518 // We need to map the OpenCL constants to the SPIR-V equivalents.
2519 const auto ConstantScopeDevice =
2520 ConstantInt::get(IntTy, spv::ScopeDevice);
2521 const auto ConstantMemorySemantics = ConstantInt::get(
2522 IntTy, spv::MemorySemanticsUniformMemoryMask |
2523 spv::MemorySemanticsSequentiallyConsistentMask);
2524
2525 SmallVector<Value *, 5> Params;
2526
2527 // The pointer.
2528 Params.push_back(CI->getArgOperand(0));
2529
2530 // The memory scope.
2531 Params.push_back(ConstantScopeDevice);
2532
2533 // The memory semantics.
2534 Params.push_back(ConstantMemorySemantics);
2535
2536 if (2 < CI->getNumArgOperands()) {
2537 // The unequal memory semantics.
2538 Params.push_back(ConstantMemorySemantics);
2539
2540 // The value.
2541 Params.push_back(CI->getArgOperand(2));
2542
2543 // The comparator.
2544 Params.push_back(CI->getArgOperand(1));
2545 } else if (1 < CI->getNumArgOperands()) {
2546 // The value.
2547 Params.push_back(CI->getArgOperand(1));
2548 }
2549
2550 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2551
2552 CI->replaceAllUsesWith(NewCI);
2553
2554 // Lastly, remember to remove the user.
2555 ToRemoves.push_back(CI);
2556 }
2557 }
2558
2559 Changed = !ToRemoves.empty();
2560
2561 // And cleanup the calls we don't use anymore.
2562 for (auto V : ToRemoves) {
2563 V->eraseFromParent();
2564 }
2565
2566 // And remove the function we don't need either too.
2567 F->eraseFromParent();
2568 }
2569 }
2570
Neil Henning39672102017-09-29 14:33:13 +01002571 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002572 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2573 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2574 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2575 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2576 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2577 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2578 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2579 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2580 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2581 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2582 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2583 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2584 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2585 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2586 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2587 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002588 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2589 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2590 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2591 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2592 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2593 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2594 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2595 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2596 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2597 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2598 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2599 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2600 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2601 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2602 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2603 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2604
2605 for (auto Pair : Map2) {
2606 // If we find a function with the matching name.
2607 if (auto F = M.getFunction(Pair.first)) {
2608 SmallVector<Instruction *, 4> ToRemoves;
2609
2610 // Walk the users of the function.
2611 for (auto &U : F->uses()) {
2612 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2613 auto AtomicOp = new AtomicRMWInst(
2614 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2615 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2616
2617 CI->replaceAllUsesWith(AtomicOp);
2618
2619 // Lastly, remember to remove the user.
2620 ToRemoves.push_back(CI);
2621 }
2622 }
2623
2624 Changed = !ToRemoves.empty();
2625
2626 // And cleanup the calls we don't use anymore.
2627 for (auto V : ToRemoves) {
2628 V->eraseFromParent();
2629 }
2630
2631 // And remove the function we don't need either too.
2632 F->eraseFromParent();
2633 }
2634 }
2635
David Neto22f144c2017-06-12 14:26:21 -04002636 return Changed;
2637}
2638
2639bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2640 bool Changed = false;
2641
2642 // If we find a function with the matching name.
2643 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2644 SmallVector<Instruction *, 4> ToRemoves;
2645
2646 auto IntTy = Type::getInt32Ty(M.getContext());
2647 auto FloatTy = Type::getFloatTy(M.getContext());
2648
2649 Constant *DownShuffleMask[3] = {
2650 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2651 ConstantInt::get(IntTy, 2)};
2652
2653 Constant *UpShuffleMask[4] = {
2654 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2655 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2656
2657 Constant *FloatVec[3] = {
2658 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
2659 };
2660
2661 // Walk the users of the function.
2662 for (auto &U : F->uses()) {
2663 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2664 auto Vec4Ty = CI->getArgOperand(0)->getType();
2665 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2666 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2667 auto Vec3Ty = Arg0->getType();
2668
2669 auto NewFType =
2670 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
2671
2672 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
2673
2674 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
2675
2676 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
2677
2678 CI->replaceAllUsesWith(Result);
2679
2680 // Lastly, remember to remove the user.
2681 ToRemoves.push_back(CI);
2682 }
2683 }
2684
2685 Changed = !ToRemoves.empty();
2686
2687 // And cleanup the calls we don't use anymore.
2688 for (auto V : ToRemoves) {
2689 V->eraseFromParent();
2690 }
2691
2692 // And remove the function we don't need either too.
2693 F->eraseFromParent();
2694 }
2695
2696 return Changed;
2697}
David Neto62653202017-10-16 19:05:18 -04002698
2699bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2700 bool Changed = false;
2701
2702 // OpenCL's float result = fract(float x, float* ptr)
2703 //
2704 // In the LLVM domain:
2705 //
2706 // %floor_result = call spir_func float @floor(float %x)
2707 // store float %floor_result, float * %ptr
2708 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2709 // %result = call spir_func float
2710 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2711 //
2712 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2713 // and clspv.fract occur in the SPIR-V generator pass:
2714 //
2715 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2716 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2717 // ...
2718 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2719 // OpStore %ptr %floor_result
2720 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2721 // %fract_result = OpExtInst %float
2722 // %glsl_ext Fmin %fract_intermediate %just_under_1
2723
2724
2725 using std::string;
2726
2727 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2728 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
2729 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
2730 auto make_quad = [](const char *a, const char *b, const char *c,
2731 const char *d) {
2732 return std::tuple<const char *, const char *, const char *, const char *>(
2733 a, b, c, d);
2734 };
2735 const std::vector<QuadType> Functions = {
2736 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
2737 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
2738 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
2739 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
2740 };
2741
2742 for (auto& quad : Functions) {
2743 const StringRef fract_name(std::get<0>(quad));
2744
2745 // If we find a function with the matching name.
2746 if (auto F = M.getFunction(fract_name)) {
2747 if (F->use_begin() == F->use_end())
2748 continue;
2749
2750 // We have some uses.
2751 Changed = true;
2752
2753 auto& Context = M.getContext();
2754
2755 const StringRef floor_name(std::get<1>(quad));
2756 const StringRef fmin_name(std::get<2>(quad));
2757 const StringRef clspv_fract_name(std::get<3>(quad));
2758
2759 // This is either float or a float vector. All the float-like
2760 // types are this type.
2761 auto result_ty = F->getReturnType();
2762
2763 Function* fmin_fn = M.getFunction(fmin_name);
2764 if (!fmin_fn) {
2765 // Make the fmin function.
2766 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
2767 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002768 fmin_fn->addFnAttr(Attribute::ReadNone);
2769 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2770 }
2771
2772 Function* floor_fn = M.getFunction(floor_name);
2773 if (!floor_fn) {
2774 // Make the floor function.
2775 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2776 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002777 floor_fn->addFnAttr(Attribute::ReadNone);
2778 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2779 }
2780
2781 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
2782 if (!clspv_fract_fn) {
2783 // Make the clspv_fract function.
2784 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2785 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002786 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2787 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2788 }
2789
2790 // Number of significant significand bits, whether represented or not.
2791 unsigned num_significand_bits;
2792 switch (result_ty->getScalarType()->getTypeID()) {
2793 case Type::HalfTyID:
2794 num_significand_bits = 11;
2795 break;
2796 case Type::FloatTyID:
2797 num_significand_bits = 24;
2798 break;
2799 case Type::DoubleTyID:
2800 num_significand_bits = 53;
2801 break;
2802 default:
2803 assert(false && "Unhandled float type when processing fract builtin");
2804 break;
2805 }
2806 // Beware that the disassembler displays this value as
2807 // OpConstant %float 1
2808 // which is not quite right.
2809 const double kJustUnderOneScalar =
2810 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2811
2812 Constant *just_under_one =
2813 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2814 if (result_ty->isVectorTy()) {
2815 just_under_one = ConstantVector::getSplat(
2816 result_ty->getVectorNumElements(), just_under_one);
2817 }
2818
2819 IRBuilder<> Builder(Context);
2820
2821 SmallVector<Instruction *, 4> ToRemoves;
2822
2823 // Walk the users of the function.
2824 for (auto &U : F->uses()) {
2825 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2826
2827 Builder.SetInsertPoint(CI);
2828 auto arg = CI->getArgOperand(0);
2829 auto ptr = CI->getArgOperand(1);
2830
2831 // Compute floor result and store it.
2832 auto floor = Builder.CreateCall(floor_fn, {arg});
2833 Builder.CreateStore(floor, ptr);
2834
2835 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2836 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2837
2838 CI->replaceAllUsesWith(fract_result);
2839
2840 // Lastly, remember to remove the user.
2841 ToRemoves.push_back(CI);
2842 }
2843 }
2844
2845 // And cleanup the calls we don't use anymore.
2846 for (auto V : ToRemoves) {
2847 V->eraseFromParent();
2848 }
2849
2850 // And remove the function we don't need either too.
2851 F->eraseFromParent();
2852 }
2853 }
2854
2855 return Changed;
2856}