blob: 5b3e0120735ed605df16011c51d2fba8ff498185 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000024#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040025#include "llvm/Pass.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040029
David Neto118188e2018-08-24 11:27:54 -040030#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040031
David Neto482550a2018-03-24 05:21:07 -070032#include "clspv/Option.h"
33
David Neto22f144c2017-06-12 14:26:21 -040034using namespace llvm;
35
36#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
37
38namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000039
40struct ArgTypeInfo {
41 enum class SignedNess {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000042 None,
Kévin Petit8a560882019-03-21 15:24:34 +000043 Unsigned,
44 Signed
45 };
46 SignedNess signedness;
47};
48
49struct FunctionInfo {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000050 StringRef name;
Kévin Petit8a560882019-03-21 15:24:34 +000051 std::vector<ArgTypeInfo> argTypeInfos;
52};
53
54bool getFunctionInfoFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
55 if (!name.consume_front("_Z")) {
56 return false;
57 }
58 size_t nameLen;
59 if (name.consumeInteger(10, nameLen)) {
60 return false;
61 }
62
Kévin Petit9d1a9d12019-03-25 15:23:46 +000063 finfo->name = name.take_front(nameLen);
Kévin Petit8a560882019-03-21 15:24:34 +000064 name = name.drop_front(nameLen);
65
66 ArgTypeInfo prev_ti;
67
68 while (name.size() != 0) {
69
70 ArgTypeInfo ti;
71
72 // Try parsing a vector prefix
73 if (name.consume_front("Dv")) {
74 int numElems;
75 if (name.consumeInteger(10, numElems)) {
76 return false;
77 }
78
79 if (!name.consume_front("_")) {
80 return false;
81 }
82 }
83
84 // Parse the base type
85 char typeCode = name.front();
86 name = name.drop_front(1);
87 switch(typeCode) {
88 case 'c': // char
89 case 'a': // signed char
90 case 's': // short
91 case 'i': // int
92 case 'l': // long
93 ti.signedness = ArgTypeInfo::SignedNess::Signed;
94 break;
95 case 'h': // unsigned char
96 case 't': // unsigned short
97 case 'j': // unsigned int
98 case 'm': // unsigned long
99 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
100 break;
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000101 case 'f':
102 ti.signedness = ArgTypeInfo::SignedNess::None;
103 break;
Kévin Petit8a560882019-03-21 15:24:34 +0000104 case 'S':
105 ti = prev_ti;
106 if (!name.consume_front("_")) {
107 return false;
108 }
109 break;
110 default:
111 return false;
112 }
113
114 finfo->argTypeInfos.push_back(ti);
115
116 prev_ti = ti;
117 }
118
119 return true;
120};
121
122void getFunctionInfoFromMangledName(StringRef name, FunctionInfo *finfo) {
123 if (!getFunctionInfoFromMangledNameCheck(name, finfo)) {
124 llvm_unreachable("Can't parse mangled function name!");
125 }
126}
127
David Neto22f144c2017-06-12 14:26:21 -0400128uint32_t clz(uint32_t v) {
129 uint32_t r;
130 uint32_t shift;
131
132 r = (v > 0xFFFF) << 4;
133 v >>= r;
134 shift = (v > 0xFF) << 3;
135 v >>= shift;
136 r |= shift;
137 shift = (v > 0xF) << 2;
138 v >>= shift;
139 r |= shift;
140 shift = (v > 0x3) << 1;
141 v >>= shift;
142 r |= shift;
143 r |= (v >> 1);
144
145 return r;
146}
147
148Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
149 if (1 == elements) {
150 return Type::getInt1Ty(C);
151 } else {
152 return VectorType::get(Type::getInt1Ty(C), elements);
153 }
154}
155
156struct ReplaceOpenCLBuiltinPass final : public ModulePass {
157 static char ID;
158 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
159
160 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000161 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400162 bool replaceRecip(Module &M);
163 bool replaceDivide(Module &M);
164 bool replaceExp10(Module &M);
165 bool replaceLog10(Module &M);
166 bool replaceBarrier(Module &M);
167 bool replaceMemFence(Module &M);
168 bool replaceRelational(Module &M);
169 bool replaceIsInfAndIsNan(Module &M);
170 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000171 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000172 bool replaceRotate(Module &M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000173 bool replaceConvert(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000174 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000175 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000176 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000177 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400178 bool replaceSignbit(Module &M);
179 bool replaceMadandMad24andMul24(Module &M);
180 bool replaceVloadHalf(Module &M);
181 bool replaceVloadHalf2(Module &M);
182 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700183 bool replaceClspvVloadaHalf2(Module &M);
184 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400185 bool replaceVstoreHalf(Module &M);
186 bool replaceVstoreHalf2(Module &M);
187 bool replaceVstoreHalf4(Module &M);
188 bool replaceReadImageF(Module &M);
189 bool replaceAtomics(Module &M);
190 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400191 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700192 bool replaceVload(Module &M);
193 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400194};
195}
196
197char ReplaceOpenCLBuiltinPass::ID = 0;
198static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
199 "Replace OpenCL Builtins Pass");
200
201namespace clspv {
202ModulePass *createReplaceOpenCLBuiltinPass() {
203 return new ReplaceOpenCLBuiltinPass();
204}
205}
206
207bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
208 bool Changed = false;
209
Kévin Petit2444e9b2018-11-09 14:14:37 +0000210 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400211 Changed |= replaceRecip(M);
212 Changed |= replaceDivide(M);
213 Changed |= replaceExp10(M);
214 Changed |= replaceLog10(M);
215 Changed |= replaceBarrier(M);
216 Changed |= replaceMemFence(M);
217 Changed |= replaceRelational(M);
218 Changed |= replaceIsInfAndIsNan(M);
219 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000220 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000221 Changed |= replaceRotate(M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000222 Changed |= replaceConvert(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000223 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000224 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000225 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000226 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400227 Changed |= replaceSignbit(M);
228 Changed |= replaceMadandMad24andMul24(M);
229 Changed |= replaceVloadHalf(M);
230 Changed |= replaceVloadHalf2(M);
231 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700232 Changed |= replaceClspvVloadaHalf2(M);
233 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400234 Changed |= replaceVstoreHalf(M);
235 Changed |= replaceVstoreHalf2(M);
236 Changed |= replaceVstoreHalf4(M);
237 Changed |= replaceReadImageF(M);
238 Changed |= replaceAtomics(M);
239 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400240 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700241 Changed |= replaceVload(M);
242 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400243
244 return Changed;
245}
246
Kévin Petit2444e9b2018-11-09 14:14:37 +0000247bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
248 bool Changed = false;
249
250 const char *Names[] = {
Kévin Petit5ace14c2019-04-01 16:29:53 +0100251 "_Z3absh",
252 "_Z3absDv2_h",
253 "_Z3absDv3_h",
254 "_Z3absDv4_h",
Kévin Petit2444e9b2018-11-09 14:14:37 +0000255 "_Z3abst",
256 "_Z3absDv2_t",
257 "_Z3absDv3_t",
258 "_Z3absDv4_t",
259 "_Z3absj",
260 "_Z3absDv2_j",
261 "_Z3absDv3_j",
262 "_Z3absDv4_j",
263 "_Z3absm",
264 "_Z3absDv2_m",
265 "_Z3absDv3_m",
266 "_Z3absDv4_m",
267 };
268
269 for (auto Name : Names) {
270 // If we find a function with the matching name.
271 if (auto F = M.getFunction(Name)) {
272 SmallVector<Instruction *, 4> ToRemoves;
273
274 // Walk the users of the function.
275 for (auto &U : F->uses()) {
276 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
277 // Abs has one arg.
278 auto Arg = CI->getOperand(0);
279
280 // Use the argument unchanged, we know it's unsigned
281 CI->replaceAllUsesWith(Arg);
282
283 // Lastly, remember to remove the user.
284 ToRemoves.push_back(CI);
285 }
286 }
287
288 Changed = !ToRemoves.empty();
289
290 // And cleanup the calls we don't use anymore.
291 for (auto V : ToRemoves) {
292 V->eraseFromParent();
293 }
294
295 // And remove the function we don't need either too.
296 F->eraseFromParent();
297 }
298 }
299
300 return Changed;
301}
302
David Neto22f144c2017-06-12 14:26:21 -0400303bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
304 bool Changed = false;
305
306 const char *Names[] = {
307 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
308 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
309 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
310 };
311
312 for (auto Name : Names) {
313 // If we find a function with the matching name.
314 if (auto F = M.getFunction(Name)) {
315 SmallVector<Instruction *, 4> ToRemoves;
316
317 // Walk the users of the function.
318 for (auto &U : F->uses()) {
319 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
320 // Recip has one arg.
321 auto Arg = CI->getOperand(0);
322
323 auto Div = BinaryOperator::Create(
324 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
325 CI);
326
327 CI->replaceAllUsesWith(Div);
328
329 // Lastly, remember to remove the user.
330 ToRemoves.push_back(CI);
331 }
332 }
333
334 Changed = !ToRemoves.empty();
335
336 // And cleanup the calls we don't use anymore.
337 for (auto V : ToRemoves) {
338 V->eraseFromParent();
339 }
340
341 // And remove the function we don't need either too.
342 F->eraseFromParent();
343 }
344 }
345
346 return Changed;
347}
348
349bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
350 bool Changed = false;
351
352 const char *Names[] = {
353 "_Z11half_divideff", "_Z13native_divideff",
354 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
355 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
356 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
357 };
358
359 for (auto Name : Names) {
360 // If we find a function with the matching name.
361 if (auto F = M.getFunction(Name)) {
362 SmallVector<Instruction *, 4> ToRemoves;
363
364 // Walk the users of the function.
365 for (auto &U : F->uses()) {
366 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
367 auto Div = BinaryOperator::Create(
368 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
369
370 CI->replaceAllUsesWith(Div);
371
372 // Lastly, remember to remove the user.
373 ToRemoves.push_back(CI);
374 }
375 }
376
377 Changed = !ToRemoves.empty();
378
379 // And cleanup the calls we don't use anymore.
380 for (auto V : ToRemoves) {
381 V->eraseFromParent();
382 }
383
384 // And remove the function we don't need either too.
385 F->eraseFromParent();
386 }
387 }
388
389 return Changed;
390}
391
392bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
393 bool Changed = false;
394
395 const std::map<const char *, const char *> Map = {
396 {"_Z5exp10f", "_Z3expf"},
397 {"_Z10half_exp10f", "_Z8half_expf"},
398 {"_Z12native_exp10f", "_Z10native_expf"},
399 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
400 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
401 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
402 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
403 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
404 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
405 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
406 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
407 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
408
409 for (auto Pair : Map) {
410 // If we find a function with the matching name.
411 if (auto F = M.getFunction(Pair.first)) {
412 SmallVector<Instruction *, 4> ToRemoves;
413
414 // Walk the users of the function.
415 for (auto &U : F->uses()) {
416 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
417 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
418
419 auto Arg = CI->getOperand(0);
420
421 // Constant of the natural log of 10 (ln(10)).
422 const double Ln10 =
423 2.302585092994045684017991454684364207601101488628772976033;
424
425 auto Mul = BinaryOperator::Create(
426 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
427 CI);
428
429 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
430
431 CI->replaceAllUsesWith(NewCI);
432
433 // Lastly, remember to remove the user.
434 ToRemoves.push_back(CI);
435 }
436 }
437
438 Changed = !ToRemoves.empty();
439
440 // And cleanup the calls we don't use anymore.
441 for (auto V : ToRemoves) {
442 V->eraseFromParent();
443 }
444
445 // And remove the function we don't need either too.
446 F->eraseFromParent();
447 }
448 }
449
450 return Changed;
451}
452
453bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
454 bool Changed = false;
455
456 const std::map<const char *, const char *> Map = {
457 {"_Z5log10f", "_Z3logf"},
458 {"_Z10half_log10f", "_Z8half_logf"},
459 {"_Z12native_log10f", "_Z10native_logf"},
460 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
461 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
462 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
463 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
464 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
465 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
466 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
467 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
468 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
469
470 for (auto Pair : Map) {
471 // If we find a function with the matching name.
472 if (auto F = M.getFunction(Pair.first)) {
473 SmallVector<Instruction *, 4> ToRemoves;
474
475 // Walk the users of the function.
476 for (auto &U : F->uses()) {
477 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
478 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
479
480 auto Arg = CI->getOperand(0);
481
482 // Constant of the reciprocal of the natural log of 10 (ln(10)).
483 const double Ln10 =
484 0.434294481903251827651128918916605082294397005803666566114;
485
486 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
487
488 auto Mul = BinaryOperator::Create(
489 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
490 "", CI);
491
492 CI->replaceAllUsesWith(Mul);
493
494 // Lastly, remember to remove the user.
495 ToRemoves.push_back(CI);
496 }
497 }
498
499 Changed = !ToRemoves.empty();
500
501 // And cleanup the calls we don't use anymore.
502 for (auto V : ToRemoves) {
503 V->eraseFromParent();
504 }
505
506 // And remove the function we don't need either too.
507 F->eraseFromParent();
508 }
509 }
510
511 return Changed;
512}
513
514bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
515 bool Changed = false;
516
517 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
518
519 const std::map<const char *, const char *> Map = {
520 {"_Z7barrierj", "__spirv_control_barrier"}};
521
522 for (auto Pair : Map) {
523 // If we find a function with the matching name.
524 if (auto F = M.getFunction(Pair.first)) {
525 SmallVector<Instruction *, 4> ToRemoves;
526
527 // Walk the users of the function.
528 for (auto &U : F->uses()) {
529 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
530 auto FType = F->getFunctionType();
531 SmallVector<Type *, 3> Params;
532 for (unsigned i = 0; i < 3; i++) {
533 Params.push_back(FType->getParamType(0));
534 }
535 auto NewFType =
536 FunctionType::get(FType->getReturnType(), Params, false);
537 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
538
539 auto Arg = CI->getOperand(0);
540
541 // We need to map the OpenCL constants to the SPIR-V equivalents.
542 const auto LocalMemFence =
543 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
544 const auto GlobalMemFence =
545 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
546 const auto ConstantSequentiallyConsistent = ConstantInt::get(
547 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
548 const auto ConstantScopeDevice =
549 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
550 const auto ConstantScopeWorkgroup =
551 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
552
553 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
554 const auto LocalMemFenceMask = BinaryOperator::Create(
555 Instruction::And, LocalMemFence, Arg, "", CI);
556 const auto WorkgroupShiftAmount =
557 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
558 clz(CLK_LOCAL_MEM_FENCE);
559 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
560 Instruction::Shl, LocalMemFenceMask,
561 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
562
563 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
564 const auto GlobalMemFenceMask = BinaryOperator::Create(
565 Instruction::And, GlobalMemFence, Arg, "", CI);
566 const auto UniformShiftAmount =
567 clz(spv::MemorySemanticsUniformMemoryMask) -
568 clz(CLK_GLOBAL_MEM_FENCE);
569 const auto MemorySemanticsUniform = BinaryOperator::Create(
570 Instruction::Shl, GlobalMemFenceMask,
571 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
572
573 // And combine the above together, also adding in
574 // MemorySemanticsSequentiallyConsistentMask.
575 auto MemorySemantics =
576 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
577 ConstantSequentiallyConsistent, "", CI);
578 MemorySemantics = BinaryOperator::Create(
579 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
580
581 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
582 // Device Scope, otherwise Workgroup Scope.
583 const auto Cmp =
584 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
585 GlobalMemFenceMask, GlobalMemFence, "", CI);
586 const auto MemoryScope = SelectInst::Create(
587 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
588
589 // Lastly, the Execution Scope is always Workgroup Scope.
590 const auto ExecutionScope = ConstantScopeWorkgroup;
591
592 auto NewCI = CallInst::Create(
593 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
594
595 CI->replaceAllUsesWith(NewCI);
596
597 // Lastly, remember to remove the user.
598 ToRemoves.push_back(CI);
599 }
600 }
601
602 Changed = !ToRemoves.empty();
603
604 // And cleanup the calls we don't use anymore.
605 for (auto V : ToRemoves) {
606 V->eraseFromParent();
607 }
608
609 // And remove the function we don't need either too.
610 F->eraseFromParent();
611 }
612 }
613
614 return Changed;
615}
616
617bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
618 bool Changed = false;
619
620 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
621
Neil Henning39672102017-09-29 14:33:13 +0100622 using Tuple = std::tuple<const char *, unsigned>;
623 const std::map<const char *, Tuple> Map = {
624 {"_Z9mem_fencej",
625 Tuple("__spirv_memory_barrier",
626 spv::MemorySemanticsSequentiallyConsistentMask)},
627 {"_Z14read_mem_fencej",
628 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
629 {"_Z15write_mem_fencej",
630 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400631
632 for (auto Pair : Map) {
633 // If we find a function with the matching name.
634 if (auto F = M.getFunction(Pair.first)) {
635 SmallVector<Instruction *, 4> ToRemoves;
636
637 // Walk the users of the function.
638 for (auto &U : F->uses()) {
639 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
640 auto FType = F->getFunctionType();
641 SmallVector<Type *, 2> Params;
642 for (unsigned i = 0; i < 2; i++) {
643 Params.push_back(FType->getParamType(0));
644 }
645 auto NewFType =
646 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100647 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400648
649 auto Arg = CI->getOperand(0);
650
651 // We need to map the OpenCL constants to the SPIR-V equivalents.
652 const auto LocalMemFence =
653 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
654 const auto GlobalMemFence =
655 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
656 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100657 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400658 const auto ConstantScopeDevice =
659 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
660
661 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
662 const auto LocalMemFenceMask = BinaryOperator::Create(
663 Instruction::And, LocalMemFence, Arg, "", CI);
664 const auto WorkgroupShiftAmount =
665 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
666 clz(CLK_LOCAL_MEM_FENCE);
667 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
668 Instruction::Shl, LocalMemFenceMask,
669 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
670
671 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
672 const auto GlobalMemFenceMask = BinaryOperator::Create(
673 Instruction::And, GlobalMemFence, Arg, "", CI);
674 const auto UniformShiftAmount =
675 clz(spv::MemorySemanticsUniformMemoryMask) -
676 clz(CLK_GLOBAL_MEM_FENCE);
677 const auto MemorySemanticsUniform = BinaryOperator::Create(
678 Instruction::Shl, GlobalMemFenceMask,
679 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
680
681 // And combine the above together, also adding in
682 // MemorySemanticsSequentiallyConsistentMask.
683 auto MemorySemantics =
684 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
685 ConstantMemorySemantics, "", CI);
686 MemorySemantics = BinaryOperator::Create(
687 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
688
689 // Memory Scope is always device.
690 const auto MemoryScope = ConstantScopeDevice;
691
692 auto NewCI =
693 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
694
695 CI->replaceAllUsesWith(NewCI);
696
697 // Lastly, remember to remove the user.
698 ToRemoves.push_back(CI);
699 }
700 }
701
702 Changed = !ToRemoves.empty();
703
704 // And cleanup the calls we don't use anymore.
705 for (auto V : ToRemoves) {
706 V->eraseFromParent();
707 }
708
709 // And remove the function we don't need either too.
710 F->eraseFromParent();
711 }
712 }
713
714 return Changed;
715}
716
717bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
718 bool Changed = false;
719
720 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
721 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
722 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
723 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
724 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
725 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
726 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
727 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
728 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
729 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
730 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
731 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
732 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
733 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
734 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
735 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
736 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
737 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
738 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
739 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
740 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
741 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
742 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
743 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
744 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
745 };
746
747 for (auto Pair : Map) {
748 // If we find a function with the matching name.
749 if (auto F = M.getFunction(Pair.first)) {
750 SmallVector<Instruction *, 4> ToRemoves;
751
752 // Walk the users of the function.
753 for (auto &U : F->uses()) {
754 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
755 // The predicate to use in the CmpInst.
756 auto Predicate = Pair.second.first;
757
758 // The value to return for true.
759 auto TrueValue =
760 ConstantInt::getSigned(CI->getType(), Pair.second.second);
761
762 // The value to return for false.
763 auto FalseValue = Constant::getNullValue(CI->getType());
764
765 auto Arg1 = CI->getOperand(0);
766 auto Arg2 = CI->getOperand(1);
767
768 const auto Cmp =
769 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
770
771 const auto Select =
772 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
773
774 CI->replaceAllUsesWith(Select);
775
776 // Lastly, remember to remove the user.
777 ToRemoves.push_back(CI);
778 }
779 }
780
781 Changed = !ToRemoves.empty();
782
783 // And cleanup the calls we don't use anymore.
784 for (auto V : ToRemoves) {
785 V->eraseFromParent();
786 }
787
788 // And remove the function we don't need either too.
789 F->eraseFromParent();
790 }
791 }
792
793 return Changed;
794}
795
796bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
797 bool Changed = false;
798
799 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
800 {"_Z5isinff", {"__spirv_isinff", 1}},
801 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
802 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
803 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
804 {"_Z5isnanf", {"__spirv_isnanf", 1}},
805 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
806 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
807 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
808 };
809
810 for (auto Pair : Map) {
811 // If we find a function with the matching name.
812 if (auto F = M.getFunction(Pair.first)) {
813 SmallVector<Instruction *, 4> ToRemoves;
814
815 // Walk the users of the function.
816 for (auto &U : F->uses()) {
817 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
818 const auto CITy = CI->getType();
819
820 // The fake SPIR-V intrinsic to generate.
821 auto SPIRVIntrinsic = Pair.second.first;
822
823 // The value to return for true.
824 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
825
826 // The value to return for false.
827 auto FalseValue = Constant::getNullValue(CITy);
828
829 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
830 M.getContext(),
831 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
832
833 auto NewFType =
834 FunctionType::get(CorrespondingBoolTy,
835 F->getFunctionType()->getParamType(0), false);
836
837 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
838
839 auto Arg = CI->getOperand(0);
840
841 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
842
843 const auto Select =
844 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
845
846 CI->replaceAllUsesWith(Select);
847
848 // Lastly, remember to remove the user.
849 ToRemoves.push_back(CI);
850 }
851 }
852
853 Changed = !ToRemoves.empty();
854
855 // And cleanup the calls we don't use anymore.
856 for (auto V : ToRemoves) {
857 V->eraseFromParent();
858 }
859
860 // And remove the function we don't need either too.
861 F->eraseFromParent();
862 }
863 }
864
865 return Changed;
866}
867
868bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
869 bool Changed = false;
870
871 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000872 // all
alan-bakerb39c8262019-03-08 14:03:37 -0500873 {"_Z3allc", ""},
874 {"_Z3allDv2_c", "__spirv_allDv2_c"},
875 {"_Z3allDv3_c", "__spirv_allDv3_c"},
876 {"_Z3allDv4_c", "__spirv_allDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000877 {"_Z3alls", ""},
878 {"_Z3allDv2_s", "__spirv_allDv2_s"},
879 {"_Z3allDv3_s", "__spirv_allDv3_s"},
880 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400881 {"_Z3alli", ""},
882 {"_Z3allDv2_i", "__spirv_allDv2_i"},
883 {"_Z3allDv3_i", "__spirv_allDv3_i"},
884 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000885 {"_Z3alll", ""},
886 {"_Z3allDv2_l", "__spirv_allDv2_l"},
887 {"_Z3allDv3_l", "__spirv_allDv3_l"},
888 {"_Z3allDv4_l", "__spirv_allDv4_l"},
889
890 // any
alan-bakerb39c8262019-03-08 14:03:37 -0500891 {"_Z3anyc", ""},
892 {"_Z3anyDv2_c", "__spirv_anyDv2_c"},
893 {"_Z3anyDv3_c", "__spirv_anyDv3_c"},
894 {"_Z3anyDv4_c", "__spirv_anyDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000895 {"_Z3anys", ""},
896 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
897 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
898 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400899 {"_Z3anyi", ""},
900 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
901 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
902 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000903 {"_Z3anyl", ""},
904 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
905 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
906 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400907 };
908
909 for (auto Pair : Map) {
910 // If we find a function with the matching name.
911 if (auto F = M.getFunction(Pair.first)) {
912 SmallVector<Instruction *, 4> ToRemoves;
913
914 // Walk the users of the function.
915 for (auto &U : F->uses()) {
916 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
917 // The fake SPIR-V intrinsic to generate.
918 auto SPIRVIntrinsic = Pair.second;
919
920 auto Arg = CI->getOperand(0);
921
922 Value *V;
923
Kévin Petitfd27cca2018-10-31 13:00:17 +0000924 // If the argument is a 32-bit int, just use a shift
925 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
926 V = BinaryOperator::Create(Instruction::LShr, Arg,
927 ConstantInt::get(Arg->getType(), 31), "",
928 CI);
929 } else {
David Neto22f144c2017-06-12 14:26:21 -0400930 // The value for zero to compare against.
931 const auto ZeroValue = Constant::getNullValue(Arg->getType());
932
David Neto22f144c2017-06-12 14:26:21 -0400933 // The value to return for true.
934 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
935
936 // The value to return for false.
937 const auto FalseValue = Constant::getNullValue(CI->getType());
938
Kévin Petitfd27cca2018-10-31 13:00:17 +0000939 const auto Cmp = CmpInst::Create(
940 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
941
942 Value* SelectSource;
943
944 // If we have a function to call, call it!
945 if (0 < strlen(SPIRVIntrinsic)) {
946
947 const auto NewFType = FunctionType::get(
948 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
949
950 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
951
952 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
953
954 SelectSource = NewCI;
955
956 } else {
957 SelectSource = Cmp;
958 }
959
960 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400961 }
962
963 CI->replaceAllUsesWith(V);
964
965 // Lastly, remember to remove the user.
966 ToRemoves.push_back(CI);
967 }
968 }
969
970 Changed = !ToRemoves.empty();
971
972 // And cleanup the calls we don't use anymore.
973 for (auto V : ToRemoves) {
974 V->eraseFromParent();
975 }
976
977 // And remove the function we don't need either too.
978 F->eraseFromParent();
979 }
980 }
981
982 return Changed;
983}
984
Kévin Petitbf0036c2019-03-06 13:57:10 +0000985bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
986 bool Changed = false;
987
988 for (auto const &SymVal : M.getValueSymbolTable()) {
989 // Skip symbols whose name doesn't match
990 if (!SymVal.getKey().startswith("_Z8upsample")) {
991 continue;
992 }
993 // Is there a function going by that name?
994 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
995
996 SmallVector<Instruction *, 4> ToRemoves;
997
998 // Walk the users of the function.
999 for (auto &U : F->uses()) {
1000 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1001
1002 // Get arguments
1003 auto HiValue = CI->getOperand(0);
1004 auto LoValue = CI->getOperand(1);
1005
1006 // Don't touch overloads that aren't in OpenCL C
1007 auto HiType = HiValue->getType();
1008 auto LoType = LoValue->getType();
1009
1010 if (HiType != LoType) {
1011 continue;
1012 }
1013
1014 if (!HiType->isIntOrIntVectorTy()) {
1015 continue;
1016 }
1017
1018 if (HiType->getScalarSizeInBits() * 2 !=
1019 CI->getType()->getScalarSizeInBits()) {
1020 continue;
1021 }
1022
1023 if ((HiType->getScalarSizeInBits() != 8) &&
1024 (HiType->getScalarSizeInBits() != 16) &&
1025 (HiType->getScalarSizeInBits() != 32)) {
1026 continue;
1027 }
1028
1029 if (HiType->isVectorTy()) {
1030 if ((HiType->getVectorNumElements() != 2) &&
1031 (HiType->getVectorNumElements() != 3) &&
1032 (HiType->getVectorNumElements() != 4) &&
1033 (HiType->getVectorNumElements() != 8) &&
1034 (HiType->getVectorNumElements() != 16)) {
1035 continue;
1036 }
1037 }
1038
1039 // Convert both operands to the result type
1040 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
1041 "", CI);
1042 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
1043 "", CI);
1044
1045 // Shift high operand
1046 auto ShiftAmount = ConstantInt::get(CI->getType(),
1047 HiType->getScalarSizeInBits());
1048 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1049 ShiftAmount, "", CI);
1050
1051 // OR both results
1052 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1053 "", CI);
1054
1055 // Replace call with the expression
1056 CI->replaceAllUsesWith(V);
1057
1058 // Lastly, remember to remove the user.
1059 ToRemoves.push_back(CI);
1060 }
1061 }
1062
1063 Changed = !ToRemoves.empty();
1064
1065 // And cleanup the calls we don't use anymore.
1066 for (auto V : ToRemoves) {
1067 V->eraseFromParent();
1068 }
1069
1070 // And remove the function we don't need either too.
1071 F->eraseFromParent();
1072 }
1073 }
1074
1075 return Changed;
1076}
1077
Kévin Petitd44eef52019-03-08 13:22:14 +00001078bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1079 bool Changed = false;
1080
1081 for (auto const &SymVal : M.getValueSymbolTable()) {
1082 // Skip symbols whose name doesn't match
1083 if (!SymVal.getKey().startswith("_Z6rotate")) {
1084 continue;
1085 }
1086 // Is there a function going by that name?
1087 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1088
1089 SmallVector<Instruction *, 4> ToRemoves;
1090
1091 // Walk the users of the function.
1092 for (auto &U : F->uses()) {
1093 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1094
1095 // Get arguments
1096 auto SrcValue = CI->getOperand(0);
1097 auto RotAmount = CI->getOperand(1);
1098
1099 // Don't touch overloads that aren't in OpenCL C
1100 auto SrcType = SrcValue->getType();
1101 auto RotType = RotAmount->getType();
1102
1103 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1104 continue;
1105 }
1106
1107 if (!SrcType->isIntOrIntVectorTy()) {
1108 continue;
1109 }
1110
1111 if ((SrcType->getScalarSizeInBits() != 8) &&
1112 (SrcType->getScalarSizeInBits() != 16) &&
1113 (SrcType->getScalarSizeInBits() != 32) &&
1114 (SrcType->getScalarSizeInBits() != 64)) {
1115 continue;
1116 }
1117
1118 if (SrcType->isVectorTy()) {
1119 if ((SrcType->getVectorNumElements() != 2) &&
1120 (SrcType->getVectorNumElements() != 3) &&
1121 (SrcType->getVectorNumElements() != 4) &&
1122 (SrcType->getVectorNumElements() != 8) &&
1123 (SrcType->getVectorNumElements() != 16)) {
1124 continue;
1125 }
1126 }
1127
1128 // The approach used is to shift the top bits down, the bottom bits up
1129 // and OR the two shifted values.
1130
1131 // The rotation amount is to be treated modulo the element size.
1132 // Since SPIR-V shift ops don't support this, let's apply the
1133 // modulo ahead of shifting. The element size is always a power of
1134 // two so we can just AND with a mask.
1135 auto ModMask = ConstantInt::get(SrcType,
1136 SrcType->getScalarSizeInBits() - 1);
1137 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1138 ModMask, "", CI);
1139
1140 // Let's calc the amount by which to shift top bits down
1141 auto ScalarSize = ConstantInt::get(SrcType,
1142 SrcType->getScalarSizeInBits());
1143 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1144 RotAmount, "", CI);
1145
1146 // Now shift the bottom bits up and the top bits down
1147 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1148 RotAmount, "", CI);
1149 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1150 DownAmount, "", CI);
1151
1152 // Finally OR the two shifted values
1153 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1154 HiRotated, "", CI);
1155
1156 // Replace call with the expression
1157 CI->replaceAllUsesWith(V);
1158
1159 // Lastly, remember to remove the user.
1160 ToRemoves.push_back(CI);
1161 }
1162 }
1163
1164 Changed = !ToRemoves.empty();
1165
1166 // And cleanup the calls we don't use anymore.
1167 for (auto V : ToRemoves) {
1168 V->eraseFromParent();
1169 }
1170
1171 // And remove the function we don't need either too.
1172 F->eraseFromParent();
1173 }
1174 }
1175
1176 return Changed;
1177}
1178
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001179bool ReplaceOpenCLBuiltinPass::replaceConvert(Module &M) {
1180 bool Changed = false;
1181
1182 for (auto const &SymVal : M.getValueSymbolTable()) {
1183
1184 // Skip symbols whose name obviously doesn't match
1185 if (!SymVal.getKey().contains("convert_")) {
1186 continue;
1187 }
1188
1189 // Is there a function going by that name?
1190 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1191
1192 // Get info from the mangled name
1193 FunctionInfo finfo;
1194 bool parsed = getFunctionInfoFromMangledNameCheck(F->getName(), &finfo);
1195
1196 // All functions of interest are handled by our mangled name parser
1197 if (!parsed) {
1198 continue;
1199 }
1200
1201 // Move on if this isn't a call to convert_
1202 if (!finfo.name.startswith("convert_")) {
1203 continue;
1204 }
1205
1206 // Extract the destination type from the function name
1207 StringRef DstTypeName = finfo.name;
1208 DstTypeName.consume_front("convert_");
1209
1210 auto DstSignedNess = StringSwitch<ArgTypeInfo::SignedNess>(DstTypeName)
1211 .StartsWith("char", ArgTypeInfo::SignedNess::Signed)
1212 .StartsWith("short", ArgTypeInfo::SignedNess::Signed)
1213 .StartsWith("int", ArgTypeInfo::SignedNess::Signed)
1214 .StartsWith("long", ArgTypeInfo::SignedNess::Signed)
1215 .StartsWith("uchar", ArgTypeInfo::SignedNess::Unsigned)
1216 .StartsWith("ushort", ArgTypeInfo::SignedNess::Unsigned)
1217 .StartsWith("uint", ArgTypeInfo::SignedNess::Unsigned)
1218 .StartsWith("ulong", ArgTypeInfo::SignedNess::Unsigned)
1219 .Default(ArgTypeInfo::SignedNess::None);
1220
1221 auto SrcSignedNess = finfo.argTypeInfos[0].signedness;
1222
1223 bool DstIsSigned = DstSignedNess == ArgTypeInfo::SignedNess::Signed;
1224 bool SrcIsSigned = SrcSignedNess == ArgTypeInfo::SignedNess::Signed;
1225
1226 SmallVector<Instruction *, 4> ToRemoves;
1227
1228 // Walk the users of the function.
1229 for (auto &U : F->uses()) {
1230 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1231
1232 // Get arguments
1233 auto SrcValue = CI->getOperand(0);
1234
1235 // Don't touch overloads that aren't in OpenCL C
1236 auto SrcType = SrcValue->getType();
1237 auto DstType = CI->getType();
1238
1239 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1240 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1241 continue;
1242 }
1243
1244 if (SrcType->isVectorTy()) {
1245
1246 if (SrcType->getVectorNumElements() !=
1247 DstType->getVectorNumElements()) {
1248 continue;
1249 }
1250
1251 if ((SrcType->getVectorNumElements() != 2) &&
1252 (SrcType->getVectorNumElements() != 3) &&
1253 (SrcType->getVectorNumElements() != 4) &&
1254 (SrcType->getVectorNumElements() != 8) &&
1255 (SrcType->getVectorNumElements() != 16)) {
1256 continue;
1257 }
1258 }
1259
1260 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1261 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1262
1263 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1264 bool DstIsInt = DstType->isIntOrIntVectorTy();
1265
1266 Value *V;
1267 if (SrcIsFloat && DstIsFloat) {
1268 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1269 } else if (SrcIsFloat && DstIsInt) {
1270 if (DstIsSigned) {
1271 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "", CI);
1272 } else {
1273 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "", CI);
1274 }
1275 } else if (SrcIsInt && DstIsFloat) {
1276 if (SrcIsSigned) {
1277 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "", CI);
1278 } else {
1279 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "", CI);
1280 }
1281 } else if (SrcIsInt && DstIsInt) {
1282 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "", CI);
1283 } else {
1284 // Not something we're supposed to handle, just move on
1285 continue;
1286 }
1287
1288 // Replace call with the expression
1289 CI->replaceAllUsesWith(V);
1290
1291 // Lastly, remember to remove the user.
1292 ToRemoves.push_back(CI);
1293 }
1294 }
1295
1296 Changed = !ToRemoves.empty();
1297
1298 // And cleanup the calls we don't use anymore.
1299 for (auto V : ToRemoves) {
1300 V->eraseFromParent();
1301 }
1302
1303 // And remove the function we don't need either too.
1304 F->eraseFromParent();
1305 }
1306 }
1307
1308 return Changed;
1309}
1310
Kévin Petit8a560882019-03-21 15:24:34 +00001311bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1312 bool Changed = false;
1313
1314 for (auto const &SymVal : M.getValueSymbolTable()) {
1315
1316 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1317 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1318
1319 // Skip symbols whose name doesn't match
1320 if (!isMad && !isMul) {
1321 continue;
1322 }
1323
1324 // Is there a function going by that name?
1325 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1326
1327 SmallVector<Instruction *, 4> ToRemoves;
1328
1329 // Walk the users of the function.
1330 for (auto &U : F->uses()) {
1331 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1332
1333 // Get arguments
1334 auto AValue = CI->getOperand(0);
1335 auto BValue = CI->getOperand(1);
1336 auto CValue = CI->getOperand(2);
1337
1338 // Don't touch overloads that aren't in OpenCL C
1339 auto AType = AValue->getType();
1340 auto BType = BValue->getType();
1341 auto CType = CValue->getType();
1342
1343 if ((AType != BType) || (CI->getType() != AType) ||
1344 (isMad && (AType != CType))) {
1345 continue;
1346 }
1347
1348 if (!AType->isIntOrIntVectorTy()) {
1349 continue;
1350 }
1351
1352 if ((AType->getScalarSizeInBits() != 8) &&
1353 (AType->getScalarSizeInBits() != 16) &&
1354 (AType->getScalarSizeInBits() != 32) &&
1355 (AType->getScalarSizeInBits() != 64)) {
1356 continue;
1357 }
1358
1359 if (AType->isVectorTy()) {
1360 if ((AType->getVectorNumElements() != 2) &&
1361 (AType->getVectorNumElements() != 3) &&
1362 (AType->getVectorNumElements() != 4) &&
1363 (AType->getVectorNumElements() != 8) &&
1364 (AType->getVectorNumElements() != 16)) {
1365 continue;
1366 }
1367 }
1368
1369 // Create struct type for the return type of our SPIR-V intrinsic
1370 SmallVector<Type*, 2> TwoValueType = {
1371 AType,
1372 AType
1373 };
1374
1375 auto ExMulRetType = StructType::create(TwoValueType);
1376
1377 // And a function type
1378 auto NewFType = FunctionType::get(ExMulRetType, TwoValueType, false);
1379
1380 // Get infos from the mangled OpenCL built-in function name
1381 FunctionInfo finfo;
1382 getFunctionInfoFromMangledName(F->getName(), &finfo);
1383
1384 // Use it to select the appropriate signed/unsigned SPIR-V intrinsic
1385 StringRef intrinsic;
1386 if (finfo.argTypeInfos[0].signedness == ArgTypeInfo::SignedNess::Signed) {
1387 intrinsic = "spirv.smul_extended";
1388 } else {
1389 intrinsic = "spirv.umul_extended";
1390 }
1391
1392 // Add the intrinsic function to the module
1393 auto NewF = M.getOrInsertFunction(intrinsic, NewFType);
1394
1395 // Call it
1396 SmallVector<Value*, 4> NewFArgs = {
1397 AValue,
1398 BValue,
1399 };
1400
1401 auto Call = CallInst::Create(NewF, NewFArgs, "", CI);
1402
1403 // Get the high part of the result
1404 unsigned Idxs[] = {1};
1405 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1406
1407 // If we're handling a mad_hi, add the third argument to the result
1408 if (isMad) {
1409 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1410 }
1411
1412 // Replace call with the expression
1413 CI->replaceAllUsesWith(V);
1414
1415 // Lastly, remember to remove the user.
1416 ToRemoves.push_back(CI);
1417 }
1418 }
1419
1420 Changed = !ToRemoves.empty();
1421
1422 // And cleanup the calls we don't use anymore.
1423 for (auto V : ToRemoves) {
1424 V->eraseFromParent();
1425 }
1426
1427 // And remove the function we don't need either too.
1428 F->eraseFromParent();
1429 }
1430 }
1431
1432 return Changed;
1433}
1434
Kévin Petitf5b78a22018-10-25 14:32:17 +00001435bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1436 bool Changed = false;
1437
1438 for (auto const &SymVal : M.getValueSymbolTable()) {
1439 // Skip symbols whose name doesn't match
1440 if (!SymVal.getKey().startswith("_Z6select")) {
1441 continue;
1442 }
1443 // Is there a function going by that name?
1444 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1445
1446 SmallVector<Instruction *, 4> ToRemoves;
1447
1448 // Walk the users of the function.
1449 for (auto &U : F->uses()) {
1450 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1451
1452 // Get arguments
1453 auto FalseValue = CI->getOperand(0);
1454 auto TrueValue = CI->getOperand(1);
1455 auto PredicateValue = CI->getOperand(2);
1456
1457 // Don't touch overloads that aren't in OpenCL C
1458 auto FalseType = FalseValue->getType();
1459 auto TrueType = TrueValue->getType();
1460 auto PredicateType = PredicateValue->getType();
1461
1462 if (FalseType != TrueType) {
1463 continue;
1464 }
1465
1466 if (!PredicateType->isIntOrIntVectorTy()) {
1467 continue;
1468 }
1469
1470 if (!FalseType->isIntOrIntVectorTy() &&
1471 !FalseType->getScalarType()->isFloatingPointTy()) {
1472 continue;
1473 }
1474
1475 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1476 continue;
1477 }
1478
1479 if (FalseType->getScalarSizeInBits() !=
1480 PredicateType->getScalarSizeInBits()) {
1481 continue;
1482 }
1483
1484 if (FalseType->isVectorTy()) {
1485 if (FalseType->getVectorNumElements() !=
1486 PredicateType->getVectorNumElements()) {
1487 continue;
1488 }
1489
1490 if ((FalseType->getVectorNumElements() != 2) &&
1491 (FalseType->getVectorNumElements() != 3) &&
1492 (FalseType->getVectorNumElements() != 4) &&
1493 (FalseType->getVectorNumElements() != 8) &&
1494 (FalseType->getVectorNumElements() != 16)) {
1495 continue;
1496 }
1497 }
1498
1499 // Create constant
1500 const auto ZeroValue = Constant::getNullValue(PredicateType);
1501
1502 // Scalar and vector are to be treated differently
1503 CmpInst::Predicate Pred;
1504 if (PredicateType->isVectorTy()) {
1505 Pred = CmpInst::ICMP_SLT;
1506 } else {
1507 Pred = CmpInst::ICMP_NE;
1508 }
1509
1510 // Create comparison instruction
1511 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1512 ZeroValue, "", CI);
1513
1514 // Create select
1515 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1516
1517 // Replace call with the selection
1518 CI->replaceAllUsesWith(V);
1519
1520 // Lastly, remember to remove the user.
1521 ToRemoves.push_back(CI);
1522 }
1523 }
1524
1525 Changed = !ToRemoves.empty();
1526
1527 // And cleanup the calls we don't use anymore.
1528 for (auto V : ToRemoves) {
1529 V->eraseFromParent();
1530 }
1531
1532 // And remove the function we don't need either too.
1533 F->eraseFromParent();
1534 }
1535 }
1536
1537 return Changed;
1538}
1539
Kévin Petite7d0cce2018-10-31 12:38:56 +00001540bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1541 bool Changed = false;
1542
1543 for (auto const &SymVal : M.getValueSymbolTable()) {
1544 // Skip symbols whose name doesn't match
1545 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1546 continue;
1547 }
1548 // Is there a function going by that name?
1549 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1550
1551 SmallVector<Instruction *, 4> ToRemoves;
1552
1553 // Walk the users of the function.
1554 for (auto &U : F->uses()) {
1555 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1556
1557 if (CI->getNumOperands() != 4) {
1558 continue;
1559 }
1560
1561 // Get arguments
1562 auto FalseValue = CI->getOperand(0);
1563 auto TrueValue = CI->getOperand(1);
1564 auto PredicateValue = CI->getOperand(2);
1565
1566 // Don't touch overloads that aren't in OpenCL C
1567 auto FalseType = FalseValue->getType();
1568 auto TrueType = TrueValue->getType();
1569 auto PredicateType = PredicateValue->getType();
1570
1571 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1572 continue;
1573 }
1574
1575 if (TrueType->isVectorTy()) {
1576 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1577 !TrueType->getScalarType()->isIntegerTy()) {
1578 continue;
1579 }
1580 if ((TrueType->getVectorNumElements() != 2) &&
1581 (TrueType->getVectorNumElements() != 3) &&
1582 (TrueType->getVectorNumElements() != 4) &&
1583 (TrueType->getVectorNumElements() != 8) &&
1584 (TrueType->getVectorNumElements() != 16)) {
1585 continue;
1586 }
1587 }
1588
1589 // Remember the type of the operands
1590 auto OpType = TrueType;
1591
1592 // The actual bit selection will always be done on an integer type,
1593 // declare it here
1594 Type *BitType;
1595
1596 // If the operands are float, then bitcast them to int
1597 if (OpType->getScalarType()->isFloatingPointTy()) {
1598
1599 // First create the new type
1600 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1601 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1602 if (OpType->isVectorTy()) {
1603 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1604 }
1605
1606 // Then bitcast all operands
1607 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1608 BitType, "", CI);
1609 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1610 BitType, "", CI);
1611 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1612
1613 } else {
1614 // The operands have an integer type, use it directly
1615 BitType = OpType;
1616 }
1617
1618 // All the operands are now always integers
1619 // implement as (c & b) | (~c & a)
1620
1621 // Create our negated predicate value
1622 auto AllOnes = Constant::getAllOnesValue(BitType);
1623 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1624 PredicateValue,
1625 AllOnes, "", CI);
1626
1627 // Then put everything together
1628 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1629 NotPredicateValue,
1630 FalseValue, "", CI);
1631 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1632 PredicateValue,
1633 TrueValue, "", CI);
1634
1635 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1636 BitsTrue, "", CI);
1637
1638 // If we were dealing with a floating point type, we must bitcast
1639 // the result back to that
1640 if (OpType->getScalarType()->isFloatingPointTy()) {
1641 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1642 }
1643
1644 // Replace call with our new code
1645 CI->replaceAllUsesWith(V);
1646
1647 // Lastly, remember to remove the user.
1648 ToRemoves.push_back(CI);
1649 }
1650 }
1651
1652 Changed = !ToRemoves.empty();
1653
1654 // And cleanup the calls we don't use anymore.
1655 for (auto V : ToRemoves) {
1656 V->eraseFromParent();
1657 }
1658
1659 // And remove the function we don't need either too.
1660 F->eraseFromParent();
1661 }
1662 }
1663
1664 return Changed;
1665}
1666
Kévin Petit6b0a9532018-10-30 20:00:39 +00001667bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1668 bool Changed = false;
1669
1670 const std::map<const char *, const char *> Map = {
1671 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1672 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1673 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1674 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1675 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1676 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1677 };
1678
1679 for (auto Pair : Map) {
1680 // If we find a function with the matching name.
1681 if (auto F = M.getFunction(Pair.first)) {
1682 SmallVector<Instruction *, 4> ToRemoves;
1683
1684 // Walk the users of the function.
1685 for (auto &U : F->uses()) {
1686 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1687
1688 auto ReplacementFn = Pair.second;
1689
1690 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1691 Value *VectorArg;
1692
1693 // First figure out which function we're dealing with
1694 if (F->getName().startswith("_Z10smoothstep")) {
1695 ArgsToSplat.push_back(CI->getOperand(1));
1696 VectorArg = CI->getOperand(2);
1697 } else {
1698 VectorArg = CI->getOperand(1);
1699 }
1700
1701 // Splat arguments that need to be
1702 SmallVector<Value*, 2> SplatArgs;
1703 auto VecType = VectorArg->getType();
1704
1705 for (auto arg : ArgsToSplat) {
1706 Value* NewVectorArg = UndefValue::get(VecType);
1707 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1708 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1709 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1710 }
1711 SplatArgs.push_back(NewVectorArg);
1712 }
1713
1714 // Replace the call with the vector/vector flavour
1715 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1716 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1717
1718 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1719
1720 SmallVector<Value*, 3> NewArgs;
1721 for (auto arg : SplatArgs) {
1722 NewArgs.push_back(arg);
1723 }
1724 NewArgs.push_back(VectorArg);
1725
1726 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1727
1728 CI->replaceAllUsesWith(NewCI);
1729
1730 // Lastly, remember to remove the user.
1731 ToRemoves.push_back(CI);
1732 }
1733 }
1734
1735 Changed = !ToRemoves.empty();
1736
1737 // And cleanup the calls we don't use anymore.
1738 for (auto V : ToRemoves) {
1739 V->eraseFromParent();
1740 }
1741
1742 // And remove the function we don't need either too.
1743 F->eraseFromParent();
1744 }
1745 }
1746
1747 return Changed;
1748}
1749
David Neto22f144c2017-06-12 14:26:21 -04001750bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1751 bool Changed = false;
1752
1753 const std::map<const char *, Instruction::BinaryOps> Map = {
1754 {"_Z7signbitf", Instruction::LShr},
1755 {"_Z7signbitDv2_f", Instruction::AShr},
1756 {"_Z7signbitDv3_f", Instruction::AShr},
1757 {"_Z7signbitDv4_f", Instruction::AShr},
1758 };
1759
1760 for (auto Pair : Map) {
1761 // If we find a function with the matching name.
1762 if (auto F = M.getFunction(Pair.first)) {
1763 SmallVector<Instruction *, 4> ToRemoves;
1764
1765 // Walk the users of the function.
1766 for (auto &U : F->uses()) {
1767 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1768 auto Arg = CI->getOperand(0);
1769
1770 auto Bitcast =
1771 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1772
1773 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1774 ConstantInt::get(CI->getType(), 31),
1775 "", CI);
1776
1777 CI->replaceAllUsesWith(Shr);
1778
1779 // Lastly, remember to remove the user.
1780 ToRemoves.push_back(CI);
1781 }
1782 }
1783
1784 Changed = !ToRemoves.empty();
1785
1786 // And cleanup the calls we don't use anymore.
1787 for (auto V : ToRemoves) {
1788 V->eraseFromParent();
1789 }
1790
1791 // And remove the function we don't need either too.
1792 F->eraseFromParent();
1793 }
1794 }
1795
1796 return Changed;
1797}
1798
1799bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1800 bool Changed = false;
1801
1802 const std::map<const char *,
1803 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1804 Map = {
1805 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1806 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1807 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1808 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1809 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1810 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1811 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1812 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1813 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1814 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1815 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1816 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1817 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1818 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1819 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1820 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1821 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1822 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1823 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1824 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1825 };
1826
1827 for (auto Pair : Map) {
1828 // If we find a function with the matching name.
1829 if (auto F = M.getFunction(Pair.first)) {
1830 SmallVector<Instruction *, 4> ToRemoves;
1831
1832 // Walk the users of the function.
1833 for (auto &U : F->uses()) {
1834 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1835 // The multiply instruction to use.
1836 auto MulInst = Pair.second.first;
1837
1838 // The add instruction to use.
1839 auto AddInst = Pair.second.second;
1840
1841 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1842
1843 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1844 CI->getArgOperand(1), "", CI);
1845
1846 if (Instruction::BinaryOpsEnd != AddInst) {
1847 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1848 CI);
1849 }
1850
1851 CI->replaceAllUsesWith(I);
1852
1853 // Lastly, remember to remove the user.
1854 ToRemoves.push_back(CI);
1855 }
1856 }
1857
1858 Changed = !ToRemoves.empty();
1859
1860 // And cleanup the calls we don't use anymore.
1861 for (auto V : ToRemoves) {
1862 V->eraseFromParent();
1863 }
1864
1865 // And remove the function we don't need either too.
1866 F->eraseFromParent();
1867 }
1868 }
1869
1870 return Changed;
1871}
1872
Derek Chowcfd368b2017-10-19 20:58:45 -07001873bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1874 bool Changed = false;
1875
1876 struct VectorStoreOps {
1877 const char* name;
1878 int n;
1879 Type* (*get_scalar_type_function)(LLVMContext&);
1880 } vector_store_ops[] = {
1881 // TODO(derekjchow): Expand this list.
1882 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1883 };
1884
David Neto544fffc2017-11-16 18:35:14 -05001885 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001886 auto Name = Op.name;
1887 auto N = Op.n;
1888 auto TypeFn = Op.get_scalar_type_function;
1889 if (auto F = M.getFunction(Name)) {
1890 SmallVector<Instruction *, 4> ToRemoves;
1891
1892 // Walk the users of the function.
1893 for (auto &U : F->uses()) {
1894 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1895 // The value argument from vstoren.
1896 auto Arg0 = CI->getOperand(0);
1897
1898 // The index argument from vstoren.
1899 auto Arg1 = CI->getOperand(1);
1900
1901 // The pointer argument from vstoren.
1902 auto Arg2 = CI->getOperand(2);
1903
1904 // Get types.
1905 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1906 auto ScalarNPointerTy = PointerType::get(
1907 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1908
1909 // Cast to scalarn
1910 auto Cast = CastInst::CreatePointerCast(
1911 Arg2, ScalarNPointerTy, "", CI);
1912 // Index to correct address
1913 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1914 // Store
1915 auto Store = new StoreInst(Arg0, Index, CI);
1916
1917 CI->replaceAllUsesWith(Store);
1918 ToRemoves.push_back(CI);
1919 }
1920 }
1921
1922 Changed = !ToRemoves.empty();
1923
1924 // And cleanup the calls we don't use anymore.
1925 for (auto V : ToRemoves) {
1926 V->eraseFromParent();
1927 }
1928
1929 // And remove the function we don't need either too.
1930 F->eraseFromParent();
1931 }
1932 }
1933
1934 return Changed;
1935}
1936
1937bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1938 bool Changed = false;
1939
1940 struct VectorLoadOps {
1941 const char* name;
1942 int n;
1943 Type* (*get_scalar_type_function)(LLVMContext&);
1944 } vector_load_ops[] = {
1945 // TODO(derekjchow): Expand this list.
1946 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1947 };
1948
David Neto544fffc2017-11-16 18:35:14 -05001949 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001950 auto Name = Op.name;
1951 auto N = Op.n;
1952 auto TypeFn = Op.get_scalar_type_function;
1953 // If we find a function with the matching name.
1954 if (auto F = M.getFunction(Name)) {
1955 SmallVector<Instruction *, 4> ToRemoves;
1956
1957 // Walk the users of the function.
1958 for (auto &U : F->uses()) {
1959 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1960 // The index argument from vloadn.
1961 auto Arg0 = CI->getOperand(0);
1962
1963 // The pointer argument from vloadn.
1964 auto Arg1 = CI->getOperand(1);
1965
1966 // Get types.
1967 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1968 auto ScalarNPointerTy = PointerType::get(
1969 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1970
1971 // Cast to scalarn
1972 auto Cast = CastInst::CreatePointerCast(
1973 Arg1, ScalarNPointerTy, "", CI);
1974 // Index to correct address
1975 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1976 // Load
1977 auto Load = new LoadInst(Index, "", CI);
1978
1979 CI->replaceAllUsesWith(Load);
1980 ToRemoves.push_back(CI);
1981 }
1982 }
1983
1984 Changed = !ToRemoves.empty();
1985
1986 // And cleanup the calls we don't use anymore.
1987 for (auto V : ToRemoves) {
1988 V->eraseFromParent();
1989 }
1990
1991 // And remove the function we don't need either too.
1992 F->eraseFromParent();
1993
1994 }
1995 }
1996
1997 return Changed;
1998}
1999
David Neto22f144c2017-06-12 14:26:21 -04002000bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
2001 bool Changed = false;
2002
2003 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
2004 "_Z10vload_halfjPU3AS2KDh"};
2005
2006 for (auto Name : Map) {
2007 // If we find a function with the matching name.
2008 if (auto F = M.getFunction(Name)) {
2009 SmallVector<Instruction *, 4> ToRemoves;
2010
2011 // Walk the users of the function.
2012 for (auto &U : F->uses()) {
2013 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2014 // The index argument from vload_half.
2015 auto Arg0 = CI->getOperand(0);
2016
2017 // The pointer argument from vload_half.
2018 auto Arg1 = CI->getOperand(1);
2019
David Neto22f144c2017-06-12 14:26:21 -04002020 auto IntTy = Type::getInt32Ty(M.getContext());
2021 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002022 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2023
David Neto22f144c2017-06-12 14:26:21 -04002024 // Our intrinsic to unpack a float2 from an int.
2025 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2026
2027 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2028
David Neto482550a2018-03-24 05:21:07 -07002029 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04002030 auto ShortTy = Type::getInt16Ty(M.getContext());
2031 auto ShortPointerTy = PointerType::get(
2032 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002033
David Netoac825b82017-05-30 12:49:01 -04002034 // Cast the half* pointer to short*.
2035 auto Cast =
2036 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002037
David Netoac825b82017-05-30 12:49:01 -04002038 // Index into the correct address of the casted pointer.
2039 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
2040
2041 // Load from the short* we casted to.
2042 auto Load = new LoadInst(Index, "", CI);
2043
2044 // ZExt the short -> int.
2045 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
2046
2047 // Get our float2.
2048 auto Call = CallInst::Create(NewF, ZExt, "", CI);
2049
2050 // Extract out the bottom element which is our float result.
2051 auto Extract = ExtractElementInst::Create(
2052 Call, ConstantInt::get(IntTy, 0), "", CI);
2053
2054 CI->replaceAllUsesWith(Extract);
2055 } else {
2056 // Assume the pointer argument points to storage aligned to 32bits
2057 // or more.
2058 // TODO(dneto): Do more analysis to make sure this is true?
2059 //
2060 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
2061 // with:
2062 //
2063 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
2064 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
2065 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
2066 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
2067 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
2068 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
2069 // x float> %converted, %index_is_odd32
2070
2071 auto IntPointerTy = PointerType::get(
2072 IntTy, Arg1->getType()->getPointerAddressSpace());
2073
David Neto973e6a82017-05-30 13:48:18 -04002074 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04002075 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04002076 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04002077 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
2078
2079 auto One = ConstantInt::get(IntTy, 1);
2080 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
2081 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
2082
2083 // Index into the correct address of the casted pointer.
2084 auto Ptr =
2085 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
2086
2087 // Load from the int* we casted to.
2088 auto Load = new LoadInst(Ptr, "", CI);
2089
2090 // Get our float2.
2091 auto Call = CallInst::Create(NewF, Load, "", CI);
2092
2093 // Extract out the float result, where the element number is
2094 // determined by whether the original index was even or odd.
2095 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
2096
2097 CI->replaceAllUsesWith(Extract);
2098 }
David Neto22f144c2017-06-12 14:26:21 -04002099
2100 // Lastly, remember to remove the user.
2101 ToRemoves.push_back(CI);
2102 }
2103 }
2104
2105 Changed = !ToRemoves.empty();
2106
2107 // And cleanup the calls we don't use anymore.
2108 for (auto V : ToRemoves) {
2109 V->eraseFromParent();
2110 }
2111
2112 // And remove the function we don't need either too.
2113 F->eraseFromParent();
2114 }
2115 }
2116
2117 return Changed;
2118}
2119
2120bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
2121 bool Changed = false;
2122
David Neto556c7e62018-06-08 13:45:55 -07002123 const std::vector<const char *> Map = {
2124 "_Z11vload_half2jPU3AS1KDh",
2125 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
2126 "_Z11vload_half2jPU3AS2KDh",
2127 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
2128 };
David Neto22f144c2017-06-12 14:26:21 -04002129
2130 for (auto Name : Map) {
2131 // If we find a function with the matching name.
2132 if (auto F = M.getFunction(Name)) {
2133 SmallVector<Instruction *, 4> ToRemoves;
2134
2135 // Walk the users of the function.
2136 for (auto &U : F->uses()) {
2137 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2138 // The index argument from vload_half.
2139 auto Arg0 = CI->getOperand(0);
2140
2141 // The pointer argument from vload_half.
2142 auto Arg1 = CI->getOperand(1);
2143
2144 auto IntTy = Type::getInt32Ty(M.getContext());
2145 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2146 auto NewPointerTy = PointerType::get(
2147 IntTy, Arg1->getType()->getPointerAddressSpace());
2148 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2149
2150 // Cast the half* pointer to int*.
2151 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2152
2153 // Index into the correct address of the casted pointer.
2154 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
2155
2156 // Load from the int* we casted to.
2157 auto Load = new LoadInst(Index, "", CI);
2158
2159 // Our intrinsic to unpack a float2 from an int.
2160 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2161
2162 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2163
2164 // Get our float2.
2165 auto Call = CallInst::Create(NewF, Load, "", CI);
2166
2167 CI->replaceAllUsesWith(Call);
2168
2169 // Lastly, remember to remove the user.
2170 ToRemoves.push_back(CI);
2171 }
2172 }
2173
2174 Changed = !ToRemoves.empty();
2175
2176 // And cleanup the calls we don't use anymore.
2177 for (auto V : ToRemoves) {
2178 V->eraseFromParent();
2179 }
2180
2181 // And remove the function we don't need either too.
2182 F->eraseFromParent();
2183 }
2184 }
2185
2186 return Changed;
2187}
2188
2189bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
2190 bool Changed = false;
2191
David Neto556c7e62018-06-08 13:45:55 -07002192 const std::vector<const char *> Map = {
2193 "_Z11vload_half4jPU3AS1KDh",
2194 "_Z12vloada_half4jPU3AS1KDh",
2195 "_Z11vload_half4jPU3AS2KDh",
2196 "_Z12vloada_half4jPU3AS2KDh",
2197 };
David Neto22f144c2017-06-12 14:26:21 -04002198
2199 for (auto Name : Map) {
2200 // If we find a function with the matching name.
2201 if (auto F = M.getFunction(Name)) {
2202 SmallVector<Instruction *, 4> ToRemoves;
2203
2204 // Walk the users of the function.
2205 for (auto &U : F->uses()) {
2206 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2207 // The index argument from vload_half.
2208 auto Arg0 = CI->getOperand(0);
2209
2210 // The pointer argument from vload_half.
2211 auto Arg1 = CI->getOperand(1);
2212
2213 auto IntTy = Type::getInt32Ty(M.getContext());
2214 auto Int2Ty = VectorType::get(IntTy, 2);
2215 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2216 auto NewPointerTy = PointerType::get(
2217 Int2Ty, Arg1->getType()->getPointerAddressSpace());
2218 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2219
2220 // Cast the half* pointer to int2*.
2221 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2222
2223 // Index into the correct address of the casted pointer.
2224 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
2225
2226 // Load from the int2* we casted to.
2227 auto Load = new LoadInst(Index, "", CI);
2228
2229 // Extract each element from the loaded int2.
2230 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2231 "", CI);
2232 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2233 "", CI);
2234
2235 // Our intrinsic to unpack a float2 from an int.
2236 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2237
2238 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2239
2240 // Get the lower (x & y) components of our final float4.
2241 auto Lo = CallInst::Create(NewF, X, "", CI);
2242
2243 // Get the higher (z & w) components of our final float4.
2244 auto Hi = CallInst::Create(NewF, Y, "", CI);
2245
2246 Constant *ShuffleMask[4] = {
2247 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2248 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2249
2250 // Combine our two float2's into one float4.
2251 auto Combine = new ShuffleVectorInst(
2252 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2253
2254 CI->replaceAllUsesWith(Combine);
2255
2256 // Lastly, remember to remove the user.
2257 ToRemoves.push_back(CI);
2258 }
2259 }
2260
2261 Changed = !ToRemoves.empty();
2262
2263 // And cleanup the calls we don't use anymore.
2264 for (auto V : ToRemoves) {
2265 V->eraseFromParent();
2266 }
2267
2268 // And remove the function we don't need either too.
2269 F->eraseFromParent();
2270 }
2271 }
2272
2273 return Changed;
2274}
2275
David Neto6ad93232018-06-07 15:42:58 -07002276bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
2277 bool Changed = false;
2278
2279 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2280 //
2281 // %u = load i32 %ptr
2282 // %fxy = call <2 x float> Unpack2xHalf(u)
2283 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2284 const std::vector<const char *> Map = {
2285 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2286 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2287 "_Z20__clspv_vloada_half2jPKj", // private
2288 };
2289
2290 for (auto Name : Map) {
2291 // If we find a function with the matching name.
2292 if (auto F = M.getFunction(Name)) {
2293 SmallVector<Instruction *, 4> ToRemoves;
2294
2295 // Walk the users of the function.
2296 for (auto &U : F->uses()) {
2297 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
2298 auto Index = CI->getOperand(0);
2299 auto Ptr = CI->getOperand(1);
2300
2301 auto IntTy = Type::getInt32Ty(M.getContext());
2302 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2303 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2304
2305 auto IndexedPtr =
2306 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
2307 auto Load = new LoadInst(IndexedPtr, "", CI);
2308
2309 // Our intrinsic to unpack a float2 from an int.
2310 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2311
2312 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2313
2314 // Get our final float2.
2315 auto Result = CallInst::Create(NewF, Load, "", CI);
2316
2317 CI->replaceAllUsesWith(Result);
2318
2319 // Lastly, remember to remove the user.
2320 ToRemoves.push_back(CI);
2321 }
2322 }
2323
2324 Changed = true;
2325
2326 // And cleanup the calls we don't use anymore.
2327 for (auto V : ToRemoves) {
2328 V->eraseFromParent();
2329 }
2330
2331 // And remove the function we don't need either too.
2332 F->eraseFromParent();
2333 }
2334 }
2335
2336 return Changed;
2337}
2338
2339bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
2340 bool Changed = false;
2341
2342 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2343 //
2344 // %u2 = load <2 x i32> %ptr
2345 // %u2xy = extractelement %u2, 0
2346 // %u2zw = extractelement %u2, 1
2347 // %fxy = call <2 x float> Unpack2xHalf(uint)
2348 // %fzw = call <2 x float> Unpack2xHalf(uint)
2349 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2350 const std::vector<const char *> Map = {
2351 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2352 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2353 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2354 };
2355
2356 for (auto Name : Map) {
2357 // If we find a function with the matching name.
2358 if (auto F = M.getFunction(Name)) {
2359 SmallVector<Instruction *, 4> ToRemoves;
2360
2361 // Walk the users of the function.
2362 for (auto &U : F->uses()) {
2363 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2364 auto Index = CI->getOperand(0);
2365 auto Ptr = CI->getOperand(1);
2366
2367 auto IntTy = Type::getInt32Ty(M.getContext());
2368 auto Int2Ty = VectorType::get(IntTy, 2);
2369 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2370 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2371
2372 auto IndexedPtr =
2373 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
2374 auto Load = new LoadInst(IndexedPtr, "", CI);
2375
2376 // Extract each element from the loaded int2.
2377 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2378 "", CI);
2379 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2380 "", CI);
2381
2382 // Our intrinsic to unpack a float2 from an int.
2383 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2384
2385 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2386
2387 // Get the lower (x & y) components of our final float4.
2388 auto Lo = CallInst::Create(NewF, X, "", CI);
2389
2390 // Get the higher (z & w) components of our final float4.
2391 auto Hi = CallInst::Create(NewF, Y, "", CI);
2392
2393 Constant *ShuffleMask[4] = {
2394 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2395 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2396
2397 // Combine our two float2's into one float4.
2398 auto Combine = new ShuffleVectorInst(
2399 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2400
2401 CI->replaceAllUsesWith(Combine);
2402
2403 // Lastly, remember to remove the user.
2404 ToRemoves.push_back(CI);
2405 }
2406 }
2407
2408 Changed = true;
2409
2410 // And cleanup the calls we don't use anymore.
2411 for (auto V : ToRemoves) {
2412 V->eraseFromParent();
2413 }
2414
2415 // And remove the function we don't need either too.
2416 F->eraseFromParent();
2417 }
2418 }
2419
2420 return Changed;
2421}
2422
David Neto22f144c2017-06-12 14:26:21 -04002423bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
2424 bool Changed = false;
2425
2426 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
2427 "_Z15vstore_half_rtefjPU3AS1Dh",
2428 "_Z15vstore_half_rtzfjPU3AS1Dh"};
2429
2430 for (auto Name : Map) {
2431 // If we find a function with the matching name.
2432 if (auto F = M.getFunction(Name)) {
2433 SmallVector<Instruction *, 4> ToRemoves;
2434
2435 // Walk the users of the function.
2436 for (auto &U : F->uses()) {
2437 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2438 // The value to store.
2439 auto Arg0 = CI->getOperand(0);
2440
2441 // The index argument from vstore_half.
2442 auto Arg1 = CI->getOperand(1);
2443
2444 // The pointer argument from vstore_half.
2445 auto Arg2 = CI->getOperand(2);
2446
David Neto22f144c2017-06-12 14:26:21 -04002447 auto IntTy = Type::getInt32Ty(M.getContext());
2448 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002449 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04002450 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002451
2452 // Our intrinsic to pack a float2 to an int.
2453 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2454
2455 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2456
2457 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04002458 auto TempVec =
2459 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
2460 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002461
2462 // Pack the float2 -> half2 (in an int).
2463 auto X = CallInst::Create(NewF, TempVec, "", CI);
2464
David Neto482550a2018-03-24 05:21:07 -07002465 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002466 auto ShortTy = Type::getInt16Ty(M.getContext());
2467 auto ShortPointerTy = PointerType::get(
2468 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002469
David Neto17852de2017-05-29 17:29:31 -04002470 // Truncate our i32 to an i16.
2471 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002472
David Neto17852de2017-05-29 17:29:31 -04002473 // Cast the half* pointer to short*.
2474 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002475
David Neto17852de2017-05-29 17:29:31 -04002476 // Index into the correct address of the casted pointer.
2477 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002478
David Neto17852de2017-05-29 17:29:31 -04002479 // Store to the int* we casted to.
2480 auto Store = new StoreInst(Trunc, Index, CI);
2481
2482 CI->replaceAllUsesWith(Store);
2483 } else {
2484 // We can only write to 32-bit aligned words.
2485 //
2486 // Assuming base is aligned to 32-bits, replace the equivalent of
2487 // vstore_half(value, index, base)
2488 // with:
2489 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2490 // uint32_t write_to_upper_half = index & 1u;
2491 // uint32_t shift = write_to_upper_half << 4;
2492 //
2493 // // Pack the float value as a half number in bottom 16 bits
2494 // // of an i32.
2495 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2496 //
2497 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2498 // ^ ((packed & 0xffff) << shift)
2499 // // We only need relaxed consistency, but OpenCL 1.2 only has
2500 // // sequentially consistent atomics.
2501 // // TODO(dneto): Use relaxed consistency.
2502 // atomic_xor(target_ptr, xor_value)
2503 auto IntPointerTy = PointerType::get(
2504 IntTy, Arg2->getType()->getPointerAddressSpace());
2505
2506 auto Four = ConstantInt::get(IntTy, 4);
2507 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2508
2509 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2510 // Compute index / 2
2511 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2512 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2513 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2514 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2515 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2516 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2517 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2518
2519 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2520 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2521 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2522
2523 // Generate the call to atomi_xor.
2524 SmallVector<Type *, 5> ParamTypes;
2525 // The pointer type.
2526 ParamTypes.push_back(IntPointerTy);
2527 // The Types for memory scope, semantics, and value.
2528 ParamTypes.push_back(IntTy);
2529 ParamTypes.push_back(IntTy);
2530 ParamTypes.push_back(IntTy);
2531 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2532 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2533
2534 const auto ConstantScopeDevice =
2535 ConstantInt::get(IntTy, spv::ScopeDevice);
2536 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2537 // (SPIR-V Workgroup).
2538 const auto AddrSpaceSemanticsBits =
2539 IntPointerTy->getPointerAddressSpace() == 1
2540 ? spv::MemorySemanticsUniformMemoryMask
2541 : spv::MemorySemanticsWorkgroupMemoryMask;
2542
2543 // We're using relaxed consistency here.
2544 const auto ConstantMemorySemantics =
2545 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2546 AddrSpaceSemanticsBits);
2547
2548 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2549 ConstantMemorySemantics, ValueToXor};
2550 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2551 }
David Neto22f144c2017-06-12 14:26:21 -04002552
2553 // Lastly, remember to remove the user.
2554 ToRemoves.push_back(CI);
2555 }
2556 }
2557
2558 Changed = !ToRemoves.empty();
2559
2560 // And cleanup the calls we don't use anymore.
2561 for (auto V : ToRemoves) {
2562 V->eraseFromParent();
2563 }
2564
2565 // And remove the function we don't need either too.
2566 F->eraseFromParent();
2567 }
2568 }
2569
2570 return Changed;
2571}
2572
2573bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2574 bool Changed = false;
2575
David Netoe2871522018-06-08 11:09:54 -07002576 const std::vector<const char *> Map = {
2577 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2578 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2579 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2580 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2581 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2582 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2583 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2584 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2585 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2586 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2587 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2588 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2589 };
David Neto22f144c2017-06-12 14:26:21 -04002590
2591 for (auto Name : Map) {
2592 // If we find a function with the matching name.
2593 if (auto F = M.getFunction(Name)) {
2594 SmallVector<Instruction *, 4> ToRemoves;
2595
2596 // Walk the users of the function.
2597 for (auto &U : F->uses()) {
2598 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2599 // The value to store.
2600 auto Arg0 = CI->getOperand(0);
2601
2602 // The index argument from vstore_half.
2603 auto Arg1 = CI->getOperand(1);
2604
2605 // The pointer argument from vstore_half.
2606 auto Arg2 = CI->getOperand(2);
2607
2608 auto IntTy = Type::getInt32Ty(M.getContext());
2609 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2610 auto NewPointerTy = PointerType::get(
2611 IntTy, Arg2->getType()->getPointerAddressSpace());
2612 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2613
2614 // Our intrinsic to pack a float2 to an int.
2615 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2616
2617 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2618
2619 // Turn the packed x & y into the final packing.
2620 auto X = CallInst::Create(NewF, Arg0, "", CI);
2621
2622 // Cast the half* pointer to int*.
2623 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2624
2625 // Index into the correct address of the casted pointer.
2626 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2627
2628 // Store to the int* we casted to.
2629 auto Store = new StoreInst(X, Index, CI);
2630
2631 CI->replaceAllUsesWith(Store);
2632
2633 // Lastly, remember to remove the user.
2634 ToRemoves.push_back(CI);
2635 }
2636 }
2637
2638 Changed = !ToRemoves.empty();
2639
2640 // And cleanup the calls we don't use anymore.
2641 for (auto V : ToRemoves) {
2642 V->eraseFromParent();
2643 }
2644
2645 // And remove the function we don't need either too.
2646 F->eraseFromParent();
2647 }
2648 }
2649
2650 return Changed;
2651}
2652
2653bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2654 bool Changed = false;
2655
David Netoe2871522018-06-08 11:09:54 -07002656 const std::vector<const char *> Map = {
2657 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2658 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2659 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2660 "_Z13vstorea_half4Dv4_fjPDh", // private
2661 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2662 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2663 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2664 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2665 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2666 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2667 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2668 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2669 };
David Neto22f144c2017-06-12 14:26:21 -04002670
2671 for (auto Name : Map) {
2672 // If we find a function with the matching name.
2673 if (auto F = M.getFunction(Name)) {
2674 SmallVector<Instruction *, 4> ToRemoves;
2675
2676 // Walk the users of the function.
2677 for (auto &U : F->uses()) {
2678 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2679 // The value to store.
2680 auto Arg0 = CI->getOperand(0);
2681
2682 // The index argument from vstore_half.
2683 auto Arg1 = CI->getOperand(1);
2684
2685 // The pointer argument from vstore_half.
2686 auto Arg2 = CI->getOperand(2);
2687
2688 auto IntTy = Type::getInt32Ty(M.getContext());
2689 auto Int2Ty = VectorType::get(IntTy, 2);
2690 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2691 auto NewPointerTy = PointerType::get(
2692 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2693 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2694
2695 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2696 ConstantInt::get(IntTy, 1)};
2697
2698 // Extract out the x & y components of our to store value.
2699 auto Lo =
2700 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2701 ConstantVector::get(LoShuffleMask), "", CI);
2702
2703 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2704 ConstantInt::get(IntTy, 3)};
2705
2706 // Extract out the z & w components of our to store value.
2707 auto Hi =
2708 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2709 ConstantVector::get(HiShuffleMask), "", CI);
2710
2711 // Our intrinsic to pack a float2 to an int.
2712 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2713
2714 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2715
2716 // Turn the packed x & y into the final component of our int2.
2717 auto X = CallInst::Create(NewF, Lo, "", CI);
2718
2719 // Turn the packed z & w into the final component of our int2.
2720 auto Y = CallInst::Create(NewF, Hi, "", CI);
2721
2722 auto Combine = InsertElementInst::Create(
2723 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2724 Combine = InsertElementInst::Create(
2725 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2726
2727 // Cast the half* pointer to int2*.
2728 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2729
2730 // Index into the correct address of the casted pointer.
2731 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2732
2733 // Store to the int2* we casted to.
2734 auto Store = new StoreInst(Combine, Index, CI);
2735
2736 CI->replaceAllUsesWith(Store);
2737
2738 // Lastly, remember to remove the user.
2739 ToRemoves.push_back(CI);
2740 }
2741 }
2742
2743 Changed = !ToRemoves.empty();
2744
2745 // And cleanup the calls we don't use anymore.
2746 for (auto V : ToRemoves) {
2747 V->eraseFromParent();
2748 }
2749
2750 // And remove the function we don't need either too.
2751 F->eraseFromParent();
2752 }
2753 }
2754
2755 return Changed;
2756}
2757
2758bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2759 bool Changed = false;
2760
2761 const std::map<const char *, const char*> Map = {
2762 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2763 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2764 };
2765
2766 for (auto Pair : Map) {
2767 // If we find a function with the matching name.
2768 if (auto F = M.getFunction(Pair.first)) {
2769 SmallVector<Instruction *, 4> ToRemoves;
2770
2771 // Walk the users of the function.
2772 for (auto &U : F->uses()) {
2773 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2774 // The image.
2775 auto Arg0 = CI->getOperand(0);
2776
2777 // The sampler.
2778 auto Arg1 = CI->getOperand(1);
2779
2780 // The coordinate (integer type that we can't handle).
2781 auto Arg2 = CI->getOperand(2);
2782
2783 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2784
2785 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2786
2787 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2788
2789 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2790
2791 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2792
2793 CI->replaceAllUsesWith(NewCI);
2794
2795 // Lastly, remember to remove the user.
2796 ToRemoves.push_back(CI);
2797 }
2798 }
2799
2800 Changed = !ToRemoves.empty();
2801
2802 // And cleanup the calls we don't use anymore.
2803 for (auto V : ToRemoves) {
2804 V->eraseFromParent();
2805 }
2806
2807 // And remove the function we don't need either too.
2808 F->eraseFromParent();
2809 }
2810 }
2811
2812 return Changed;
2813}
2814
2815bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2816 bool Changed = false;
2817
2818 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002819 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002820 {"_Z8atom_incPU3AS3Vi", "spirv.atomic_inc"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002821 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002822 {"_Z8atom_incPU3AS3Vj", "spirv.atomic_inc"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002823 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002824 {"_Z8atom_decPU3AS3Vi", "spirv.atomic_dec"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002825 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002826 {"_Z8atom_decPU3AS3Vj", "spirv.atomic_dec"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002827 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002828 {"_Z12atom_cmpxchgPU3AS3Viii", "spirv.atomic_compare_exchange"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002829 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002830 {"_Z12atom_cmpxchgPU3AS3Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002831 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002832 {"_Z10atomic_incPU3AS3Vi", "spirv.atomic_inc"},
David Neto22f144c2017-06-12 14:26:21 -04002833 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002834 {"_Z10atomic_incPU3AS3Vj", "spirv.atomic_inc"},
David Neto22f144c2017-06-12 14:26:21 -04002835 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002836 {"_Z10atomic_decPU3AS3Vi", "spirv.atomic_dec"},
David Neto22f144c2017-06-12 14:26:21 -04002837 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002838 {"_Z10atomic_decPU3AS3Vj", "spirv.atomic_dec"},
David Neto22f144c2017-06-12 14:26:21 -04002839 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002840 {"_Z14atomic_cmpxchgPU3AS3Viii", "spirv.atomic_compare_exchange"},
2841 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
2842 {"_Z14atomic_cmpxchgPU3AS3Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002843
2844 for (auto Pair : Map) {
2845 // If we find a function with the matching name.
2846 if (auto F = M.getFunction(Pair.first)) {
2847 SmallVector<Instruction *, 4> ToRemoves;
2848
2849 // Walk the users of the function.
2850 for (auto &U : F->uses()) {
2851 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2852 auto FType = F->getFunctionType();
2853 SmallVector<Type *, 5> ParamTypes;
2854
2855 // The pointer type.
2856 ParamTypes.push_back(FType->getParamType(0));
2857
2858 auto IntTy = Type::getInt32Ty(M.getContext());
2859
2860 // The memory scope type.
2861 ParamTypes.push_back(IntTy);
2862
2863 // The memory semantics type.
2864 ParamTypes.push_back(IntTy);
2865
2866 if (2 < CI->getNumArgOperands()) {
2867 // The unequal memory semantics type.
2868 ParamTypes.push_back(IntTy);
2869
2870 // The value type.
2871 ParamTypes.push_back(FType->getParamType(2));
2872
2873 // The comparator type.
2874 ParamTypes.push_back(FType->getParamType(1));
2875 } else if (1 < CI->getNumArgOperands()) {
2876 // The value type.
2877 ParamTypes.push_back(FType->getParamType(1));
2878 }
2879
2880 auto NewFType =
2881 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2882 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2883
2884 // We need to map the OpenCL constants to the SPIR-V equivalents.
2885 const auto ConstantScopeDevice =
2886 ConstantInt::get(IntTy, spv::ScopeDevice);
2887 const auto ConstantMemorySemantics = ConstantInt::get(
2888 IntTy, spv::MemorySemanticsUniformMemoryMask |
2889 spv::MemorySemanticsSequentiallyConsistentMask);
2890
2891 SmallVector<Value *, 5> Params;
2892
2893 // The pointer.
2894 Params.push_back(CI->getArgOperand(0));
2895
2896 // The memory scope.
2897 Params.push_back(ConstantScopeDevice);
2898
2899 // The memory semantics.
2900 Params.push_back(ConstantMemorySemantics);
2901
2902 if (2 < CI->getNumArgOperands()) {
2903 // The unequal memory semantics.
2904 Params.push_back(ConstantMemorySemantics);
2905
2906 // The value.
2907 Params.push_back(CI->getArgOperand(2));
2908
2909 // The comparator.
2910 Params.push_back(CI->getArgOperand(1));
2911 } else if (1 < CI->getNumArgOperands()) {
2912 // The value.
2913 Params.push_back(CI->getArgOperand(1));
2914 }
2915
2916 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2917
2918 CI->replaceAllUsesWith(NewCI);
2919
2920 // Lastly, remember to remove the user.
2921 ToRemoves.push_back(CI);
2922 }
2923 }
2924
2925 Changed = !ToRemoves.empty();
2926
2927 // And cleanup the calls we don't use anymore.
2928 for (auto V : ToRemoves) {
2929 V->eraseFromParent();
2930 }
2931
2932 // And remove the function we don't need either too.
2933 F->eraseFromParent();
2934 }
2935 }
2936
Neil Henning39672102017-09-29 14:33:13 +01002937 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002938 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002939 {"_Z8atom_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002940 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002941 {"_Z8atom_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002942 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002943 {"_Z8atom_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002944 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002945 {"_Z8atom_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002946 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002947 {"_Z9atom_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002948 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002949 {"_Z9atom_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002950 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002951 {"_Z8atom_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002952 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002953 {"_Z8atom_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002954 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002955 {"_Z8atom_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002956 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002957 {"_Z8atom_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002958 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002959 {"_Z8atom_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002960 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002961 {"_Z8atom_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002962 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002963 {"_Z7atom_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002964 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002965 {"_Z7atom_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002966 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002967 {"_Z8atom_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002968 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002969 {"_Z8atom_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002970 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002971 {"_Z10atomic_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002972 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002973 {"_Z10atomic_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002974 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002975 {"_Z10atomic_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002976 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002977 {"_Z10atomic_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002978 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002979 {"_Z11atomic_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002980 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002981 {"_Z11atomic_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002982 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002983 {"_Z10atomic_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Neil Henning39672102017-09-29 14:33:13 +01002984 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002985 {"_Z10atomic_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Neil Henning39672102017-09-29 14:33:13 +01002986 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002987 {"_Z10atomic_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Neil Henning39672102017-09-29 14:33:13 +01002988 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002989 {"_Z10atomic_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Neil Henning39672102017-09-29 14:33:13 +01002990 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002991 {"_Z10atomic_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002992 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002993 {"_Z10atomic_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002994 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002995 {"_Z9atomic_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002996 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002997 {"_Z9atomic_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002998 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002999 {"_Z10atomic_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
3000 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
3001 {"_Z10atomic_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor}};
Neil Henning39672102017-09-29 14:33:13 +01003002
3003 for (auto Pair : Map2) {
3004 // If we find a function with the matching name.
3005 if (auto F = M.getFunction(Pair.first)) {
3006 SmallVector<Instruction *, 4> ToRemoves;
3007
3008 // Walk the users of the function.
3009 for (auto &U : F->uses()) {
3010 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3011 auto AtomicOp = new AtomicRMWInst(
3012 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
3013 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
3014
3015 CI->replaceAllUsesWith(AtomicOp);
3016
3017 // Lastly, remember to remove the user.
3018 ToRemoves.push_back(CI);
3019 }
3020 }
3021
3022 Changed = !ToRemoves.empty();
3023
3024 // And cleanup the calls we don't use anymore.
3025 for (auto V : ToRemoves) {
3026 V->eraseFromParent();
3027 }
3028
3029 // And remove the function we don't need either too.
3030 F->eraseFromParent();
3031 }
3032 }
3033
David Neto22f144c2017-06-12 14:26:21 -04003034 return Changed;
3035}
3036
3037bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
3038 bool Changed = false;
3039
3040 // If we find a function with the matching name.
3041 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
3042 SmallVector<Instruction *, 4> ToRemoves;
3043
3044 auto IntTy = Type::getInt32Ty(M.getContext());
3045 auto FloatTy = Type::getFloatTy(M.getContext());
3046
3047 Constant *DownShuffleMask[3] = {
3048 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
3049 ConstantInt::get(IntTy, 2)};
3050
3051 Constant *UpShuffleMask[4] = {
3052 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
3053 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
3054
3055 Constant *FloatVec[3] = {
3056 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
3057 };
3058
3059 // Walk the users of the function.
3060 for (auto &U : F->uses()) {
3061 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3062 auto Vec4Ty = CI->getArgOperand(0)->getType();
3063 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
3064 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
3065 auto Vec3Ty = Arg0->getType();
3066
3067 auto NewFType =
3068 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
3069
3070 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
3071
3072 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
3073
3074 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
3075
3076 CI->replaceAllUsesWith(Result);
3077
3078 // Lastly, remember to remove the user.
3079 ToRemoves.push_back(CI);
3080 }
3081 }
3082
3083 Changed = !ToRemoves.empty();
3084
3085 // And cleanup the calls we don't use anymore.
3086 for (auto V : ToRemoves) {
3087 V->eraseFromParent();
3088 }
3089
3090 // And remove the function we don't need either too.
3091 F->eraseFromParent();
3092 }
3093
3094 return Changed;
3095}
David Neto62653202017-10-16 19:05:18 -04003096
3097bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
3098 bool Changed = false;
3099
3100 // OpenCL's float result = fract(float x, float* ptr)
3101 //
3102 // In the LLVM domain:
3103 //
3104 // %floor_result = call spir_func float @floor(float %x)
3105 // store float %floor_result, float * %ptr
3106 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
3107 // %result = call spir_func float
3108 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
3109 //
3110 // Becomes in the SPIR-V domain, where translations of floor, fmin,
3111 // and clspv.fract occur in the SPIR-V generator pass:
3112 //
3113 // %glsl_ext = OpExtInstImport "GLSL.std.450"
3114 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
3115 // ...
3116 // %floor_result = OpExtInst %float %glsl_ext Floor %x
3117 // OpStore %ptr %floor_result
3118 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
3119 // %fract_result = OpExtInst %float
3120 // %glsl_ext Fmin %fract_intermediate %just_under_1
3121
3122
3123 using std::string;
3124
3125 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
3126 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
3127 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
3128 auto make_quad = [](const char *a, const char *b, const char *c,
3129 const char *d) {
3130 return std::tuple<const char *, const char *, const char *, const char *>(
3131 a, b, c, d);
3132 };
3133 const std::vector<QuadType> Functions = {
3134 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
3135 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
3136 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
3137 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
3138 };
3139
3140 for (auto& quad : Functions) {
3141 const StringRef fract_name(std::get<0>(quad));
3142
3143 // If we find a function with the matching name.
3144 if (auto F = M.getFunction(fract_name)) {
3145 if (F->use_begin() == F->use_end())
3146 continue;
3147
3148 // We have some uses.
3149 Changed = true;
3150
3151 auto& Context = M.getContext();
3152
3153 const StringRef floor_name(std::get<1>(quad));
3154 const StringRef fmin_name(std::get<2>(quad));
3155 const StringRef clspv_fract_name(std::get<3>(quad));
3156
3157 // This is either float or a float vector. All the float-like
3158 // types are this type.
3159 auto result_ty = F->getReturnType();
3160
3161 Function* fmin_fn = M.getFunction(fmin_name);
3162 if (!fmin_fn) {
3163 // Make the fmin function.
3164 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003165 fmin_fn =
3166 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003167 fmin_fn->addFnAttr(Attribute::ReadNone);
3168 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
3169 }
3170
3171 Function* floor_fn = M.getFunction(floor_name);
3172 if (!floor_fn) {
3173 // Make the floor function.
3174 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003175 floor_fn = cast<Function>(
3176 M.getOrInsertFunction(floor_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003177 floor_fn->addFnAttr(Attribute::ReadNone);
3178 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
3179 }
3180
3181 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
3182 if (!clspv_fract_fn) {
3183 // Make the clspv_fract function.
3184 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003185 clspv_fract_fn = cast<Function>(
3186 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003187 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
3188 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
3189 }
3190
3191 // Number of significant significand bits, whether represented or not.
3192 unsigned num_significand_bits;
3193 switch (result_ty->getScalarType()->getTypeID()) {
3194 case Type::HalfTyID:
3195 num_significand_bits = 11;
3196 break;
3197 case Type::FloatTyID:
3198 num_significand_bits = 24;
3199 break;
3200 case Type::DoubleTyID:
3201 num_significand_bits = 53;
3202 break;
3203 default:
3204 assert(false && "Unhandled float type when processing fract builtin");
3205 break;
3206 }
3207 // Beware that the disassembler displays this value as
3208 // OpConstant %float 1
3209 // which is not quite right.
3210 const double kJustUnderOneScalar =
3211 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3212
3213 Constant *just_under_one =
3214 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3215 if (result_ty->isVectorTy()) {
3216 just_under_one = ConstantVector::getSplat(
3217 result_ty->getVectorNumElements(), just_under_one);
3218 }
3219
3220 IRBuilder<> Builder(Context);
3221
3222 SmallVector<Instruction *, 4> ToRemoves;
3223
3224 // Walk the users of the function.
3225 for (auto &U : F->uses()) {
3226 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3227
3228 Builder.SetInsertPoint(CI);
3229 auto arg = CI->getArgOperand(0);
3230 auto ptr = CI->getArgOperand(1);
3231
3232 // Compute floor result and store it.
3233 auto floor = Builder.CreateCall(floor_fn, {arg});
3234 Builder.CreateStore(floor, ptr);
3235
3236 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
3237 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
3238
3239 CI->replaceAllUsesWith(fract_result);
3240
3241 // Lastly, remember to remove the user.
3242 ToRemoves.push_back(CI);
3243 }
3244 }
3245
3246 // And cleanup the calls we don't use anymore.
3247 for (auto V : ToRemoves) {
3248 V->eraseFromParent();
3249 }
3250
3251 // And remove the function we don't need either too.
3252 F->eraseFromParent();
3253 }
3254 }
3255
3256 return Changed;
3257}