blob: 72e584afe4fb264157d8ed95a98c47c87dd8531a [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
David Neto118188e2018-08-24 11:27:54 -040019#include "llvm/IR/Constants.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000023#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/Pass.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/raw_ostream.h"
27#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040028
David Neto118188e2018-08-24 11:27:54 -040029#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040030
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Neto22f144c2017-06-12 14:26:21 -040033using namespace llvm;
34
35#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
36
37namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000038
39struct ArgTypeInfo {
40 enum class SignedNess {
41 Unsigned,
42 Signed
43 };
44 SignedNess signedness;
45};
46
47struct FunctionInfo {
48 std::vector<ArgTypeInfo> argTypeInfos;
49};
50
51bool getFunctionInfoFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
52 if (!name.consume_front("_Z")) {
53 return false;
54 }
55 size_t nameLen;
56 if (name.consumeInteger(10, nameLen)) {
57 return false;
58 }
59
60 name = name.drop_front(nameLen);
61
62 ArgTypeInfo prev_ti;
63
64 while (name.size() != 0) {
65
66 ArgTypeInfo ti;
67
68 // Try parsing a vector prefix
69 if (name.consume_front("Dv")) {
70 int numElems;
71 if (name.consumeInteger(10, numElems)) {
72 return false;
73 }
74
75 if (!name.consume_front("_")) {
76 return false;
77 }
78 }
79
80 // Parse the base type
81 char typeCode = name.front();
82 name = name.drop_front(1);
83 switch(typeCode) {
84 case 'c': // char
85 case 'a': // signed char
86 case 's': // short
87 case 'i': // int
88 case 'l': // long
89 ti.signedness = ArgTypeInfo::SignedNess::Signed;
90 break;
91 case 'h': // unsigned char
92 case 't': // unsigned short
93 case 'j': // unsigned int
94 case 'm': // unsigned long
95 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
96 break;
97 case 'S':
98 ti = prev_ti;
99 if (!name.consume_front("_")) {
100 return false;
101 }
102 break;
103 default:
104 return false;
105 }
106
107 finfo->argTypeInfos.push_back(ti);
108
109 prev_ti = ti;
110 }
111
112 return true;
113};
114
115void getFunctionInfoFromMangledName(StringRef name, FunctionInfo *finfo) {
116 if (!getFunctionInfoFromMangledNameCheck(name, finfo)) {
117 llvm_unreachable("Can't parse mangled function name!");
118 }
119}
120
David Neto22f144c2017-06-12 14:26:21 -0400121uint32_t clz(uint32_t v) {
122 uint32_t r;
123 uint32_t shift;
124
125 r = (v > 0xFFFF) << 4;
126 v >>= r;
127 shift = (v > 0xFF) << 3;
128 v >>= shift;
129 r |= shift;
130 shift = (v > 0xF) << 2;
131 v >>= shift;
132 r |= shift;
133 shift = (v > 0x3) << 1;
134 v >>= shift;
135 r |= shift;
136 r |= (v >> 1);
137
138 return r;
139}
140
141Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
142 if (1 == elements) {
143 return Type::getInt1Ty(C);
144 } else {
145 return VectorType::get(Type::getInt1Ty(C), elements);
146 }
147}
148
149struct ReplaceOpenCLBuiltinPass final : public ModulePass {
150 static char ID;
151 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
152
153 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000154 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400155 bool replaceRecip(Module &M);
156 bool replaceDivide(Module &M);
157 bool replaceExp10(Module &M);
158 bool replaceLog10(Module &M);
159 bool replaceBarrier(Module &M);
160 bool replaceMemFence(Module &M);
161 bool replaceRelational(Module &M);
162 bool replaceIsInfAndIsNan(Module &M);
163 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000164 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000165 bool replaceRotate(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000166 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000167 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000168 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000169 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400170 bool replaceSignbit(Module &M);
171 bool replaceMadandMad24andMul24(Module &M);
172 bool replaceVloadHalf(Module &M);
173 bool replaceVloadHalf2(Module &M);
174 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700175 bool replaceClspvVloadaHalf2(Module &M);
176 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400177 bool replaceVstoreHalf(Module &M);
178 bool replaceVstoreHalf2(Module &M);
179 bool replaceVstoreHalf4(Module &M);
180 bool replaceReadImageF(Module &M);
181 bool replaceAtomics(Module &M);
182 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400183 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700184 bool replaceVload(Module &M);
185 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400186};
187}
188
189char ReplaceOpenCLBuiltinPass::ID = 0;
190static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
191 "Replace OpenCL Builtins Pass");
192
193namespace clspv {
194ModulePass *createReplaceOpenCLBuiltinPass() {
195 return new ReplaceOpenCLBuiltinPass();
196}
197}
198
199bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
200 bool Changed = false;
201
Kévin Petit2444e9b2018-11-09 14:14:37 +0000202 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400203 Changed |= replaceRecip(M);
204 Changed |= replaceDivide(M);
205 Changed |= replaceExp10(M);
206 Changed |= replaceLog10(M);
207 Changed |= replaceBarrier(M);
208 Changed |= replaceMemFence(M);
209 Changed |= replaceRelational(M);
210 Changed |= replaceIsInfAndIsNan(M);
211 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000212 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000213 Changed |= replaceRotate(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000214 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000215 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000216 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000217 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400218 Changed |= replaceSignbit(M);
219 Changed |= replaceMadandMad24andMul24(M);
220 Changed |= replaceVloadHalf(M);
221 Changed |= replaceVloadHalf2(M);
222 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700223 Changed |= replaceClspvVloadaHalf2(M);
224 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400225 Changed |= replaceVstoreHalf(M);
226 Changed |= replaceVstoreHalf2(M);
227 Changed |= replaceVstoreHalf4(M);
228 Changed |= replaceReadImageF(M);
229 Changed |= replaceAtomics(M);
230 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400231 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700232 Changed |= replaceVload(M);
233 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400234
235 return Changed;
236}
237
Kévin Petit2444e9b2018-11-09 14:14:37 +0000238bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
239 bool Changed = false;
240
241 const char *Names[] = {
242 "_Z3abst",
243 "_Z3absDv2_t",
244 "_Z3absDv3_t",
245 "_Z3absDv4_t",
246 "_Z3absj",
247 "_Z3absDv2_j",
248 "_Z3absDv3_j",
249 "_Z3absDv4_j",
250 "_Z3absm",
251 "_Z3absDv2_m",
252 "_Z3absDv3_m",
253 "_Z3absDv4_m",
254 };
255
256 for (auto Name : Names) {
257 // If we find a function with the matching name.
258 if (auto F = M.getFunction(Name)) {
259 SmallVector<Instruction *, 4> ToRemoves;
260
261 // Walk the users of the function.
262 for (auto &U : F->uses()) {
263 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
264 // Abs has one arg.
265 auto Arg = CI->getOperand(0);
266
267 // Use the argument unchanged, we know it's unsigned
268 CI->replaceAllUsesWith(Arg);
269
270 // Lastly, remember to remove the user.
271 ToRemoves.push_back(CI);
272 }
273 }
274
275 Changed = !ToRemoves.empty();
276
277 // And cleanup the calls we don't use anymore.
278 for (auto V : ToRemoves) {
279 V->eraseFromParent();
280 }
281
282 // And remove the function we don't need either too.
283 F->eraseFromParent();
284 }
285 }
286
287 return Changed;
288}
289
David Neto22f144c2017-06-12 14:26:21 -0400290bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
291 bool Changed = false;
292
293 const char *Names[] = {
294 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
295 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
296 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
297 };
298
299 for (auto Name : Names) {
300 // If we find a function with the matching name.
301 if (auto F = M.getFunction(Name)) {
302 SmallVector<Instruction *, 4> ToRemoves;
303
304 // Walk the users of the function.
305 for (auto &U : F->uses()) {
306 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
307 // Recip has one arg.
308 auto Arg = CI->getOperand(0);
309
310 auto Div = BinaryOperator::Create(
311 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
312 CI);
313
314 CI->replaceAllUsesWith(Div);
315
316 // Lastly, remember to remove the user.
317 ToRemoves.push_back(CI);
318 }
319 }
320
321 Changed = !ToRemoves.empty();
322
323 // And cleanup the calls we don't use anymore.
324 for (auto V : ToRemoves) {
325 V->eraseFromParent();
326 }
327
328 // And remove the function we don't need either too.
329 F->eraseFromParent();
330 }
331 }
332
333 return Changed;
334}
335
336bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
337 bool Changed = false;
338
339 const char *Names[] = {
340 "_Z11half_divideff", "_Z13native_divideff",
341 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
342 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
343 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
344 };
345
346 for (auto Name : Names) {
347 // If we find a function with the matching name.
348 if (auto F = M.getFunction(Name)) {
349 SmallVector<Instruction *, 4> ToRemoves;
350
351 // Walk the users of the function.
352 for (auto &U : F->uses()) {
353 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
354 auto Div = BinaryOperator::Create(
355 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
356
357 CI->replaceAllUsesWith(Div);
358
359 // Lastly, remember to remove the user.
360 ToRemoves.push_back(CI);
361 }
362 }
363
364 Changed = !ToRemoves.empty();
365
366 // And cleanup the calls we don't use anymore.
367 for (auto V : ToRemoves) {
368 V->eraseFromParent();
369 }
370
371 // And remove the function we don't need either too.
372 F->eraseFromParent();
373 }
374 }
375
376 return Changed;
377}
378
379bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
380 bool Changed = false;
381
382 const std::map<const char *, const char *> Map = {
383 {"_Z5exp10f", "_Z3expf"},
384 {"_Z10half_exp10f", "_Z8half_expf"},
385 {"_Z12native_exp10f", "_Z10native_expf"},
386 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
387 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
388 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
389 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
390 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
391 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
392 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
393 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
394 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
395
396 for (auto Pair : Map) {
397 // If we find a function with the matching name.
398 if (auto F = M.getFunction(Pair.first)) {
399 SmallVector<Instruction *, 4> ToRemoves;
400
401 // Walk the users of the function.
402 for (auto &U : F->uses()) {
403 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
404 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
405
406 auto Arg = CI->getOperand(0);
407
408 // Constant of the natural log of 10 (ln(10)).
409 const double Ln10 =
410 2.302585092994045684017991454684364207601101488628772976033;
411
412 auto Mul = BinaryOperator::Create(
413 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
414 CI);
415
416 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
417
418 CI->replaceAllUsesWith(NewCI);
419
420 // Lastly, remember to remove the user.
421 ToRemoves.push_back(CI);
422 }
423 }
424
425 Changed = !ToRemoves.empty();
426
427 // And cleanup the calls we don't use anymore.
428 for (auto V : ToRemoves) {
429 V->eraseFromParent();
430 }
431
432 // And remove the function we don't need either too.
433 F->eraseFromParent();
434 }
435 }
436
437 return Changed;
438}
439
440bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
441 bool Changed = false;
442
443 const std::map<const char *, const char *> Map = {
444 {"_Z5log10f", "_Z3logf"},
445 {"_Z10half_log10f", "_Z8half_logf"},
446 {"_Z12native_log10f", "_Z10native_logf"},
447 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
448 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
449 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
450 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
451 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
452 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
453 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
454 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
455 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
456
457 for (auto Pair : Map) {
458 // If we find a function with the matching name.
459 if (auto F = M.getFunction(Pair.first)) {
460 SmallVector<Instruction *, 4> ToRemoves;
461
462 // Walk the users of the function.
463 for (auto &U : F->uses()) {
464 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
465 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
466
467 auto Arg = CI->getOperand(0);
468
469 // Constant of the reciprocal of the natural log of 10 (ln(10)).
470 const double Ln10 =
471 0.434294481903251827651128918916605082294397005803666566114;
472
473 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
474
475 auto Mul = BinaryOperator::Create(
476 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
477 "", CI);
478
479 CI->replaceAllUsesWith(Mul);
480
481 // Lastly, remember to remove the user.
482 ToRemoves.push_back(CI);
483 }
484 }
485
486 Changed = !ToRemoves.empty();
487
488 // And cleanup the calls we don't use anymore.
489 for (auto V : ToRemoves) {
490 V->eraseFromParent();
491 }
492
493 // And remove the function we don't need either too.
494 F->eraseFromParent();
495 }
496 }
497
498 return Changed;
499}
500
501bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
502 bool Changed = false;
503
504 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
505
506 const std::map<const char *, const char *> Map = {
507 {"_Z7barrierj", "__spirv_control_barrier"}};
508
509 for (auto Pair : Map) {
510 // If we find a function with the matching name.
511 if (auto F = M.getFunction(Pair.first)) {
512 SmallVector<Instruction *, 4> ToRemoves;
513
514 // Walk the users of the function.
515 for (auto &U : F->uses()) {
516 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
517 auto FType = F->getFunctionType();
518 SmallVector<Type *, 3> Params;
519 for (unsigned i = 0; i < 3; i++) {
520 Params.push_back(FType->getParamType(0));
521 }
522 auto NewFType =
523 FunctionType::get(FType->getReturnType(), Params, false);
524 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
525
526 auto Arg = CI->getOperand(0);
527
528 // We need to map the OpenCL constants to the SPIR-V equivalents.
529 const auto LocalMemFence =
530 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
531 const auto GlobalMemFence =
532 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
533 const auto ConstantSequentiallyConsistent = ConstantInt::get(
534 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
535 const auto ConstantScopeDevice =
536 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
537 const auto ConstantScopeWorkgroup =
538 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
539
540 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
541 const auto LocalMemFenceMask = BinaryOperator::Create(
542 Instruction::And, LocalMemFence, Arg, "", CI);
543 const auto WorkgroupShiftAmount =
544 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
545 clz(CLK_LOCAL_MEM_FENCE);
546 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
547 Instruction::Shl, LocalMemFenceMask,
548 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
549
550 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
551 const auto GlobalMemFenceMask = BinaryOperator::Create(
552 Instruction::And, GlobalMemFence, Arg, "", CI);
553 const auto UniformShiftAmount =
554 clz(spv::MemorySemanticsUniformMemoryMask) -
555 clz(CLK_GLOBAL_MEM_FENCE);
556 const auto MemorySemanticsUniform = BinaryOperator::Create(
557 Instruction::Shl, GlobalMemFenceMask,
558 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
559
560 // And combine the above together, also adding in
561 // MemorySemanticsSequentiallyConsistentMask.
562 auto MemorySemantics =
563 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
564 ConstantSequentiallyConsistent, "", CI);
565 MemorySemantics = BinaryOperator::Create(
566 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
567
568 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
569 // Device Scope, otherwise Workgroup Scope.
570 const auto Cmp =
571 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
572 GlobalMemFenceMask, GlobalMemFence, "", CI);
573 const auto MemoryScope = SelectInst::Create(
574 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
575
576 // Lastly, the Execution Scope is always Workgroup Scope.
577 const auto ExecutionScope = ConstantScopeWorkgroup;
578
579 auto NewCI = CallInst::Create(
580 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
581
582 CI->replaceAllUsesWith(NewCI);
583
584 // Lastly, remember to remove the user.
585 ToRemoves.push_back(CI);
586 }
587 }
588
589 Changed = !ToRemoves.empty();
590
591 // And cleanup the calls we don't use anymore.
592 for (auto V : ToRemoves) {
593 V->eraseFromParent();
594 }
595
596 // And remove the function we don't need either too.
597 F->eraseFromParent();
598 }
599 }
600
601 return Changed;
602}
603
604bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
605 bool Changed = false;
606
607 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
608
Neil Henning39672102017-09-29 14:33:13 +0100609 using Tuple = std::tuple<const char *, unsigned>;
610 const std::map<const char *, Tuple> Map = {
611 {"_Z9mem_fencej",
612 Tuple("__spirv_memory_barrier",
613 spv::MemorySemanticsSequentiallyConsistentMask)},
614 {"_Z14read_mem_fencej",
615 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
616 {"_Z15write_mem_fencej",
617 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400618
619 for (auto Pair : Map) {
620 // If we find a function with the matching name.
621 if (auto F = M.getFunction(Pair.first)) {
622 SmallVector<Instruction *, 4> ToRemoves;
623
624 // Walk the users of the function.
625 for (auto &U : F->uses()) {
626 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
627 auto FType = F->getFunctionType();
628 SmallVector<Type *, 2> Params;
629 for (unsigned i = 0; i < 2; i++) {
630 Params.push_back(FType->getParamType(0));
631 }
632 auto NewFType =
633 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100634 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400635
636 auto Arg = CI->getOperand(0);
637
638 // We need to map the OpenCL constants to the SPIR-V equivalents.
639 const auto LocalMemFence =
640 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
641 const auto GlobalMemFence =
642 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
643 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100644 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400645 const auto ConstantScopeDevice =
646 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
647
648 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
649 const auto LocalMemFenceMask = BinaryOperator::Create(
650 Instruction::And, LocalMemFence, Arg, "", CI);
651 const auto WorkgroupShiftAmount =
652 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
653 clz(CLK_LOCAL_MEM_FENCE);
654 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
655 Instruction::Shl, LocalMemFenceMask,
656 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
657
658 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
659 const auto GlobalMemFenceMask = BinaryOperator::Create(
660 Instruction::And, GlobalMemFence, Arg, "", CI);
661 const auto UniformShiftAmount =
662 clz(spv::MemorySemanticsUniformMemoryMask) -
663 clz(CLK_GLOBAL_MEM_FENCE);
664 const auto MemorySemanticsUniform = BinaryOperator::Create(
665 Instruction::Shl, GlobalMemFenceMask,
666 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
667
668 // And combine the above together, also adding in
669 // MemorySemanticsSequentiallyConsistentMask.
670 auto MemorySemantics =
671 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
672 ConstantMemorySemantics, "", CI);
673 MemorySemantics = BinaryOperator::Create(
674 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
675
676 // Memory Scope is always device.
677 const auto MemoryScope = ConstantScopeDevice;
678
679 auto NewCI =
680 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
681
682 CI->replaceAllUsesWith(NewCI);
683
684 // Lastly, remember to remove the user.
685 ToRemoves.push_back(CI);
686 }
687 }
688
689 Changed = !ToRemoves.empty();
690
691 // And cleanup the calls we don't use anymore.
692 for (auto V : ToRemoves) {
693 V->eraseFromParent();
694 }
695
696 // And remove the function we don't need either too.
697 F->eraseFromParent();
698 }
699 }
700
701 return Changed;
702}
703
704bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
705 bool Changed = false;
706
707 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
708 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
709 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
710 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
711 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
712 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
713 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
714 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
715 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
716 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
717 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
718 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
719 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
720 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
721 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
722 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
723 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
724 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
725 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
726 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
727 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
728 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
729 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
730 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
731 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
732 };
733
734 for (auto Pair : Map) {
735 // If we find a function with the matching name.
736 if (auto F = M.getFunction(Pair.first)) {
737 SmallVector<Instruction *, 4> ToRemoves;
738
739 // Walk the users of the function.
740 for (auto &U : F->uses()) {
741 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
742 // The predicate to use in the CmpInst.
743 auto Predicate = Pair.second.first;
744
745 // The value to return for true.
746 auto TrueValue =
747 ConstantInt::getSigned(CI->getType(), Pair.second.second);
748
749 // The value to return for false.
750 auto FalseValue = Constant::getNullValue(CI->getType());
751
752 auto Arg1 = CI->getOperand(0);
753 auto Arg2 = CI->getOperand(1);
754
755 const auto Cmp =
756 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
757
758 const auto Select =
759 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
760
761 CI->replaceAllUsesWith(Select);
762
763 // Lastly, remember to remove the user.
764 ToRemoves.push_back(CI);
765 }
766 }
767
768 Changed = !ToRemoves.empty();
769
770 // And cleanup the calls we don't use anymore.
771 for (auto V : ToRemoves) {
772 V->eraseFromParent();
773 }
774
775 // And remove the function we don't need either too.
776 F->eraseFromParent();
777 }
778 }
779
780 return Changed;
781}
782
783bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
784 bool Changed = false;
785
786 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
787 {"_Z5isinff", {"__spirv_isinff", 1}},
788 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
789 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
790 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
791 {"_Z5isnanf", {"__spirv_isnanf", 1}},
792 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
793 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
794 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
795 };
796
797 for (auto Pair : Map) {
798 // If we find a function with the matching name.
799 if (auto F = M.getFunction(Pair.first)) {
800 SmallVector<Instruction *, 4> ToRemoves;
801
802 // Walk the users of the function.
803 for (auto &U : F->uses()) {
804 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
805 const auto CITy = CI->getType();
806
807 // The fake SPIR-V intrinsic to generate.
808 auto SPIRVIntrinsic = Pair.second.first;
809
810 // The value to return for true.
811 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
812
813 // The value to return for false.
814 auto FalseValue = Constant::getNullValue(CITy);
815
816 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
817 M.getContext(),
818 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
819
820 auto NewFType =
821 FunctionType::get(CorrespondingBoolTy,
822 F->getFunctionType()->getParamType(0), false);
823
824 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
825
826 auto Arg = CI->getOperand(0);
827
828 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
829
830 const auto Select =
831 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
832
833 CI->replaceAllUsesWith(Select);
834
835 // Lastly, remember to remove the user.
836 ToRemoves.push_back(CI);
837 }
838 }
839
840 Changed = !ToRemoves.empty();
841
842 // And cleanup the calls we don't use anymore.
843 for (auto V : ToRemoves) {
844 V->eraseFromParent();
845 }
846
847 // And remove the function we don't need either too.
848 F->eraseFromParent();
849 }
850 }
851
852 return Changed;
853}
854
855bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
856 bool Changed = false;
857
858 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000859 // all
alan-bakerb39c8262019-03-08 14:03:37 -0500860 {"_Z3allc", ""},
861 {"_Z3allDv2_c", "__spirv_allDv2_c"},
862 {"_Z3allDv3_c", "__spirv_allDv3_c"},
863 {"_Z3allDv4_c", "__spirv_allDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000864 {"_Z3alls", ""},
865 {"_Z3allDv2_s", "__spirv_allDv2_s"},
866 {"_Z3allDv3_s", "__spirv_allDv3_s"},
867 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400868 {"_Z3alli", ""},
869 {"_Z3allDv2_i", "__spirv_allDv2_i"},
870 {"_Z3allDv3_i", "__spirv_allDv3_i"},
871 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000872 {"_Z3alll", ""},
873 {"_Z3allDv2_l", "__spirv_allDv2_l"},
874 {"_Z3allDv3_l", "__spirv_allDv3_l"},
875 {"_Z3allDv4_l", "__spirv_allDv4_l"},
876
877 // any
alan-bakerb39c8262019-03-08 14:03:37 -0500878 {"_Z3anyc", ""},
879 {"_Z3anyDv2_c", "__spirv_anyDv2_c"},
880 {"_Z3anyDv3_c", "__spirv_anyDv3_c"},
881 {"_Z3anyDv4_c", "__spirv_anyDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000882 {"_Z3anys", ""},
883 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
884 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
885 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400886 {"_Z3anyi", ""},
887 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
888 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
889 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000890 {"_Z3anyl", ""},
891 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
892 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
893 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400894 };
895
896 for (auto Pair : Map) {
897 // If we find a function with the matching name.
898 if (auto F = M.getFunction(Pair.first)) {
899 SmallVector<Instruction *, 4> ToRemoves;
900
901 // Walk the users of the function.
902 for (auto &U : F->uses()) {
903 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
904 // The fake SPIR-V intrinsic to generate.
905 auto SPIRVIntrinsic = Pair.second;
906
907 auto Arg = CI->getOperand(0);
908
909 Value *V;
910
Kévin Petitfd27cca2018-10-31 13:00:17 +0000911 // If the argument is a 32-bit int, just use a shift
912 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
913 V = BinaryOperator::Create(Instruction::LShr, Arg,
914 ConstantInt::get(Arg->getType(), 31), "",
915 CI);
916 } else {
David Neto22f144c2017-06-12 14:26:21 -0400917 // The value for zero to compare against.
918 const auto ZeroValue = Constant::getNullValue(Arg->getType());
919
David Neto22f144c2017-06-12 14:26:21 -0400920 // The value to return for true.
921 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
922
923 // The value to return for false.
924 const auto FalseValue = Constant::getNullValue(CI->getType());
925
Kévin Petitfd27cca2018-10-31 13:00:17 +0000926 const auto Cmp = CmpInst::Create(
927 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
928
929 Value* SelectSource;
930
931 // If we have a function to call, call it!
932 if (0 < strlen(SPIRVIntrinsic)) {
933
934 const auto NewFType = FunctionType::get(
935 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
936
937 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
938
939 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
940
941 SelectSource = NewCI;
942
943 } else {
944 SelectSource = Cmp;
945 }
946
947 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400948 }
949
950 CI->replaceAllUsesWith(V);
951
952 // Lastly, remember to remove the user.
953 ToRemoves.push_back(CI);
954 }
955 }
956
957 Changed = !ToRemoves.empty();
958
959 // And cleanup the calls we don't use anymore.
960 for (auto V : ToRemoves) {
961 V->eraseFromParent();
962 }
963
964 // And remove the function we don't need either too.
965 F->eraseFromParent();
966 }
967 }
968
969 return Changed;
970}
971
Kévin Petitbf0036c2019-03-06 13:57:10 +0000972bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
973 bool Changed = false;
974
975 for (auto const &SymVal : M.getValueSymbolTable()) {
976 // Skip symbols whose name doesn't match
977 if (!SymVal.getKey().startswith("_Z8upsample")) {
978 continue;
979 }
980 // Is there a function going by that name?
981 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
982
983 SmallVector<Instruction *, 4> ToRemoves;
984
985 // Walk the users of the function.
986 for (auto &U : F->uses()) {
987 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
988
989 // Get arguments
990 auto HiValue = CI->getOperand(0);
991 auto LoValue = CI->getOperand(1);
992
993 // Don't touch overloads that aren't in OpenCL C
994 auto HiType = HiValue->getType();
995 auto LoType = LoValue->getType();
996
997 if (HiType != LoType) {
998 continue;
999 }
1000
1001 if (!HiType->isIntOrIntVectorTy()) {
1002 continue;
1003 }
1004
1005 if (HiType->getScalarSizeInBits() * 2 !=
1006 CI->getType()->getScalarSizeInBits()) {
1007 continue;
1008 }
1009
1010 if ((HiType->getScalarSizeInBits() != 8) &&
1011 (HiType->getScalarSizeInBits() != 16) &&
1012 (HiType->getScalarSizeInBits() != 32)) {
1013 continue;
1014 }
1015
1016 if (HiType->isVectorTy()) {
1017 if ((HiType->getVectorNumElements() != 2) &&
1018 (HiType->getVectorNumElements() != 3) &&
1019 (HiType->getVectorNumElements() != 4) &&
1020 (HiType->getVectorNumElements() != 8) &&
1021 (HiType->getVectorNumElements() != 16)) {
1022 continue;
1023 }
1024 }
1025
1026 // Convert both operands to the result type
1027 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
1028 "", CI);
1029 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
1030 "", CI);
1031
1032 // Shift high operand
1033 auto ShiftAmount = ConstantInt::get(CI->getType(),
1034 HiType->getScalarSizeInBits());
1035 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1036 ShiftAmount, "", CI);
1037
1038 // OR both results
1039 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1040 "", CI);
1041
1042 // Replace call with the expression
1043 CI->replaceAllUsesWith(V);
1044
1045 // Lastly, remember to remove the user.
1046 ToRemoves.push_back(CI);
1047 }
1048 }
1049
1050 Changed = !ToRemoves.empty();
1051
1052 // And cleanup the calls we don't use anymore.
1053 for (auto V : ToRemoves) {
1054 V->eraseFromParent();
1055 }
1056
1057 // And remove the function we don't need either too.
1058 F->eraseFromParent();
1059 }
1060 }
1061
1062 return Changed;
1063}
1064
Kévin Petitd44eef52019-03-08 13:22:14 +00001065bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1066 bool Changed = false;
1067
1068 for (auto const &SymVal : M.getValueSymbolTable()) {
1069 // Skip symbols whose name doesn't match
1070 if (!SymVal.getKey().startswith("_Z6rotate")) {
1071 continue;
1072 }
1073 // Is there a function going by that name?
1074 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1075
1076 SmallVector<Instruction *, 4> ToRemoves;
1077
1078 // Walk the users of the function.
1079 for (auto &U : F->uses()) {
1080 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1081
1082 // Get arguments
1083 auto SrcValue = CI->getOperand(0);
1084 auto RotAmount = CI->getOperand(1);
1085
1086 // Don't touch overloads that aren't in OpenCL C
1087 auto SrcType = SrcValue->getType();
1088 auto RotType = RotAmount->getType();
1089
1090 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1091 continue;
1092 }
1093
1094 if (!SrcType->isIntOrIntVectorTy()) {
1095 continue;
1096 }
1097
1098 if ((SrcType->getScalarSizeInBits() != 8) &&
1099 (SrcType->getScalarSizeInBits() != 16) &&
1100 (SrcType->getScalarSizeInBits() != 32) &&
1101 (SrcType->getScalarSizeInBits() != 64)) {
1102 continue;
1103 }
1104
1105 if (SrcType->isVectorTy()) {
1106 if ((SrcType->getVectorNumElements() != 2) &&
1107 (SrcType->getVectorNumElements() != 3) &&
1108 (SrcType->getVectorNumElements() != 4) &&
1109 (SrcType->getVectorNumElements() != 8) &&
1110 (SrcType->getVectorNumElements() != 16)) {
1111 continue;
1112 }
1113 }
1114
1115 // The approach used is to shift the top bits down, the bottom bits up
1116 // and OR the two shifted values.
1117
1118 // The rotation amount is to be treated modulo the element size.
1119 // Since SPIR-V shift ops don't support this, let's apply the
1120 // modulo ahead of shifting. The element size is always a power of
1121 // two so we can just AND with a mask.
1122 auto ModMask = ConstantInt::get(SrcType,
1123 SrcType->getScalarSizeInBits() - 1);
1124 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1125 ModMask, "", CI);
1126
1127 // Let's calc the amount by which to shift top bits down
1128 auto ScalarSize = ConstantInt::get(SrcType,
1129 SrcType->getScalarSizeInBits());
1130 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1131 RotAmount, "", CI);
1132
1133 // Now shift the bottom bits up and the top bits down
1134 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1135 RotAmount, "", CI);
1136 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1137 DownAmount, "", CI);
1138
1139 // Finally OR the two shifted values
1140 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1141 HiRotated, "", CI);
1142
1143 // Replace call with the expression
1144 CI->replaceAllUsesWith(V);
1145
1146 // Lastly, remember to remove the user.
1147 ToRemoves.push_back(CI);
1148 }
1149 }
1150
1151 Changed = !ToRemoves.empty();
1152
1153 // And cleanup the calls we don't use anymore.
1154 for (auto V : ToRemoves) {
1155 V->eraseFromParent();
1156 }
1157
1158 // And remove the function we don't need either too.
1159 F->eraseFromParent();
1160 }
1161 }
1162
1163 return Changed;
1164}
1165
Kévin Petit8a560882019-03-21 15:24:34 +00001166bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1167 bool Changed = false;
1168
1169 for (auto const &SymVal : M.getValueSymbolTable()) {
1170
1171 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1172 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1173
1174 // Skip symbols whose name doesn't match
1175 if (!isMad && !isMul) {
1176 continue;
1177 }
1178
1179 // Is there a function going by that name?
1180 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1181
1182 SmallVector<Instruction *, 4> ToRemoves;
1183
1184 // Walk the users of the function.
1185 for (auto &U : F->uses()) {
1186 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1187
1188 // Get arguments
1189 auto AValue = CI->getOperand(0);
1190 auto BValue = CI->getOperand(1);
1191 auto CValue = CI->getOperand(2);
1192
1193 // Don't touch overloads that aren't in OpenCL C
1194 auto AType = AValue->getType();
1195 auto BType = BValue->getType();
1196 auto CType = CValue->getType();
1197
1198 if ((AType != BType) || (CI->getType() != AType) ||
1199 (isMad && (AType != CType))) {
1200 continue;
1201 }
1202
1203 if (!AType->isIntOrIntVectorTy()) {
1204 continue;
1205 }
1206
1207 if ((AType->getScalarSizeInBits() != 8) &&
1208 (AType->getScalarSizeInBits() != 16) &&
1209 (AType->getScalarSizeInBits() != 32) &&
1210 (AType->getScalarSizeInBits() != 64)) {
1211 continue;
1212 }
1213
1214 if (AType->isVectorTy()) {
1215 if ((AType->getVectorNumElements() != 2) &&
1216 (AType->getVectorNumElements() != 3) &&
1217 (AType->getVectorNumElements() != 4) &&
1218 (AType->getVectorNumElements() != 8) &&
1219 (AType->getVectorNumElements() != 16)) {
1220 continue;
1221 }
1222 }
1223
1224 // Create struct type for the return type of our SPIR-V intrinsic
1225 SmallVector<Type*, 2> TwoValueType = {
1226 AType,
1227 AType
1228 };
1229
1230 auto ExMulRetType = StructType::create(TwoValueType);
1231
1232 // And a function type
1233 auto NewFType = FunctionType::get(ExMulRetType, TwoValueType, false);
1234
1235 // Get infos from the mangled OpenCL built-in function name
1236 FunctionInfo finfo;
1237 getFunctionInfoFromMangledName(F->getName(), &finfo);
1238
1239 // Use it to select the appropriate signed/unsigned SPIR-V intrinsic
1240 StringRef intrinsic;
1241 if (finfo.argTypeInfos[0].signedness == ArgTypeInfo::SignedNess::Signed) {
1242 intrinsic = "spirv.smul_extended";
1243 } else {
1244 intrinsic = "spirv.umul_extended";
1245 }
1246
1247 // Add the intrinsic function to the module
1248 auto NewF = M.getOrInsertFunction(intrinsic, NewFType);
1249
1250 // Call it
1251 SmallVector<Value*, 4> NewFArgs = {
1252 AValue,
1253 BValue,
1254 };
1255
1256 auto Call = CallInst::Create(NewF, NewFArgs, "", CI);
1257
1258 // Get the high part of the result
1259 unsigned Idxs[] = {1};
1260 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1261
1262 // If we're handling a mad_hi, add the third argument to the result
1263 if (isMad) {
1264 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1265 }
1266
1267 // Replace call with the expression
1268 CI->replaceAllUsesWith(V);
1269
1270 // Lastly, remember to remove the user.
1271 ToRemoves.push_back(CI);
1272 }
1273 }
1274
1275 Changed = !ToRemoves.empty();
1276
1277 // And cleanup the calls we don't use anymore.
1278 for (auto V : ToRemoves) {
1279 V->eraseFromParent();
1280 }
1281
1282 // And remove the function we don't need either too.
1283 F->eraseFromParent();
1284 }
1285 }
1286
1287 return Changed;
1288}
1289
Kévin Petitf5b78a22018-10-25 14:32:17 +00001290bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1291 bool Changed = false;
1292
1293 for (auto const &SymVal : M.getValueSymbolTable()) {
1294 // Skip symbols whose name doesn't match
1295 if (!SymVal.getKey().startswith("_Z6select")) {
1296 continue;
1297 }
1298 // Is there a function going by that name?
1299 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1300
1301 SmallVector<Instruction *, 4> ToRemoves;
1302
1303 // Walk the users of the function.
1304 for (auto &U : F->uses()) {
1305 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1306
1307 // Get arguments
1308 auto FalseValue = CI->getOperand(0);
1309 auto TrueValue = CI->getOperand(1);
1310 auto PredicateValue = CI->getOperand(2);
1311
1312 // Don't touch overloads that aren't in OpenCL C
1313 auto FalseType = FalseValue->getType();
1314 auto TrueType = TrueValue->getType();
1315 auto PredicateType = PredicateValue->getType();
1316
1317 if (FalseType != TrueType) {
1318 continue;
1319 }
1320
1321 if (!PredicateType->isIntOrIntVectorTy()) {
1322 continue;
1323 }
1324
1325 if (!FalseType->isIntOrIntVectorTy() &&
1326 !FalseType->getScalarType()->isFloatingPointTy()) {
1327 continue;
1328 }
1329
1330 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1331 continue;
1332 }
1333
1334 if (FalseType->getScalarSizeInBits() !=
1335 PredicateType->getScalarSizeInBits()) {
1336 continue;
1337 }
1338
1339 if (FalseType->isVectorTy()) {
1340 if (FalseType->getVectorNumElements() !=
1341 PredicateType->getVectorNumElements()) {
1342 continue;
1343 }
1344
1345 if ((FalseType->getVectorNumElements() != 2) &&
1346 (FalseType->getVectorNumElements() != 3) &&
1347 (FalseType->getVectorNumElements() != 4) &&
1348 (FalseType->getVectorNumElements() != 8) &&
1349 (FalseType->getVectorNumElements() != 16)) {
1350 continue;
1351 }
1352 }
1353
1354 // Create constant
1355 const auto ZeroValue = Constant::getNullValue(PredicateType);
1356
1357 // Scalar and vector are to be treated differently
1358 CmpInst::Predicate Pred;
1359 if (PredicateType->isVectorTy()) {
1360 Pred = CmpInst::ICMP_SLT;
1361 } else {
1362 Pred = CmpInst::ICMP_NE;
1363 }
1364
1365 // Create comparison instruction
1366 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1367 ZeroValue, "", CI);
1368
1369 // Create select
1370 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1371
1372 // Replace call with the selection
1373 CI->replaceAllUsesWith(V);
1374
1375 // Lastly, remember to remove the user.
1376 ToRemoves.push_back(CI);
1377 }
1378 }
1379
1380 Changed = !ToRemoves.empty();
1381
1382 // And cleanup the calls we don't use anymore.
1383 for (auto V : ToRemoves) {
1384 V->eraseFromParent();
1385 }
1386
1387 // And remove the function we don't need either too.
1388 F->eraseFromParent();
1389 }
1390 }
1391
1392 return Changed;
1393}
1394
Kévin Petite7d0cce2018-10-31 12:38:56 +00001395bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1396 bool Changed = false;
1397
1398 for (auto const &SymVal : M.getValueSymbolTable()) {
1399 // Skip symbols whose name doesn't match
1400 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1401 continue;
1402 }
1403 // Is there a function going by that name?
1404 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1405
1406 SmallVector<Instruction *, 4> ToRemoves;
1407
1408 // Walk the users of the function.
1409 for (auto &U : F->uses()) {
1410 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1411
1412 if (CI->getNumOperands() != 4) {
1413 continue;
1414 }
1415
1416 // Get arguments
1417 auto FalseValue = CI->getOperand(0);
1418 auto TrueValue = CI->getOperand(1);
1419 auto PredicateValue = CI->getOperand(2);
1420
1421 // Don't touch overloads that aren't in OpenCL C
1422 auto FalseType = FalseValue->getType();
1423 auto TrueType = TrueValue->getType();
1424 auto PredicateType = PredicateValue->getType();
1425
1426 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1427 continue;
1428 }
1429
1430 if (TrueType->isVectorTy()) {
1431 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1432 !TrueType->getScalarType()->isIntegerTy()) {
1433 continue;
1434 }
1435 if ((TrueType->getVectorNumElements() != 2) &&
1436 (TrueType->getVectorNumElements() != 3) &&
1437 (TrueType->getVectorNumElements() != 4) &&
1438 (TrueType->getVectorNumElements() != 8) &&
1439 (TrueType->getVectorNumElements() != 16)) {
1440 continue;
1441 }
1442 }
1443
1444 // Remember the type of the operands
1445 auto OpType = TrueType;
1446
1447 // The actual bit selection will always be done on an integer type,
1448 // declare it here
1449 Type *BitType;
1450
1451 // If the operands are float, then bitcast them to int
1452 if (OpType->getScalarType()->isFloatingPointTy()) {
1453
1454 // First create the new type
1455 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1456 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1457 if (OpType->isVectorTy()) {
1458 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1459 }
1460
1461 // Then bitcast all operands
1462 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1463 BitType, "", CI);
1464 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1465 BitType, "", CI);
1466 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1467
1468 } else {
1469 // The operands have an integer type, use it directly
1470 BitType = OpType;
1471 }
1472
1473 // All the operands are now always integers
1474 // implement as (c & b) | (~c & a)
1475
1476 // Create our negated predicate value
1477 auto AllOnes = Constant::getAllOnesValue(BitType);
1478 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1479 PredicateValue,
1480 AllOnes, "", CI);
1481
1482 // Then put everything together
1483 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1484 NotPredicateValue,
1485 FalseValue, "", CI);
1486 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1487 PredicateValue,
1488 TrueValue, "", CI);
1489
1490 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1491 BitsTrue, "", CI);
1492
1493 // If we were dealing with a floating point type, we must bitcast
1494 // the result back to that
1495 if (OpType->getScalarType()->isFloatingPointTy()) {
1496 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1497 }
1498
1499 // Replace call with our new code
1500 CI->replaceAllUsesWith(V);
1501
1502 // Lastly, remember to remove the user.
1503 ToRemoves.push_back(CI);
1504 }
1505 }
1506
1507 Changed = !ToRemoves.empty();
1508
1509 // And cleanup the calls we don't use anymore.
1510 for (auto V : ToRemoves) {
1511 V->eraseFromParent();
1512 }
1513
1514 // And remove the function we don't need either too.
1515 F->eraseFromParent();
1516 }
1517 }
1518
1519 return Changed;
1520}
1521
Kévin Petit6b0a9532018-10-30 20:00:39 +00001522bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1523 bool Changed = false;
1524
1525 const std::map<const char *, const char *> Map = {
1526 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1527 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1528 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1529 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1530 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1531 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1532 };
1533
1534 for (auto Pair : Map) {
1535 // If we find a function with the matching name.
1536 if (auto F = M.getFunction(Pair.first)) {
1537 SmallVector<Instruction *, 4> ToRemoves;
1538
1539 // Walk the users of the function.
1540 for (auto &U : F->uses()) {
1541 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1542
1543 auto ReplacementFn = Pair.second;
1544
1545 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1546 Value *VectorArg;
1547
1548 // First figure out which function we're dealing with
1549 if (F->getName().startswith("_Z10smoothstep")) {
1550 ArgsToSplat.push_back(CI->getOperand(1));
1551 VectorArg = CI->getOperand(2);
1552 } else {
1553 VectorArg = CI->getOperand(1);
1554 }
1555
1556 // Splat arguments that need to be
1557 SmallVector<Value*, 2> SplatArgs;
1558 auto VecType = VectorArg->getType();
1559
1560 for (auto arg : ArgsToSplat) {
1561 Value* NewVectorArg = UndefValue::get(VecType);
1562 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1563 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1564 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1565 }
1566 SplatArgs.push_back(NewVectorArg);
1567 }
1568
1569 // Replace the call with the vector/vector flavour
1570 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1571 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1572
1573 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1574
1575 SmallVector<Value*, 3> NewArgs;
1576 for (auto arg : SplatArgs) {
1577 NewArgs.push_back(arg);
1578 }
1579 NewArgs.push_back(VectorArg);
1580
1581 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1582
1583 CI->replaceAllUsesWith(NewCI);
1584
1585 // Lastly, remember to remove the user.
1586 ToRemoves.push_back(CI);
1587 }
1588 }
1589
1590 Changed = !ToRemoves.empty();
1591
1592 // And cleanup the calls we don't use anymore.
1593 for (auto V : ToRemoves) {
1594 V->eraseFromParent();
1595 }
1596
1597 // And remove the function we don't need either too.
1598 F->eraseFromParent();
1599 }
1600 }
1601
1602 return Changed;
1603}
1604
David Neto22f144c2017-06-12 14:26:21 -04001605bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1606 bool Changed = false;
1607
1608 const std::map<const char *, Instruction::BinaryOps> Map = {
1609 {"_Z7signbitf", Instruction::LShr},
1610 {"_Z7signbitDv2_f", Instruction::AShr},
1611 {"_Z7signbitDv3_f", Instruction::AShr},
1612 {"_Z7signbitDv4_f", Instruction::AShr},
1613 };
1614
1615 for (auto Pair : Map) {
1616 // If we find a function with the matching name.
1617 if (auto F = M.getFunction(Pair.first)) {
1618 SmallVector<Instruction *, 4> ToRemoves;
1619
1620 // Walk the users of the function.
1621 for (auto &U : F->uses()) {
1622 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1623 auto Arg = CI->getOperand(0);
1624
1625 auto Bitcast =
1626 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1627
1628 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1629 ConstantInt::get(CI->getType(), 31),
1630 "", CI);
1631
1632 CI->replaceAllUsesWith(Shr);
1633
1634 // Lastly, remember to remove the user.
1635 ToRemoves.push_back(CI);
1636 }
1637 }
1638
1639 Changed = !ToRemoves.empty();
1640
1641 // And cleanup the calls we don't use anymore.
1642 for (auto V : ToRemoves) {
1643 V->eraseFromParent();
1644 }
1645
1646 // And remove the function we don't need either too.
1647 F->eraseFromParent();
1648 }
1649 }
1650
1651 return Changed;
1652}
1653
1654bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1655 bool Changed = false;
1656
1657 const std::map<const char *,
1658 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1659 Map = {
1660 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1661 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1662 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1663 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1664 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1665 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1666 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1667 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1668 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1669 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1670 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1671 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1672 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1673 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1674 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1675 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1676 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1677 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1678 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1679 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1680 };
1681
1682 for (auto Pair : Map) {
1683 // If we find a function with the matching name.
1684 if (auto F = M.getFunction(Pair.first)) {
1685 SmallVector<Instruction *, 4> ToRemoves;
1686
1687 // Walk the users of the function.
1688 for (auto &U : F->uses()) {
1689 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1690 // The multiply instruction to use.
1691 auto MulInst = Pair.second.first;
1692
1693 // The add instruction to use.
1694 auto AddInst = Pair.second.second;
1695
1696 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1697
1698 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1699 CI->getArgOperand(1), "", CI);
1700
1701 if (Instruction::BinaryOpsEnd != AddInst) {
1702 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1703 CI);
1704 }
1705
1706 CI->replaceAllUsesWith(I);
1707
1708 // Lastly, remember to remove the user.
1709 ToRemoves.push_back(CI);
1710 }
1711 }
1712
1713 Changed = !ToRemoves.empty();
1714
1715 // And cleanup the calls we don't use anymore.
1716 for (auto V : ToRemoves) {
1717 V->eraseFromParent();
1718 }
1719
1720 // And remove the function we don't need either too.
1721 F->eraseFromParent();
1722 }
1723 }
1724
1725 return Changed;
1726}
1727
Derek Chowcfd368b2017-10-19 20:58:45 -07001728bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1729 bool Changed = false;
1730
1731 struct VectorStoreOps {
1732 const char* name;
1733 int n;
1734 Type* (*get_scalar_type_function)(LLVMContext&);
1735 } vector_store_ops[] = {
1736 // TODO(derekjchow): Expand this list.
1737 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1738 };
1739
David Neto544fffc2017-11-16 18:35:14 -05001740 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001741 auto Name = Op.name;
1742 auto N = Op.n;
1743 auto TypeFn = Op.get_scalar_type_function;
1744 if (auto F = M.getFunction(Name)) {
1745 SmallVector<Instruction *, 4> ToRemoves;
1746
1747 // Walk the users of the function.
1748 for (auto &U : F->uses()) {
1749 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1750 // The value argument from vstoren.
1751 auto Arg0 = CI->getOperand(0);
1752
1753 // The index argument from vstoren.
1754 auto Arg1 = CI->getOperand(1);
1755
1756 // The pointer argument from vstoren.
1757 auto Arg2 = CI->getOperand(2);
1758
1759 // Get types.
1760 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1761 auto ScalarNPointerTy = PointerType::get(
1762 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1763
1764 // Cast to scalarn
1765 auto Cast = CastInst::CreatePointerCast(
1766 Arg2, ScalarNPointerTy, "", CI);
1767 // Index to correct address
1768 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1769 // Store
1770 auto Store = new StoreInst(Arg0, Index, CI);
1771
1772 CI->replaceAllUsesWith(Store);
1773 ToRemoves.push_back(CI);
1774 }
1775 }
1776
1777 Changed = !ToRemoves.empty();
1778
1779 // And cleanup the calls we don't use anymore.
1780 for (auto V : ToRemoves) {
1781 V->eraseFromParent();
1782 }
1783
1784 // And remove the function we don't need either too.
1785 F->eraseFromParent();
1786 }
1787 }
1788
1789 return Changed;
1790}
1791
1792bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1793 bool Changed = false;
1794
1795 struct VectorLoadOps {
1796 const char* name;
1797 int n;
1798 Type* (*get_scalar_type_function)(LLVMContext&);
1799 } vector_load_ops[] = {
1800 // TODO(derekjchow): Expand this list.
1801 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1802 };
1803
David Neto544fffc2017-11-16 18:35:14 -05001804 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001805 auto Name = Op.name;
1806 auto N = Op.n;
1807 auto TypeFn = Op.get_scalar_type_function;
1808 // If we find a function with the matching name.
1809 if (auto F = M.getFunction(Name)) {
1810 SmallVector<Instruction *, 4> ToRemoves;
1811
1812 // Walk the users of the function.
1813 for (auto &U : F->uses()) {
1814 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1815 // The index argument from vloadn.
1816 auto Arg0 = CI->getOperand(0);
1817
1818 // The pointer argument from vloadn.
1819 auto Arg1 = CI->getOperand(1);
1820
1821 // Get types.
1822 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1823 auto ScalarNPointerTy = PointerType::get(
1824 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1825
1826 // Cast to scalarn
1827 auto Cast = CastInst::CreatePointerCast(
1828 Arg1, ScalarNPointerTy, "", CI);
1829 // Index to correct address
1830 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1831 // Load
1832 auto Load = new LoadInst(Index, "", CI);
1833
1834 CI->replaceAllUsesWith(Load);
1835 ToRemoves.push_back(CI);
1836 }
1837 }
1838
1839 Changed = !ToRemoves.empty();
1840
1841 // And cleanup the calls we don't use anymore.
1842 for (auto V : ToRemoves) {
1843 V->eraseFromParent();
1844 }
1845
1846 // And remove the function we don't need either too.
1847 F->eraseFromParent();
1848
1849 }
1850 }
1851
1852 return Changed;
1853}
1854
David Neto22f144c2017-06-12 14:26:21 -04001855bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1856 bool Changed = false;
1857
1858 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
1859 "_Z10vload_halfjPU3AS2KDh"};
1860
1861 for (auto Name : Map) {
1862 // If we find a function with the matching name.
1863 if (auto F = M.getFunction(Name)) {
1864 SmallVector<Instruction *, 4> ToRemoves;
1865
1866 // Walk the users of the function.
1867 for (auto &U : F->uses()) {
1868 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1869 // The index argument from vload_half.
1870 auto Arg0 = CI->getOperand(0);
1871
1872 // The pointer argument from vload_half.
1873 auto Arg1 = CI->getOperand(1);
1874
David Neto22f144c2017-06-12 14:26:21 -04001875 auto IntTy = Type::getInt32Ty(M.getContext());
1876 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001877 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1878
David Neto22f144c2017-06-12 14:26:21 -04001879 // Our intrinsic to unpack a float2 from an int.
1880 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1881
1882 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1883
David Neto482550a2018-03-24 05:21:07 -07001884 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04001885 auto ShortTy = Type::getInt16Ty(M.getContext());
1886 auto ShortPointerTy = PointerType::get(
1887 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001888
David Netoac825b82017-05-30 12:49:01 -04001889 // Cast the half* pointer to short*.
1890 auto Cast =
1891 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001892
David Netoac825b82017-05-30 12:49:01 -04001893 // Index into the correct address of the casted pointer.
1894 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1895
1896 // Load from the short* we casted to.
1897 auto Load = new LoadInst(Index, "", CI);
1898
1899 // ZExt the short -> int.
1900 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1901
1902 // Get our float2.
1903 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1904
1905 // Extract out the bottom element which is our float result.
1906 auto Extract = ExtractElementInst::Create(
1907 Call, ConstantInt::get(IntTy, 0), "", CI);
1908
1909 CI->replaceAllUsesWith(Extract);
1910 } else {
1911 // Assume the pointer argument points to storage aligned to 32bits
1912 // or more.
1913 // TODO(dneto): Do more analysis to make sure this is true?
1914 //
1915 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1916 // with:
1917 //
1918 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1919 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1920 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1921 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1922 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1923 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1924 // x float> %converted, %index_is_odd32
1925
1926 auto IntPointerTy = PointerType::get(
1927 IntTy, Arg1->getType()->getPointerAddressSpace());
1928
David Neto973e6a82017-05-30 13:48:18 -04001929 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04001930 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04001931 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04001932 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1933
1934 auto One = ConstantInt::get(IntTy, 1);
1935 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1936 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1937
1938 // Index into the correct address of the casted pointer.
1939 auto Ptr =
1940 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1941
1942 // Load from the int* we casted to.
1943 auto Load = new LoadInst(Ptr, "", CI);
1944
1945 // Get our float2.
1946 auto Call = CallInst::Create(NewF, Load, "", CI);
1947
1948 // Extract out the float result, where the element number is
1949 // determined by whether the original index was even or odd.
1950 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1951
1952 CI->replaceAllUsesWith(Extract);
1953 }
David Neto22f144c2017-06-12 14:26:21 -04001954
1955 // Lastly, remember to remove the user.
1956 ToRemoves.push_back(CI);
1957 }
1958 }
1959
1960 Changed = !ToRemoves.empty();
1961
1962 // And cleanup the calls we don't use anymore.
1963 for (auto V : ToRemoves) {
1964 V->eraseFromParent();
1965 }
1966
1967 // And remove the function we don't need either too.
1968 F->eraseFromParent();
1969 }
1970 }
1971
1972 return Changed;
1973}
1974
1975bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
1976 bool Changed = false;
1977
David Neto556c7e62018-06-08 13:45:55 -07001978 const std::vector<const char *> Map = {
1979 "_Z11vload_half2jPU3AS1KDh",
1980 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
1981 "_Z11vload_half2jPU3AS2KDh",
1982 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
1983 };
David Neto22f144c2017-06-12 14:26:21 -04001984
1985 for (auto Name : Map) {
1986 // If we find a function with the matching name.
1987 if (auto F = M.getFunction(Name)) {
1988 SmallVector<Instruction *, 4> ToRemoves;
1989
1990 // Walk the users of the function.
1991 for (auto &U : F->uses()) {
1992 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1993 // The index argument from vload_half.
1994 auto Arg0 = CI->getOperand(0);
1995
1996 // The pointer argument from vload_half.
1997 auto Arg1 = CI->getOperand(1);
1998
1999 auto IntTy = Type::getInt32Ty(M.getContext());
2000 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2001 auto NewPointerTy = PointerType::get(
2002 IntTy, Arg1->getType()->getPointerAddressSpace());
2003 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2004
2005 // Cast the half* pointer to int*.
2006 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2007
2008 // Index into the correct address of the casted pointer.
2009 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
2010
2011 // Load from the int* we casted to.
2012 auto Load = new LoadInst(Index, "", CI);
2013
2014 // Our intrinsic to unpack a float2 from an int.
2015 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2016
2017 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2018
2019 // Get our float2.
2020 auto Call = CallInst::Create(NewF, Load, "", CI);
2021
2022 CI->replaceAllUsesWith(Call);
2023
2024 // Lastly, remember to remove the user.
2025 ToRemoves.push_back(CI);
2026 }
2027 }
2028
2029 Changed = !ToRemoves.empty();
2030
2031 // And cleanup the calls we don't use anymore.
2032 for (auto V : ToRemoves) {
2033 V->eraseFromParent();
2034 }
2035
2036 // And remove the function we don't need either too.
2037 F->eraseFromParent();
2038 }
2039 }
2040
2041 return Changed;
2042}
2043
2044bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
2045 bool Changed = false;
2046
David Neto556c7e62018-06-08 13:45:55 -07002047 const std::vector<const char *> Map = {
2048 "_Z11vload_half4jPU3AS1KDh",
2049 "_Z12vloada_half4jPU3AS1KDh",
2050 "_Z11vload_half4jPU3AS2KDh",
2051 "_Z12vloada_half4jPU3AS2KDh",
2052 };
David Neto22f144c2017-06-12 14:26:21 -04002053
2054 for (auto Name : Map) {
2055 // If we find a function with the matching name.
2056 if (auto F = M.getFunction(Name)) {
2057 SmallVector<Instruction *, 4> ToRemoves;
2058
2059 // Walk the users of the function.
2060 for (auto &U : F->uses()) {
2061 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2062 // The index argument from vload_half.
2063 auto Arg0 = CI->getOperand(0);
2064
2065 // The pointer argument from vload_half.
2066 auto Arg1 = CI->getOperand(1);
2067
2068 auto IntTy = Type::getInt32Ty(M.getContext());
2069 auto Int2Ty = VectorType::get(IntTy, 2);
2070 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2071 auto NewPointerTy = PointerType::get(
2072 Int2Ty, Arg1->getType()->getPointerAddressSpace());
2073 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2074
2075 // Cast the half* pointer to int2*.
2076 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2077
2078 // Index into the correct address of the casted pointer.
2079 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
2080
2081 // Load from the int2* we casted to.
2082 auto Load = new LoadInst(Index, "", CI);
2083
2084 // Extract each element from the loaded int2.
2085 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2086 "", CI);
2087 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2088 "", CI);
2089
2090 // Our intrinsic to unpack a float2 from an int.
2091 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2092
2093 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2094
2095 // Get the lower (x & y) components of our final float4.
2096 auto Lo = CallInst::Create(NewF, X, "", CI);
2097
2098 // Get the higher (z & w) components of our final float4.
2099 auto Hi = CallInst::Create(NewF, Y, "", CI);
2100
2101 Constant *ShuffleMask[4] = {
2102 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2103 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2104
2105 // Combine our two float2's into one float4.
2106 auto Combine = new ShuffleVectorInst(
2107 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2108
2109 CI->replaceAllUsesWith(Combine);
2110
2111 // Lastly, remember to remove the user.
2112 ToRemoves.push_back(CI);
2113 }
2114 }
2115
2116 Changed = !ToRemoves.empty();
2117
2118 // And cleanup the calls we don't use anymore.
2119 for (auto V : ToRemoves) {
2120 V->eraseFromParent();
2121 }
2122
2123 // And remove the function we don't need either too.
2124 F->eraseFromParent();
2125 }
2126 }
2127
2128 return Changed;
2129}
2130
David Neto6ad93232018-06-07 15:42:58 -07002131bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
2132 bool Changed = false;
2133
2134 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2135 //
2136 // %u = load i32 %ptr
2137 // %fxy = call <2 x float> Unpack2xHalf(u)
2138 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2139 const std::vector<const char *> Map = {
2140 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2141 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2142 "_Z20__clspv_vloada_half2jPKj", // private
2143 };
2144
2145 for (auto Name : Map) {
2146 // If we find a function with the matching name.
2147 if (auto F = M.getFunction(Name)) {
2148 SmallVector<Instruction *, 4> ToRemoves;
2149
2150 // Walk the users of the function.
2151 for (auto &U : F->uses()) {
2152 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
2153 auto Index = CI->getOperand(0);
2154 auto Ptr = CI->getOperand(1);
2155
2156 auto IntTy = Type::getInt32Ty(M.getContext());
2157 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2158 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2159
2160 auto IndexedPtr =
2161 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
2162 auto Load = new LoadInst(IndexedPtr, "", CI);
2163
2164 // Our intrinsic to unpack a float2 from an int.
2165 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2166
2167 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2168
2169 // Get our final float2.
2170 auto Result = CallInst::Create(NewF, Load, "", CI);
2171
2172 CI->replaceAllUsesWith(Result);
2173
2174 // Lastly, remember to remove the user.
2175 ToRemoves.push_back(CI);
2176 }
2177 }
2178
2179 Changed = true;
2180
2181 // And cleanup the calls we don't use anymore.
2182 for (auto V : ToRemoves) {
2183 V->eraseFromParent();
2184 }
2185
2186 // And remove the function we don't need either too.
2187 F->eraseFromParent();
2188 }
2189 }
2190
2191 return Changed;
2192}
2193
2194bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
2195 bool Changed = false;
2196
2197 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2198 //
2199 // %u2 = load <2 x i32> %ptr
2200 // %u2xy = extractelement %u2, 0
2201 // %u2zw = extractelement %u2, 1
2202 // %fxy = call <2 x float> Unpack2xHalf(uint)
2203 // %fzw = call <2 x float> Unpack2xHalf(uint)
2204 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2205 const std::vector<const char *> Map = {
2206 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2207 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2208 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2209 };
2210
2211 for (auto Name : Map) {
2212 // If we find a function with the matching name.
2213 if (auto F = M.getFunction(Name)) {
2214 SmallVector<Instruction *, 4> ToRemoves;
2215
2216 // Walk the users of the function.
2217 for (auto &U : F->uses()) {
2218 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2219 auto Index = CI->getOperand(0);
2220 auto Ptr = CI->getOperand(1);
2221
2222 auto IntTy = Type::getInt32Ty(M.getContext());
2223 auto Int2Ty = VectorType::get(IntTy, 2);
2224 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2225 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2226
2227 auto IndexedPtr =
2228 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
2229 auto Load = new LoadInst(IndexedPtr, "", CI);
2230
2231 // Extract each element from the loaded int2.
2232 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2233 "", CI);
2234 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2235 "", CI);
2236
2237 // Our intrinsic to unpack a float2 from an int.
2238 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2239
2240 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2241
2242 // Get the lower (x & y) components of our final float4.
2243 auto Lo = CallInst::Create(NewF, X, "", CI);
2244
2245 // Get the higher (z & w) components of our final float4.
2246 auto Hi = CallInst::Create(NewF, Y, "", CI);
2247
2248 Constant *ShuffleMask[4] = {
2249 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2250 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2251
2252 // Combine our two float2's into one float4.
2253 auto Combine = new ShuffleVectorInst(
2254 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2255
2256 CI->replaceAllUsesWith(Combine);
2257
2258 // Lastly, remember to remove the user.
2259 ToRemoves.push_back(CI);
2260 }
2261 }
2262
2263 Changed = true;
2264
2265 // And cleanup the calls we don't use anymore.
2266 for (auto V : ToRemoves) {
2267 V->eraseFromParent();
2268 }
2269
2270 // And remove the function we don't need either too.
2271 F->eraseFromParent();
2272 }
2273 }
2274
2275 return Changed;
2276}
2277
David Neto22f144c2017-06-12 14:26:21 -04002278bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
2279 bool Changed = false;
2280
2281 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
2282 "_Z15vstore_half_rtefjPU3AS1Dh",
2283 "_Z15vstore_half_rtzfjPU3AS1Dh"};
2284
2285 for (auto Name : Map) {
2286 // If we find a function with the matching name.
2287 if (auto F = M.getFunction(Name)) {
2288 SmallVector<Instruction *, 4> ToRemoves;
2289
2290 // Walk the users of the function.
2291 for (auto &U : F->uses()) {
2292 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2293 // The value to store.
2294 auto Arg0 = CI->getOperand(0);
2295
2296 // The index argument from vstore_half.
2297 auto Arg1 = CI->getOperand(1);
2298
2299 // The pointer argument from vstore_half.
2300 auto Arg2 = CI->getOperand(2);
2301
David Neto22f144c2017-06-12 14:26:21 -04002302 auto IntTy = Type::getInt32Ty(M.getContext());
2303 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002304 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04002305 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002306
2307 // Our intrinsic to pack a float2 to an int.
2308 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2309
2310 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2311
2312 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04002313 auto TempVec =
2314 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
2315 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002316
2317 // Pack the float2 -> half2 (in an int).
2318 auto X = CallInst::Create(NewF, TempVec, "", CI);
2319
David Neto482550a2018-03-24 05:21:07 -07002320 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002321 auto ShortTy = Type::getInt16Ty(M.getContext());
2322 auto ShortPointerTy = PointerType::get(
2323 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002324
David Neto17852de2017-05-29 17:29:31 -04002325 // Truncate our i32 to an i16.
2326 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002327
David Neto17852de2017-05-29 17:29:31 -04002328 // Cast the half* pointer to short*.
2329 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002330
David Neto17852de2017-05-29 17:29:31 -04002331 // Index into the correct address of the casted pointer.
2332 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002333
David Neto17852de2017-05-29 17:29:31 -04002334 // Store to the int* we casted to.
2335 auto Store = new StoreInst(Trunc, Index, CI);
2336
2337 CI->replaceAllUsesWith(Store);
2338 } else {
2339 // We can only write to 32-bit aligned words.
2340 //
2341 // Assuming base is aligned to 32-bits, replace the equivalent of
2342 // vstore_half(value, index, base)
2343 // with:
2344 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2345 // uint32_t write_to_upper_half = index & 1u;
2346 // uint32_t shift = write_to_upper_half << 4;
2347 //
2348 // // Pack the float value as a half number in bottom 16 bits
2349 // // of an i32.
2350 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2351 //
2352 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2353 // ^ ((packed & 0xffff) << shift)
2354 // // We only need relaxed consistency, but OpenCL 1.2 only has
2355 // // sequentially consistent atomics.
2356 // // TODO(dneto): Use relaxed consistency.
2357 // atomic_xor(target_ptr, xor_value)
2358 auto IntPointerTy = PointerType::get(
2359 IntTy, Arg2->getType()->getPointerAddressSpace());
2360
2361 auto Four = ConstantInt::get(IntTy, 4);
2362 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2363
2364 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2365 // Compute index / 2
2366 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2367 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2368 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2369 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2370 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2371 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2372 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2373
2374 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2375 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2376 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2377
2378 // Generate the call to atomi_xor.
2379 SmallVector<Type *, 5> ParamTypes;
2380 // The pointer type.
2381 ParamTypes.push_back(IntPointerTy);
2382 // The Types for memory scope, semantics, and value.
2383 ParamTypes.push_back(IntTy);
2384 ParamTypes.push_back(IntTy);
2385 ParamTypes.push_back(IntTy);
2386 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2387 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2388
2389 const auto ConstantScopeDevice =
2390 ConstantInt::get(IntTy, spv::ScopeDevice);
2391 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2392 // (SPIR-V Workgroup).
2393 const auto AddrSpaceSemanticsBits =
2394 IntPointerTy->getPointerAddressSpace() == 1
2395 ? spv::MemorySemanticsUniformMemoryMask
2396 : spv::MemorySemanticsWorkgroupMemoryMask;
2397
2398 // We're using relaxed consistency here.
2399 const auto ConstantMemorySemantics =
2400 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2401 AddrSpaceSemanticsBits);
2402
2403 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2404 ConstantMemorySemantics, ValueToXor};
2405 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2406 }
David Neto22f144c2017-06-12 14:26:21 -04002407
2408 // Lastly, remember to remove the user.
2409 ToRemoves.push_back(CI);
2410 }
2411 }
2412
2413 Changed = !ToRemoves.empty();
2414
2415 // And cleanup the calls we don't use anymore.
2416 for (auto V : ToRemoves) {
2417 V->eraseFromParent();
2418 }
2419
2420 // And remove the function we don't need either too.
2421 F->eraseFromParent();
2422 }
2423 }
2424
2425 return Changed;
2426}
2427
2428bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2429 bool Changed = false;
2430
David Netoe2871522018-06-08 11:09:54 -07002431 const std::vector<const char *> Map = {
2432 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2433 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2434 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2435 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2436 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2437 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2438 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2439 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2440 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2441 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2442 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2443 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2444 };
David Neto22f144c2017-06-12 14:26:21 -04002445
2446 for (auto Name : Map) {
2447 // If we find a function with the matching name.
2448 if (auto F = M.getFunction(Name)) {
2449 SmallVector<Instruction *, 4> ToRemoves;
2450
2451 // Walk the users of the function.
2452 for (auto &U : F->uses()) {
2453 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2454 // The value to store.
2455 auto Arg0 = CI->getOperand(0);
2456
2457 // The index argument from vstore_half.
2458 auto Arg1 = CI->getOperand(1);
2459
2460 // The pointer argument from vstore_half.
2461 auto Arg2 = CI->getOperand(2);
2462
2463 auto IntTy = Type::getInt32Ty(M.getContext());
2464 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2465 auto NewPointerTy = PointerType::get(
2466 IntTy, Arg2->getType()->getPointerAddressSpace());
2467 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2468
2469 // Our intrinsic to pack a float2 to an int.
2470 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2471
2472 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2473
2474 // Turn the packed x & y into the final packing.
2475 auto X = CallInst::Create(NewF, Arg0, "", CI);
2476
2477 // Cast the half* pointer to int*.
2478 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2479
2480 // Index into the correct address of the casted pointer.
2481 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2482
2483 // Store to the int* we casted to.
2484 auto Store = new StoreInst(X, Index, CI);
2485
2486 CI->replaceAllUsesWith(Store);
2487
2488 // Lastly, remember to remove the user.
2489 ToRemoves.push_back(CI);
2490 }
2491 }
2492
2493 Changed = !ToRemoves.empty();
2494
2495 // And cleanup the calls we don't use anymore.
2496 for (auto V : ToRemoves) {
2497 V->eraseFromParent();
2498 }
2499
2500 // And remove the function we don't need either too.
2501 F->eraseFromParent();
2502 }
2503 }
2504
2505 return Changed;
2506}
2507
2508bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2509 bool Changed = false;
2510
David Netoe2871522018-06-08 11:09:54 -07002511 const std::vector<const char *> Map = {
2512 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2513 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2514 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2515 "_Z13vstorea_half4Dv4_fjPDh", // private
2516 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2517 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2518 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2519 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2520 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2521 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2522 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2523 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2524 };
David Neto22f144c2017-06-12 14:26:21 -04002525
2526 for (auto Name : Map) {
2527 // If we find a function with the matching name.
2528 if (auto F = M.getFunction(Name)) {
2529 SmallVector<Instruction *, 4> ToRemoves;
2530
2531 // Walk the users of the function.
2532 for (auto &U : F->uses()) {
2533 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2534 // The value to store.
2535 auto Arg0 = CI->getOperand(0);
2536
2537 // The index argument from vstore_half.
2538 auto Arg1 = CI->getOperand(1);
2539
2540 // The pointer argument from vstore_half.
2541 auto Arg2 = CI->getOperand(2);
2542
2543 auto IntTy = Type::getInt32Ty(M.getContext());
2544 auto Int2Ty = VectorType::get(IntTy, 2);
2545 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2546 auto NewPointerTy = PointerType::get(
2547 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2548 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2549
2550 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2551 ConstantInt::get(IntTy, 1)};
2552
2553 // Extract out the x & y components of our to store value.
2554 auto Lo =
2555 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2556 ConstantVector::get(LoShuffleMask), "", CI);
2557
2558 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2559 ConstantInt::get(IntTy, 3)};
2560
2561 // Extract out the z & w components of our to store value.
2562 auto Hi =
2563 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2564 ConstantVector::get(HiShuffleMask), "", CI);
2565
2566 // Our intrinsic to pack a float2 to an int.
2567 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2568
2569 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2570
2571 // Turn the packed x & y into the final component of our int2.
2572 auto X = CallInst::Create(NewF, Lo, "", CI);
2573
2574 // Turn the packed z & w into the final component of our int2.
2575 auto Y = CallInst::Create(NewF, Hi, "", CI);
2576
2577 auto Combine = InsertElementInst::Create(
2578 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2579 Combine = InsertElementInst::Create(
2580 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2581
2582 // Cast the half* pointer to int2*.
2583 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2584
2585 // Index into the correct address of the casted pointer.
2586 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2587
2588 // Store to the int2* we casted to.
2589 auto Store = new StoreInst(Combine, Index, CI);
2590
2591 CI->replaceAllUsesWith(Store);
2592
2593 // Lastly, remember to remove the user.
2594 ToRemoves.push_back(CI);
2595 }
2596 }
2597
2598 Changed = !ToRemoves.empty();
2599
2600 // And cleanup the calls we don't use anymore.
2601 for (auto V : ToRemoves) {
2602 V->eraseFromParent();
2603 }
2604
2605 // And remove the function we don't need either too.
2606 F->eraseFromParent();
2607 }
2608 }
2609
2610 return Changed;
2611}
2612
2613bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2614 bool Changed = false;
2615
2616 const std::map<const char *, const char*> Map = {
2617 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2618 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2619 };
2620
2621 for (auto Pair : Map) {
2622 // If we find a function with the matching name.
2623 if (auto F = M.getFunction(Pair.first)) {
2624 SmallVector<Instruction *, 4> ToRemoves;
2625
2626 // Walk the users of the function.
2627 for (auto &U : F->uses()) {
2628 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2629 // The image.
2630 auto Arg0 = CI->getOperand(0);
2631
2632 // The sampler.
2633 auto Arg1 = CI->getOperand(1);
2634
2635 // The coordinate (integer type that we can't handle).
2636 auto Arg2 = CI->getOperand(2);
2637
2638 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2639
2640 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2641
2642 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2643
2644 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2645
2646 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2647
2648 CI->replaceAllUsesWith(NewCI);
2649
2650 // Lastly, remember to remove the user.
2651 ToRemoves.push_back(CI);
2652 }
2653 }
2654
2655 Changed = !ToRemoves.empty();
2656
2657 // And cleanup the calls we don't use anymore.
2658 for (auto V : ToRemoves) {
2659 V->eraseFromParent();
2660 }
2661
2662 // And remove the function we don't need either too.
2663 F->eraseFromParent();
2664 }
2665 }
2666
2667 return Changed;
2668}
2669
2670bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2671 bool Changed = false;
2672
2673 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002674 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2675 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2676 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2677 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2678 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2679 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002680 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2681 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2682 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2683 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2684 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002685 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002686
2687 for (auto Pair : Map) {
2688 // If we find a function with the matching name.
2689 if (auto F = M.getFunction(Pair.first)) {
2690 SmallVector<Instruction *, 4> ToRemoves;
2691
2692 // Walk the users of the function.
2693 for (auto &U : F->uses()) {
2694 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2695 auto FType = F->getFunctionType();
2696 SmallVector<Type *, 5> ParamTypes;
2697
2698 // The pointer type.
2699 ParamTypes.push_back(FType->getParamType(0));
2700
2701 auto IntTy = Type::getInt32Ty(M.getContext());
2702
2703 // The memory scope type.
2704 ParamTypes.push_back(IntTy);
2705
2706 // The memory semantics type.
2707 ParamTypes.push_back(IntTy);
2708
2709 if (2 < CI->getNumArgOperands()) {
2710 // The unequal memory semantics type.
2711 ParamTypes.push_back(IntTy);
2712
2713 // The value type.
2714 ParamTypes.push_back(FType->getParamType(2));
2715
2716 // The comparator type.
2717 ParamTypes.push_back(FType->getParamType(1));
2718 } else if (1 < CI->getNumArgOperands()) {
2719 // The value type.
2720 ParamTypes.push_back(FType->getParamType(1));
2721 }
2722
2723 auto NewFType =
2724 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2725 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2726
2727 // We need to map the OpenCL constants to the SPIR-V equivalents.
2728 const auto ConstantScopeDevice =
2729 ConstantInt::get(IntTy, spv::ScopeDevice);
2730 const auto ConstantMemorySemantics = ConstantInt::get(
2731 IntTy, spv::MemorySemanticsUniformMemoryMask |
2732 spv::MemorySemanticsSequentiallyConsistentMask);
2733
2734 SmallVector<Value *, 5> Params;
2735
2736 // The pointer.
2737 Params.push_back(CI->getArgOperand(0));
2738
2739 // The memory scope.
2740 Params.push_back(ConstantScopeDevice);
2741
2742 // The memory semantics.
2743 Params.push_back(ConstantMemorySemantics);
2744
2745 if (2 < CI->getNumArgOperands()) {
2746 // The unequal memory semantics.
2747 Params.push_back(ConstantMemorySemantics);
2748
2749 // The value.
2750 Params.push_back(CI->getArgOperand(2));
2751
2752 // The comparator.
2753 Params.push_back(CI->getArgOperand(1));
2754 } else if (1 < CI->getNumArgOperands()) {
2755 // The value.
2756 Params.push_back(CI->getArgOperand(1));
2757 }
2758
2759 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2760
2761 CI->replaceAllUsesWith(NewCI);
2762
2763 // Lastly, remember to remove the user.
2764 ToRemoves.push_back(CI);
2765 }
2766 }
2767
2768 Changed = !ToRemoves.empty();
2769
2770 // And cleanup the calls we don't use anymore.
2771 for (auto V : ToRemoves) {
2772 V->eraseFromParent();
2773 }
2774
2775 // And remove the function we don't need either too.
2776 F->eraseFromParent();
2777 }
2778 }
2779
Neil Henning39672102017-09-29 14:33:13 +01002780 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002781 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2782 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2783 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2784 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2785 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2786 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2787 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2788 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2789 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2790 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2791 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2792 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2793 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2794 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2795 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2796 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002797 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2798 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2799 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2800 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2801 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2802 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2803 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2804 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2805 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2806 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2807 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2808 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2809 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2810 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2811 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2812 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2813
2814 for (auto Pair : Map2) {
2815 // If we find a function with the matching name.
2816 if (auto F = M.getFunction(Pair.first)) {
2817 SmallVector<Instruction *, 4> ToRemoves;
2818
2819 // Walk the users of the function.
2820 for (auto &U : F->uses()) {
2821 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2822 auto AtomicOp = new AtomicRMWInst(
2823 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2824 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2825
2826 CI->replaceAllUsesWith(AtomicOp);
2827
2828 // Lastly, remember to remove the user.
2829 ToRemoves.push_back(CI);
2830 }
2831 }
2832
2833 Changed = !ToRemoves.empty();
2834
2835 // And cleanup the calls we don't use anymore.
2836 for (auto V : ToRemoves) {
2837 V->eraseFromParent();
2838 }
2839
2840 // And remove the function we don't need either too.
2841 F->eraseFromParent();
2842 }
2843 }
2844
David Neto22f144c2017-06-12 14:26:21 -04002845 return Changed;
2846}
2847
2848bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2849 bool Changed = false;
2850
2851 // If we find a function with the matching name.
2852 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2853 SmallVector<Instruction *, 4> ToRemoves;
2854
2855 auto IntTy = Type::getInt32Ty(M.getContext());
2856 auto FloatTy = Type::getFloatTy(M.getContext());
2857
2858 Constant *DownShuffleMask[3] = {
2859 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2860 ConstantInt::get(IntTy, 2)};
2861
2862 Constant *UpShuffleMask[4] = {
2863 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2864 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2865
2866 Constant *FloatVec[3] = {
2867 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
2868 };
2869
2870 // Walk the users of the function.
2871 for (auto &U : F->uses()) {
2872 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2873 auto Vec4Ty = CI->getArgOperand(0)->getType();
2874 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2875 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2876 auto Vec3Ty = Arg0->getType();
2877
2878 auto NewFType =
2879 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
2880
2881 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
2882
2883 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
2884
2885 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
2886
2887 CI->replaceAllUsesWith(Result);
2888
2889 // Lastly, remember to remove the user.
2890 ToRemoves.push_back(CI);
2891 }
2892 }
2893
2894 Changed = !ToRemoves.empty();
2895
2896 // And cleanup the calls we don't use anymore.
2897 for (auto V : ToRemoves) {
2898 V->eraseFromParent();
2899 }
2900
2901 // And remove the function we don't need either too.
2902 F->eraseFromParent();
2903 }
2904
2905 return Changed;
2906}
David Neto62653202017-10-16 19:05:18 -04002907
2908bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2909 bool Changed = false;
2910
2911 // OpenCL's float result = fract(float x, float* ptr)
2912 //
2913 // In the LLVM domain:
2914 //
2915 // %floor_result = call spir_func float @floor(float %x)
2916 // store float %floor_result, float * %ptr
2917 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2918 // %result = call spir_func float
2919 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2920 //
2921 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2922 // and clspv.fract occur in the SPIR-V generator pass:
2923 //
2924 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2925 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2926 // ...
2927 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2928 // OpStore %ptr %floor_result
2929 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2930 // %fract_result = OpExtInst %float
2931 // %glsl_ext Fmin %fract_intermediate %just_under_1
2932
2933
2934 using std::string;
2935
2936 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2937 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
2938 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
2939 auto make_quad = [](const char *a, const char *b, const char *c,
2940 const char *d) {
2941 return std::tuple<const char *, const char *, const char *, const char *>(
2942 a, b, c, d);
2943 };
2944 const std::vector<QuadType> Functions = {
2945 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
2946 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
2947 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
2948 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
2949 };
2950
2951 for (auto& quad : Functions) {
2952 const StringRef fract_name(std::get<0>(quad));
2953
2954 // If we find a function with the matching name.
2955 if (auto F = M.getFunction(fract_name)) {
2956 if (F->use_begin() == F->use_end())
2957 continue;
2958
2959 // We have some uses.
2960 Changed = true;
2961
2962 auto& Context = M.getContext();
2963
2964 const StringRef floor_name(std::get<1>(quad));
2965 const StringRef fmin_name(std::get<2>(quad));
2966 const StringRef clspv_fract_name(std::get<3>(quad));
2967
2968 // This is either float or a float vector. All the float-like
2969 // types are this type.
2970 auto result_ty = F->getReturnType();
2971
2972 Function* fmin_fn = M.getFunction(fmin_name);
2973 if (!fmin_fn) {
2974 // Make the fmin function.
2975 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
2976 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002977 fmin_fn->addFnAttr(Attribute::ReadNone);
2978 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2979 }
2980
2981 Function* floor_fn = M.getFunction(floor_name);
2982 if (!floor_fn) {
2983 // Make the floor function.
2984 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2985 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002986 floor_fn->addFnAttr(Attribute::ReadNone);
2987 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2988 }
2989
2990 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
2991 if (!clspv_fract_fn) {
2992 // Make the clspv_fract function.
2993 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2994 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002995 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2996 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2997 }
2998
2999 // Number of significant significand bits, whether represented or not.
3000 unsigned num_significand_bits;
3001 switch (result_ty->getScalarType()->getTypeID()) {
3002 case Type::HalfTyID:
3003 num_significand_bits = 11;
3004 break;
3005 case Type::FloatTyID:
3006 num_significand_bits = 24;
3007 break;
3008 case Type::DoubleTyID:
3009 num_significand_bits = 53;
3010 break;
3011 default:
3012 assert(false && "Unhandled float type when processing fract builtin");
3013 break;
3014 }
3015 // Beware that the disassembler displays this value as
3016 // OpConstant %float 1
3017 // which is not quite right.
3018 const double kJustUnderOneScalar =
3019 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3020
3021 Constant *just_under_one =
3022 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3023 if (result_ty->isVectorTy()) {
3024 just_under_one = ConstantVector::getSplat(
3025 result_ty->getVectorNumElements(), just_under_one);
3026 }
3027
3028 IRBuilder<> Builder(Context);
3029
3030 SmallVector<Instruction *, 4> ToRemoves;
3031
3032 // Walk the users of the function.
3033 for (auto &U : F->uses()) {
3034 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3035
3036 Builder.SetInsertPoint(CI);
3037 auto arg = CI->getArgOperand(0);
3038 auto ptr = CI->getArgOperand(1);
3039
3040 // Compute floor result and store it.
3041 auto floor = Builder.CreateCall(floor_fn, {arg});
3042 Builder.CreateStore(floor, ptr);
3043
3044 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
3045 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
3046
3047 CI->replaceAllUsesWith(fract_result);
3048
3049 // Lastly, remember to remove the user.
3050 ToRemoves.push_back(CI);
3051 }
3052 }
3053
3054 // And cleanup the calls we don't use anymore.
3055 for (auto V : ToRemoves) {
3056 V->eraseFromParent();
3057 }
3058
3059 // And remove the function we don't need either too.
3060 F->eraseFromParent();
3061 }
3062 }
3063
3064 return Changed;
3065}