blob: a81bb32957c28348b79ae65b7d131c5c0f7201cf [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000024#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040025#include "llvm/Pass.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040029
David Neto118188e2018-08-24 11:27:54 -040030#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040031
David Neto482550a2018-03-24 05:21:07 -070032#include "clspv/Option.h"
33
David Neto22f144c2017-06-12 14:26:21 -040034using namespace llvm;
35
36#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
37
38namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000039
40struct ArgTypeInfo {
41 enum class SignedNess {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000042 None,
Kévin Petit8a560882019-03-21 15:24:34 +000043 Unsigned,
44 Signed
45 };
46 SignedNess signedness;
47};
48
49struct FunctionInfo {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000050 StringRef name;
Kévin Petit8a560882019-03-21 15:24:34 +000051 std::vector<ArgTypeInfo> argTypeInfos;
52};
53
54bool getFunctionInfoFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
55 if (!name.consume_front("_Z")) {
56 return false;
57 }
58 size_t nameLen;
59 if (name.consumeInteger(10, nameLen)) {
60 return false;
61 }
62
Kévin Petit9d1a9d12019-03-25 15:23:46 +000063 finfo->name = name.take_front(nameLen);
Kévin Petit8a560882019-03-21 15:24:34 +000064 name = name.drop_front(nameLen);
65
66 ArgTypeInfo prev_ti;
67
68 while (name.size() != 0) {
69
70 ArgTypeInfo ti;
71
72 // Try parsing a vector prefix
73 if (name.consume_front("Dv")) {
74 int numElems;
75 if (name.consumeInteger(10, numElems)) {
76 return false;
77 }
78
79 if (!name.consume_front("_")) {
80 return false;
81 }
82 }
83
84 // Parse the base type
85 char typeCode = name.front();
86 name = name.drop_front(1);
87 switch(typeCode) {
88 case 'c': // char
89 case 'a': // signed char
90 case 's': // short
91 case 'i': // int
92 case 'l': // long
93 ti.signedness = ArgTypeInfo::SignedNess::Signed;
94 break;
95 case 'h': // unsigned char
96 case 't': // unsigned short
97 case 'j': // unsigned int
98 case 'm': // unsigned long
99 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
100 break;
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000101 case 'f':
102 ti.signedness = ArgTypeInfo::SignedNess::None;
103 break;
Kévin Petit8a560882019-03-21 15:24:34 +0000104 case 'S':
105 ti = prev_ti;
106 if (!name.consume_front("_")) {
107 return false;
108 }
109 break;
110 default:
111 return false;
112 }
113
114 finfo->argTypeInfos.push_back(ti);
115
116 prev_ti = ti;
117 }
118
119 return true;
120};
121
122void getFunctionInfoFromMangledName(StringRef name, FunctionInfo *finfo) {
123 if (!getFunctionInfoFromMangledNameCheck(name, finfo)) {
124 llvm_unreachable("Can't parse mangled function name!");
125 }
126}
127
David Neto22f144c2017-06-12 14:26:21 -0400128uint32_t clz(uint32_t v) {
129 uint32_t r;
130 uint32_t shift;
131
132 r = (v > 0xFFFF) << 4;
133 v >>= r;
134 shift = (v > 0xFF) << 3;
135 v >>= shift;
136 r |= shift;
137 shift = (v > 0xF) << 2;
138 v >>= shift;
139 r |= shift;
140 shift = (v > 0x3) << 1;
141 v >>= shift;
142 r |= shift;
143 r |= (v >> 1);
144
145 return r;
146}
147
148Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
149 if (1 == elements) {
150 return Type::getInt1Ty(C);
151 } else {
152 return VectorType::get(Type::getInt1Ty(C), elements);
153 }
154}
155
156struct ReplaceOpenCLBuiltinPass final : public ModulePass {
157 static char ID;
158 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
159
160 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000161 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400162 bool replaceRecip(Module &M);
163 bool replaceDivide(Module &M);
164 bool replaceExp10(Module &M);
165 bool replaceLog10(Module &M);
166 bool replaceBarrier(Module &M);
167 bool replaceMemFence(Module &M);
168 bool replaceRelational(Module &M);
169 bool replaceIsInfAndIsNan(Module &M);
170 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000171 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000172 bool replaceRotate(Module &M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000173 bool replaceConvert(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000174 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000175 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000176 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000177 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400178 bool replaceSignbit(Module &M);
179 bool replaceMadandMad24andMul24(Module &M);
180 bool replaceVloadHalf(Module &M);
181 bool replaceVloadHalf2(Module &M);
182 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700183 bool replaceClspvVloadaHalf2(Module &M);
184 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400185 bool replaceVstoreHalf(Module &M);
186 bool replaceVstoreHalf2(Module &M);
187 bool replaceVstoreHalf4(Module &M);
188 bool replaceReadImageF(Module &M);
189 bool replaceAtomics(Module &M);
190 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400191 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700192 bool replaceVload(Module &M);
193 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400194};
195}
196
197char ReplaceOpenCLBuiltinPass::ID = 0;
198static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
199 "Replace OpenCL Builtins Pass");
200
201namespace clspv {
202ModulePass *createReplaceOpenCLBuiltinPass() {
203 return new ReplaceOpenCLBuiltinPass();
204}
205}
206
207bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
208 bool Changed = false;
209
Kévin Petit2444e9b2018-11-09 14:14:37 +0000210 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400211 Changed |= replaceRecip(M);
212 Changed |= replaceDivide(M);
213 Changed |= replaceExp10(M);
214 Changed |= replaceLog10(M);
215 Changed |= replaceBarrier(M);
216 Changed |= replaceMemFence(M);
217 Changed |= replaceRelational(M);
218 Changed |= replaceIsInfAndIsNan(M);
219 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000220 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000221 Changed |= replaceRotate(M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000222 Changed |= replaceConvert(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000223 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000224 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000225 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000226 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400227 Changed |= replaceSignbit(M);
228 Changed |= replaceMadandMad24andMul24(M);
229 Changed |= replaceVloadHalf(M);
230 Changed |= replaceVloadHalf2(M);
231 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700232 Changed |= replaceClspvVloadaHalf2(M);
233 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400234 Changed |= replaceVstoreHalf(M);
235 Changed |= replaceVstoreHalf2(M);
236 Changed |= replaceVstoreHalf4(M);
237 Changed |= replaceReadImageF(M);
238 Changed |= replaceAtomics(M);
239 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400240 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700241 Changed |= replaceVload(M);
242 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400243
244 return Changed;
245}
246
Kévin Petit2444e9b2018-11-09 14:14:37 +0000247bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
248 bool Changed = false;
249
250 const char *Names[] = {
251 "_Z3abst",
252 "_Z3absDv2_t",
253 "_Z3absDv3_t",
254 "_Z3absDv4_t",
255 "_Z3absj",
256 "_Z3absDv2_j",
257 "_Z3absDv3_j",
258 "_Z3absDv4_j",
259 "_Z3absm",
260 "_Z3absDv2_m",
261 "_Z3absDv3_m",
262 "_Z3absDv4_m",
263 };
264
265 for (auto Name : Names) {
266 // If we find a function with the matching name.
267 if (auto F = M.getFunction(Name)) {
268 SmallVector<Instruction *, 4> ToRemoves;
269
270 // Walk the users of the function.
271 for (auto &U : F->uses()) {
272 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
273 // Abs has one arg.
274 auto Arg = CI->getOperand(0);
275
276 // Use the argument unchanged, we know it's unsigned
277 CI->replaceAllUsesWith(Arg);
278
279 // Lastly, remember to remove the user.
280 ToRemoves.push_back(CI);
281 }
282 }
283
284 Changed = !ToRemoves.empty();
285
286 // And cleanup the calls we don't use anymore.
287 for (auto V : ToRemoves) {
288 V->eraseFromParent();
289 }
290
291 // And remove the function we don't need either too.
292 F->eraseFromParent();
293 }
294 }
295
296 return Changed;
297}
298
David Neto22f144c2017-06-12 14:26:21 -0400299bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
300 bool Changed = false;
301
302 const char *Names[] = {
303 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
304 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
305 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
306 };
307
308 for (auto Name : Names) {
309 // If we find a function with the matching name.
310 if (auto F = M.getFunction(Name)) {
311 SmallVector<Instruction *, 4> ToRemoves;
312
313 // Walk the users of the function.
314 for (auto &U : F->uses()) {
315 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
316 // Recip has one arg.
317 auto Arg = CI->getOperand(0);
318
319 auto Div = BinaryOperator::Create(
320 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
321 CI);
322
323 CI->replaceAllUsesWith(Div);
324
325 // Lastly, remember to remove the user.
326 ToRemoves.push_back(CI);
327 }
328 }
329
330 Changed = !ToRemoves.empty();
331
332 // And cleanup the calls we don't use anymore.
333 for (auto V : ToRemoves) {
334 V->eraseFromParent();
335 }
336
337 // And remove the function we don't need either too.
338 F->eraseFromParent();
339 }
340 }
341
342 return Changed;
343}
344
345bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
346 bool Changed = false;
347
348 const char *Names[] = {
349 "_Z11half_divideff", "_Z13native_divideff",
350 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
351 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
352 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
353 };
354
355 for (auto Name : Names) {
356 // If we find a function with the matching name.
357 if (auto F = M.getFunction(Name)) {
358 SmallVector<Instruction *, 4> ToRemoves;
359
360 // Walk the users of the function.
361 for (auto &U : F->uses()) {
362 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
363 auto Div = BinaryOperator::Create(
364 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
365
366 CI->replaceAllUsesWith(Div);
367
368 // Lastly, remember to remove the user.
369 ToRemoves.push_back(CI);
370 }
371 }
372
373 Changed = !ToRemoves.empty();
374
375 // And cleanup the calls we don't use anymore.
376 for (auto V : ToRemoves) {
377 V->eraseFromParent();
378 }
379
380 // And remove the function we don't need either too.
381 F->eraseFromParent();
382 }
383 }
384
385 return Changed;
386}
387
388bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
389 bool Changed = false;
390
391 const std::map<const char *, const char *> Map = {
392 {"_Z5exp10f", "_Z3expf"},
393 {"_Z10half_exp10f", "_Z8half_expf"},
394 {"_Z12native_exp10f", "_Z10native_expf"},
395 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
396 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
397 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
398 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
399 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
400 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
401 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
402 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
403 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
404
405 for (auto Pair : Map) {
406 // If we find a function with the matching name.
407 if (auto F = M.getFunction(Pair.first)) {
408 SmallVector<Instruction *, 4> ToRemoves;
409
410 // Walk the users of the function.
411 for (auto &U : F->uses()) {
412 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
413 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
414
415 auto Arg = CI->getOperand(0);
416
417 // Constant of the natural log of 10 (ln(10)).
418 const double Ln10 =
419 2.302585092994045684017991454684364207601101488628772976033;
420
421 auto Mul = BinaryOperator::Create(
422 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
423 CI);
424
425 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
426
427 CI->replaceAllUsesWith(NewCI);
428
429 // Lastly, remember to remove the user.
430 ToRemoves.push_back(CI);
431 }
432 }
433
434 Changed = !ToRemoves.empty();
435
436 // And cleanup the calls we don't use anymore.
437 for (auto V : ToRemoves) {
438 V->eraseFromParent();
439 }
440
441 // And remove the function we don't need either too.
442 F->eraseFromParent();
443 }
444 }
445
446 return Changed;
447}
448
449bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
450 bool Changed = false;
451
452 const std::map<const char *, const char *> Map = {
453 {"_Z5log10f", "_Z3logf"},
454 {"_Z10half_log10f", "_Z8half_logf"},
455 {"_Z12native_log10f", "_Z10native_logf"},
456 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
457 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
458 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
459 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
460 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
461 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
462 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
463 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
464 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
465
466 for (auto Pair : Map) {
467 // If we find a function with the matching name.
468 if (auto F = M.getFunction(Pair.first)) {
469 SmallVector<Instruction *, 4> ToRemoves;
470
471 // Walk the users of the function.
472 for (auto &U : F->uses()) {
473 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
474 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
475
476 auto Arg = CI->getOperand(0);
477
478 // Constant of the reciprocal of the natural log of 10 (ln(10)).
479 const double Ln10 =
480 0.434294481903251827651128918916605082294397005803666566114;
481
482 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
483
484 auto Mul = BinaryOperator::Create(
485 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
486 "", CI);
487
488 CI->replaceAllUsesWith(Mul);
489
490 // Lastly, remember to remove the user.
491 ToRemoves.push_back(CI);
492 }
493 }
494
495 Changed = !ToRemoves.empty();
496
497 // And cleanup the calls we don't use anymore.
498 for (auto V : ToRemoves) {
499 V->eraseFromParent();
500 }
501
502 // And remove the function we don't need either too.
503 F->eraseFromParent();
504 }
505 }
506
507 return Changed;
508}
509
510bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
511 bool Changed = false;
512
513 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
514
515 const std::map<const char *, const char *> Map = {
516 {"_Z7barrierj", "__spirv_control_barrier"}};
517
518 for (auto Pair : Map) {
519 // If we find a function with the matching name.
520 if (auto F = M.getFunction(Pair.first)) {
521 SmallVector<Instruction *, 4> ToRemoves;
522
523 // Walk the users of the function.
524 for (auto &U : F->uses()) {
525 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
526 auto FType = F->getFunctionType();
527 SmallVector<Type *, 3> Params;
528 for (unsigned i = 0; i < 3; i++) {
529 Params.push_back(FType->getParamType(0));
530 }
531 auto NewFType =
532 FunctionType::get(FType->getReturnType(), Params, false);
533 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
534
535 auto Arg = CI->getOperand(0);
536
537 // We need to map the OpenCL constants to the SPIR-V equivalents.
538 const auto LocalMemFence =
539 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
540 const auto GlobalMemFence =
541 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
542 const auto ConstantSequentiallyConsistent = ConstantInt::get(
543 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
544 const auto ConstantScopeDevice =
545 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
546 const auto ConstantScopeWorkgroup =
547 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
548
549 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
550 const auto LocalMemFenceMask = BinaryOperator::Create(
551 Instruction::And, LocalMemFence, Arg, "", CI);
552 const auto WorkgroupShiftAmount =
553 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
554 clz(CLK_LOCAL_MEM_FENCE);
555 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
556 Instruction::Shl, LocalMemFenceMask,
557 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
558
559 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
560 const auto GlobalMemFenceMask = BinaryOperator::Create(
561 Instruction::And, GlobalMemFence, Arg, "", CI);
562 const auto UniformShiftAmount =
563 clz(spv::MemorySemanticsUniformMemoryMask) -
564 clz(CLK_GLOBAL_MEM_FENCE);
565 const auto MemorySemanticsUniform = BinaryOperator::Create(
566 Instruction::Shl, GlobalMemFenceMask,
567 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
568
569 // And combine the above together, also adding in
570 // MemorySemanticsSequentiallyConsistentMask.
571 auto MemorySemantics =
572 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
573 ConstantSequentiallyConsistent, "", CI);
574 MemorySemantics = BinaryOperator::Create(
575 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
576
577 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
578 // Device Scope, otherwise Workgroup Scope.
579 const auto Cmp =
580 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
581 GlobalMemFenceMask, GlobalMemFence, "", CI);
582 const auto MemoryScope = SelectInst::Create(
583 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
584
585 // Lastly, the Execution Scope is always Workgroup Scope.
586 const auto ExecutionScope = ConstantScopeWorkgroup;
587
588 auto NewCI = CallInst::Create(
589 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
590
591 CI->replaceAllUsesWith(NewCI);
592
593 // Lastly, remember to remove the user.
594 ToRemoves.push_back(CI);
595 }
596 }
597
598 Changed = !ToRemoves.empty();
599
600 // And cleanup the calls we don't use anymore.
601 for (auto V : ToRemoves) {
602 V->eraseFromParent();
603 }
604
605 // And remove the function we don't need either too.
606 F->eraseFromParent();
607 }
608 }
609
610 return Changed;
611}
612
613bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
614 bool Changed = false;
615
616 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
617
Neil Henning39672102017-09-29 14:33:13 +0100618 using Tuple = std::tuple<const char *, unsigned>;
619 const std::map<const char *, Tuple> Map = {
620 {"_Z9mem_fencej",
621 Tuple("__spirv_memory_barrier",
622 spv::MemorySemanticsSequentiallyConsistentMask)},
623 {"_Z14read_mem_fencej",
624 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
625 {"_Z15write_mem_fencej",
626 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400627
628 for (auto Pair : Map) {
629 // If we find a function with the matching name.
630 if (auto F = M.getFunction(Pair.first)) {
631 SmallVector<Instruction *, 4> ToRemoves;
632
633 // Walk the users of the function.
634 for (auto &U : F->uses()) {
635 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
636 auto FType = F->getFunctionType();
637 SmallVector<Type *, 2> Params;
638 for (unsigned i = 0; i < 2; i++) {
639 Params.push_back(FType->getParamType(0));
640 }
641 auto NewFType =
642 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100643 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400644
645 auto Arg = CI->getOperand(0);
646
647 // We need to map the OpenCL constants to the SPIR-V equivalents.
648 const auto LocalMemFence =
649 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
650 const auto GlobalMemFence =
651 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
652 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100653 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400654 const auto ConstantScopeDevice =
655 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
656
657 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
658 const auto LocalMemFenceMask = BinaryOperator::Create(
659 Instruction::And, LocalMemFence, Arg, "", CI);
660 const auto WorkgroupShiftAmount =
661 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
662 clz(CLK_LOCAL_MEM_FENCE);
663 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
664 Instruction::Shl, LocalMemFenceMask,
665 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
666
667 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
668 const auto GlobalMemFenceMask = BinaryOperator::Create(
669 Instruction::And, GlobalMemFence, Arg, "", CI);
670 const auto UniformShiftAmount =
671 clz(spv::MemorySemanticsUniformMemoryMask) -
672 clz(CLK_GLOBAL_MEM_FENCE);
673 const auto MemorySemanticsUniform = BinaryOperator::Create(
674 Instruction::Shl, GlobalMemFenceMask,
675 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
676
677 // And combine the above together, also adding in
678 // MemorySemanticsSequentiallyConsistentMask.
679 auto MemorySemantics =
680 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
681 ConstantMemorySemantics, "", CI);
682 MemorySemantics = BinaryOperator::Create(
683 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
684
685 // Memory Scope is always device.
686 const auto MemoryScope = ConstantScopeDevice;
687
688 auto NewCI =
689 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
690
691 CI->replaceAllUsesWith(NewCI);
692
693 // Lastly, remember to remove the user.
694 ToRemoves.push_back(CI);
695 }
696 }
697
698 Changed = !ToRemoves.empty();
699
700 // And cleanup the calls we don't use anymore.
701 for (auto V : ToRemoves) {
702 V->eraseFromParent();
703 }
704
705 // And remove the function we don't need either too.
706 F->eraseFromParent();
707 }
708 }
709
710 return Changed;
711}
712
713bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
714 bool Changed = false;
715
716 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
717 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
718 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
719 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
720 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
721 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
722 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
723 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
724 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
725 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
726 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
727 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
728 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
729 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
730 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
731 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
732 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
733 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
734 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
735 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
736 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
737 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
738 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
739 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
740 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
741 };
742
743 for (auto Pair : Map) {
744 // If we find a function with the matching name.
745 if (auto F = M.getFunction(Pair.first)) {
746 SmallVector<Instruction *, 4> ToRemoves;
747
748 // Walk the users of the function.
749 for (auto &U : F->uses()) {
750 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
751 // The predicate to use in the CmpInst.
752 auto Predicate = Pair.second.first;
753
754 // The value to return for true.
755 auto TrueValue =
756 ConstantInt::getSigned(CI->getType(), Pair.second.second);
757
758 // The value to return for false.
759 auto FalseValue = Constant::getNullValue(CI->getType());
760
761 auto Arg1 = CI->getOperand(0);
762 auto Arg2 = CI->getOperand(1);
763
764 const auto Cmp =
765 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
766
767 const auto Select =
768 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
769
770 CI->replaceAllUsesWith(Select);
771
772 // Lastly, remember to remove the user.
773 ToRemoves.push_back(CI);
774 }
775 }
776
777 Changed = !ToRemoves.empty();
778
779 // And cleanup the calls we don't use anymore.
780 for (auto V : ToRemoves) {
781 V->eraseFromParent();
782 }
783
784 // And remove the function we don't need either too.
785 F->eraseFromParent();
786 }
787 }
788
789 return Changed;
790}
791
792bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
793 bool Changed = false;
794
795 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
796 {"_Z5isinff", {"__spirv_isinff", 1}},
797 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
798 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
799 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
800 {"_Z5isnanf", {"__spirv_isnanf", 1}},
801 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
802 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
803 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
804 };
805
806 for (auto Pair : Map) {
807 // If we find a function with the matching name.
808 if (auto F = M.getFunction(Pair.first)) {
809 SmallVector<Instruction *, 4> ToRemoves;
810
811 // Walk the users of the function.
812 for (auto &U : F->uses()) {
813 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
814 const auto CITy = CI->getType();
815
816 // The fake SPIR-V intrinsic to generate.
817 auto SPIRVIntrinsic = Pair.second.first;
818
819 // The value to return for true.
820 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
821
822 // The value to return for false.
823 auto FalseValue = Constant::getNullValue(CITy);
824
825 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
826 M.getContext(),
827 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
828
829 auto NewFType =
830 FunctionType::get(CorrespondingBoolTy,
831 F->getFunctionType()->getParamType(0), false);
832
833 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
834
835 auto Arg = CI->getOperand(0);
836
837 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
838
839 const auto Select =
840 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
841
842 CI->replaceAllUsesWith(Select);
843
844 // Lastly, remember to remove the user.
845 ToRemoves.push_back(CI);
846 }
847 }
848
849 Changed = !ToRemoves.empty();
850
851 // And cleanup the calls we don't use anymore.
852 for (auto V : ToRemoves) {
853 V->eraseFromParent();
854 }
855
856 // And remove the function we don't need either too.
857 F->eraseFromParent();
858 }
859 }
860
861 return Changed;
862}
863
864bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
865 bool Changed = false;
866
867 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000868 // all
alan-bakerb39c8262019-03-08 14:03:37 -0500869 {"_Z3allc", ""},
870 {"_Z3allDv2_c", "__spirv_allDv2_c"},
871 {"_Z3allDv3_c", "__spirv_allDv3_c"},
872 {"_Z3allDv4_c", "__spirv_allDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000873 {"_Z3alls", ""},
874 {"_Z3allDv2_s", "__spirv_allDv2_s"},
875 {"_Z3allDv3_s", "__spirv_allDv3_s"},
876 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400877 {"_Z3alli", ""},
878 {"_Z3allDv2_i", "__spirv_allDv2_i"},
879 {"_Z3allDv3_i", "__spirv_allDv3_i"},
880 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000881 {"_Z3alll", ""},
882 {"_Z3allDv2_l", "__spirv_allDv2_l"},
883 {"_Z3allDv3_l", "__spirv_allDv3_l"},
884 {"_Z3allDv4_l", "__spirv_allDv4_l"},
885
886 // any
alan-bakerb39c8262019-03-08 14:03:37 -0500887 {"_Z3anyc", ""},
888 {"_Z3anyDv2_c", "__spirv_anyDv2_c"},
889 {"_Z3anyDv3_c", "__spirv_anyDv3_c"},
890 {"_Z3anyDv4_c", "__spirv_anyDv4_c"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000891 {"_Z3anys", ""},
892 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
893 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
894 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400895 {"_Z3anyi", ""},
896 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
897 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
898 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000899 {"_Z3anyl", ""},
900 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
901 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
902 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400903 };
904
905 for (auto Pair : Map) {
906 // If we find a function with the matching name.
907 if (auto F = M.getFunction(Pair.first)) {
908 SmallVector<Instruction *, 4> ToRemoves;
909
910 // Walk the users of the function.
911 for (auto &U : F->uses()) {
912 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
913 // The fake SPIR-V intrinsic to generate.
914 auto SPIRVIntrinsic = Pair.second;
915
916 auto Arg = CI->getOperand(0);
917
918 Value *V;
919
Kévin Petitfd27cca2018-10-31 13:00:17 +0000920 // If the argument is a 32-bit int, just use a shift
921 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
922 V = BinaryOperator::Create(Instruction::LShr, Arg,
923 ConstantInt::get(Arg->getType(), 31), "",
924 CI);
925 } else {
David Neto22f144c2017-06-12 14:26:21 -0400926 // The value for zero to compare against.
927 const auto ZeroValue = Constant::getNullValue(Arg->getType());
928
David Neto22f144c2017-06-12 14:26:21 -0400929 // The value to return for true.
930 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
931
932 // The value to return for false.
933 const auto FalseValue = Constant::getNullValue(CI->getType());
934
Kévin Petitfd27cca2018-10-31 13:00:17 +0000935 const auto Cmp = CmpInst::Create(
936 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
937
938 Value* SelectSource;
939
940 // If we have a function to call, call it!
941 if (0 < strlen(SPIRVIntrinsic)) {
942
943 const auto NewFType = FunctionType::get(
944 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
945
946 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
947
948 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
949
950 SelectSource = NewCI;
951
952 } else {
953 SelectSource = Cmp;
954 }
955
956 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400957 }
958
959 CI->replaceAllUsesWith(V);
960
961 // Lastly, remember to remove the user.
962 ToRemoves.push_back(CI);
963 }
964 }
965
966 Changed = !ToRemoves.empty();
967
968 // And cleanup the calls we don't use anymore.
969 for (auto V : ToRemoves) {
970 V->eraseFromParent();
971 }
972
973 // And remove the function we don't need either too.
974 F->eraseFromParent();
975 }
976 }
977
978 return Changed;
979}
980
Kévin Petitbf0036c2019-03-06 13:57:10 +0000981bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
982 bool Changed = false;
983
984 for (auto const &SymVal : M.getValueSymbolTable()) {
985 // Skip symbols whose name doesn't match
986 if (!SymVal.getKey().startswith("_Z8upsample")) {
987 continue;
988 }
989 // Is there a function going by that name?
990 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
991
992 SmallVector<Instruction *, 4> ToRemoves;
993
994 // Walk the users of the function.
995 for (auto &U : F->uses()) {
996 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
997
998 // Get arguments
999 auto HiValue = CI->getOperand(0);
1000 auto LoValue = CI->getOperand(1);
1001
1002 // Don't touch overloads that aren't in OpenCL C
1003 auto HiType = HiValue->getType();
1004 auto LoType = LoValue->getType();
1005
1006 if (HiType != LoType) {
1007 continue;
1008 }
1009
1010 if (!HiType->isIntOrIntVectorTy()) {
1011 continue;
1012 }
1013
1014 if (HiType->getScalarSizeInBits() * 2 !=
1015 CI->getType()->getScalarSizeInBits()) {
1016 continue;
1017 }
1018
1019 if ((HiType->getScalarSizeInBits() != 8) &&
1020 (HiType->getScalarSizeInBits() != 16) &&
1021 (HiType->getScalarSizeInBits() != 32)) {
1022 continue;
1023 }
1024
1025 if (HiType->isVectorTy()) {
1026 if ((HiType->getVectorNumElements() != 2) &&
1027 (HiType->getVectorNumElements() != 3) &&
1028 (HiType->getVectorNumElements() != 4) &&
1029 (HiType->getVectorNumElements() != 8) &&
1030 (HiType->getVectorNumElements() != 16)) {
1031 continue;
1032 }
1033 }
1034
1035 // Convert both operands to the result type
1036 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
1037 "", CI);
1038 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
1039 "", CI);
1040
1041 // Shift high operand
1042 auto ShiftAmount = ConstantInt::get(CI->getType(),
1043 HiType->getScalarSizeInBits());
1044 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1045 ShiftAmount, "", CI);
1046
1047 // OR both results
1048 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1049 "", CI);
1050
1051 // Replace call with the expression
1052 CI->replaceAllUsesWith(V);
1053
1054 // Lastly, remember to remove the user.
1055 ToRemoves.push_back(CI);
1056 }
1057 }
1058
1059 Changed = !ToRemoves.empty();
1060
1061 // And cleanup the calls we don't use anymore.
1062 for (auto V : ToRemoves) {
1063 V->eraseFromParent();
1064 }
1065
1066 // And remove the function we don't need either too.
1067 F->eraseFromParent();
1068 }
1069 }
1070
1071 return Changed;
1072}
1073
Kévin Petitd44eef52019-03-08 13:22:14 +00001074bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1075 bool Changed = false;
1076
1077 for (auto const &SymVal : M.getValueSymbolTable()) {
1078 // Skip symbols whose name doesn't match
1079 if (!SymVal.getKey().startswith("_Z6rotate")) {
1080 continue;
1081 }
1082 // Is there a function going by that name?
1083 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1084
1085 SmallVector<Instruction *, 4> ToRemoves;
1086
1087 // Walk the users of the function.
1088 for (auto &U : F->uses()) {
1089 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1090
1091 // Get arguments
1092 auto SrcValue = CI->getOperand(0);
1093 auto RotAmount = CI->getOperand(1);
1094
1095 // Don't touch overloads that aren't in OpenCL C
1096 auto SrcType = SrcValue->getType();
1097 auto RotType = RotAmount->getType();
1098
1099 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1100 continue;
1101 }
1102
1103 if (!SrcType->isIntOrIntVectorTy()) {
1104 continue;
1105 }
1106
1107 if ((SrcType->getScalarSizeInBits() != 8) &&
1108 (SrcType->getScalarSizeInBits() != 16) &&
1109 (SrcType->getScalarSizeInBits() != 32) &&
1110 (SrcType->getScalarSizeInBits() != 64)) {
1111 continue;
1112 }
1113
1114 if (SrcType->isVectorTy()) {
1115 if ((SrcType->getVectorNumElements() != 2) &&
1116 (SrcType->getVectorNumElements() != 3) &&
1117 (SrcType->getVectorNumElements() != 4) &&
1118 (SrcType->getVectorNumElements() != 8) &&
1119 (SrcType->getVectorNumElements() != 16)) {
1120 continue;
1121 }
1122 }
1123
1124 // The approach used is to shift the top bits down, the bottom bits up
1125 // and OR the two shifted values.
1126
1127 // The rotation amount is to be treated modulo the element size.
1128 // Since SPIR-V shift ops don't support this, let's apply the
1129 // modulo ahead of shifting. The element size is always a power of
1130 // two so we can just AND with a mask.
1131 auto ModMask = ConstantInt::get(SrcType,
1132 SrcType->getScalarSizeInBits() - 1);
1133 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1134 ModMask, "", CI);
1135
1136 // Let's calc the amount by which to shift top bits down
1137 auto ScalarSize = ConstantInt::get(SrcType,
1138 SrcType->getScalarSizeInBits());
1139 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1140 RotAmount, "", CI);
1141
1142 // Now shift the bottom bits up and the top bits down
1143 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1144 RotAmount, "", CI);
1145 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1146 DownAmount, "", CI);
1147
1148 // Finally OR the two shifted values
1149 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1150 HiRotated, "", CI);
1151
1152 // Replace call with the expression
1153 CI->replaceAllUsesWith(V);
1154
1155 // Lastly, remember to remove the user.
1156 ToRemoves.push_back(CI);
1157 }
1158 }
1159
1160 Changed = !ToRemoves.empty();
1161
1162 // And cleanup the calls we don't use anymore.
1163 for (auto V : ToRemoves) {
1164 V->eraseFromParent();
1165 }
1166
1167 // And remove the function we don't need either too.
1168 F->eraseFromParent();
1169 }
1170 }
1171
1172 return Changed;
1173}
1174
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001175bool ReplaceOpenCLBuiltinPass::replaceConvert(Module &M) {
1176 bool Changed = false;
1177
1178 for (auto const &SymVal : M.getValueSymbolTable()) {
1179
1180 // Skip symbols whose name obviously doesn't match
1181 if (!SymVal.getKey().contains("convert_")) {
1182 continue;
1183 }
1184
1185 // Is there a function going by that name?
1186 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1187
1188 // Get info from the mangled name
1189 FunctionInfo finfo;
1190 bool parsed = getFunctionInfoFromMangledNameCheck(F->getName(), &finfo);
1191
1192 // All functions of interest are handled by our mangled name parser
1193 if (!parsed) {
1194 continue;
1195 }
1196
1197 // Move on if this isn't a call to convert_
1198 if (!finfo.name.startswith("convert_")) {
1199 continue;
1200 }
1201
1202 // Extract the destination type from the function name
1203 StringRef DstTypeName = finfo.name;
1204 DstTypeName.consume_front("convert_");
1205
1206 auto DstSignedNess = StringSwitch<ArgTypeInfo::SignedNess>(DstTypeName)
1207 .StartsWith("char", ArgTypeInfo::SignedNess::Signed)
1208 .StartsWith("short", ArgTypeInfo::SignedNess::Signed)
1209 .StartsWith("int", ArgTypeInfo::SignedNess::Signed)
1210 .StartsWith("long", ArgTypeInfo::SignedNess::Signed)
1211 .StartsWith("uchar", ArgTypeInfo::SignedNess::Unsigned)
1212 .StartsWith("ushort", ArgTypeInfo::SignedNess::Unsigned)
1213 .StartsWith("uint", ArgTypeInfo::SignedNess::Unsigned)
1214 .StartsWith("ulong", ArgTypeInfo::SignedNess::Unsigned)
1215 .Default(ArgTypeInfo::SignedNess::None);
1216
1217 auto SrcSignedNess = finfo.argTypeInfos[0].signedness;
1218
1219 bool DstIsSigned = DstSignedNess == ArgTypeInfo::SignedNess::Signed;
1220 bool SrcIsSigned = SrcSignedNess == ArgTypeInfo::SignedNess::Signed;
1221
1222 SmallVector<Instruction *, 4> ToRemoves;
1223
1224 // Walk the users of the function.
1225 for (auto &U : F->uses()) {
1226 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1227
1228 // Get arguments
1229 auto SrcValue = CI->getOperand(0);
1230
1231 // Don't touch overloads that aren't in OpenCL C
1232 auto SrcType = SrcValue->getType();
1233 auto DstType = CI->getType();
1234
1235 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1236 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1237 continue;
1238 }
1239
1240 if (SrcType->isVectorTy()) {
1241
1242 if (SrcType->getVectorNumElements() !=
1243 DstType->getVectorNumElements()) {
1244 continue;
1245 }
1246
1247 if ((SrcType->getVectorNumElements() != 2) &&
1248 (SrcType->getVectorNumElements() != 3) &&
1249 (SrcType->getVectorNumElements() != 4) &&
1250 (SrcType->getVectorNumElements() != 8) &&
1251 (SrcType->getVectorNumElements() != 16)) {
1252 continue;
1253 }
1254 }
1255
1256 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1257 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1258
1259 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1260 bool DstIsInt = DstType->isIntOrIntVectorTy();
1261
1262 Value *V;
1263 if (SrcIsFloat && DstIsFloat) {
1264 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1265 } else if (SrcIsFloat && DstIsInt) {
1266 if (DstIsSigned) {
1267 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "", CI);
1268 } else {
1269 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "", CI);
1270 }
1271 } else if (SrcIsInt && DstIsFloat) {
1272 if (SrcIsSigned) {
1273 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "", CI);
1274 } else {
1275 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "", CI);
1276 }
1277 } else if (SrcIsInt && DstIsInt) {
1278 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "", CI);
1279 } else {
1280 // Not something we're supposed to handle, just move on
1281 continue;
1282 }
1283
1284 // Replace call with the expression
1285 CI->replaceAllUsesWith(V);
1286
1287 // Lastly, remember to remove the user.
1288 ToRemoves.push_back(CI);
1289 }
1290 }
1291
1292 Changed = !ToRemoves.empty();
1293
1294 // And cleanup the calls we don't use anymore.
1295 for (auto V : ToRemoves) {
1296 V->eraseFromParent();
1297 }
1298
1299 // And remove the function we don't need either too.
1300 F->eraseFromParent();
1301 }
1302 }
1303
1304 return Changed;
1305}
1306
Kévin Petit8a560882019-03-21 15:24:34 +00001307bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1308 bool Changed = false;
1309
1310 for (auto const &SymVal : M.getValueSymbolTable()) {
1311
1312 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1313 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1314
1315 // Skip symbols whose name doesn't match
1316 if (!isMad && !isMul) {
1317 continue;
1318 }
1319
1320 // Is there a function going by that name?
1321 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1322
1323 SmallVector<Instruction *, 4> ToRemoves;
1324
1325 // Walk the users of the function.
1326 for (auto &U : F->uses()) {
1327 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1328
1329 // Get arguments
1330 auto AValue = CI->getOperand(0);
1331 auto BValue = CI->getOperand(1);
1332 auto CValue = CI->getOperand(2);
1333
1334 // Don't touch overloads that aren't in OpenCL C
1335 auto AType = AValue->getType();
1336 auto BType = BValue->getType();
1337 auto CType = CValue->getType();
1338
1339 if ((AType != BType) || (CI->getType() != AType) ||
1340 (isMad && (AType != CType))) {
1341 continue;
1342 }
1343
1344 if (!AType->isIntOrIntVectorTy()) {
1345 continue;
1346 }
1347
1348 if ((AType->getScalarSizeInBits() != 8) &&
1349 (AType->getScalarSizeInBits() != 16) &&
1350 (AType->getScalarSizeInBits() != 32) &&
1351 (AType->getScalarSizeInBits() != 64)) {
1352 continue;
1353 }
1354
1355 if (AType->isVectorTy()) {
1356 if ((AType->getVectorNumElements() != 2) &&
1357 (AType->getVectorNumElements() != 3) &&
1358 (AType->getVectorNumElements() != 4) &&
1359 (AType->getVectorNumElements() != 8) &&
1360 (AType->getVectorNumElements() != 16)) {
1361 continue;
1362 }
1363 }
1364
1365 // Create struct type for the return type of our SPIR-V intrinsic
1366 SmallVector<Type*, 2> TwoValueType = {
1367 AType,
1368 AType
1369 };
1370
1371 auto ExMulRetType = StructType::create(TwoValueType);
1372
1373 // And a function type
1374 auto NewFType = FunctionType::get(ExMulRetType, TwoValueType, false);
1375
1376 // Get infos from the mangled OpenCL built-in function name
1377 FunctionInfo finfo;
1378 getFunctionInfoFromMangledName(F->getName(), &finfo);
1379
1380 // Use it to select the appropriate signed/unsigned SPIR-V intrinsic
1381 StringRef intrinsic;
1382 if (finfo.argTypeInfos[0].signedness == ArgTypeInfo::SignedNess::Signed) {
1383 intrinsic = "spirv.smul_extended";
1384 } else {
1385 intrinsic = "spirv.umul_extended";
1386 }
1387
1388 // Add the intrinsic function to the module
1389 auto NewF = M.getOrInsertFunction(intrinsic, NewFType);
1390
1391 // Call it
1392 SmallVector<Value*, 4> NewFArgs = {
1393 AValue,
1394 BValue,
1395 };
1396
1397 auto Call = CallInst::Create(NewF, NewFArgs, "", CI);
1398
1399 // Get the high part of the result
1400 unsigned Idxs[] = {1};
1401 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1402
1403 // If we're handling a mad_hi, add the third argument to the result
1404 if (isMad) {
1405 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1406 }
1407
1408 // Replace call with the expression
1409 CI->replaceAllUsesWith(V);
1410
1411 // Lastly, remember to remove the user.
1412 ToRemoves.push_back(CI);
1413 }
1414 }
1415
1416 Changed = !ToRemoves.empty();
1417
1418 // And cleanup the calls we don't use anymore.
1419 for (auto V : ToRemoves) {
1420 V->eraseFromParent();
1421 }
1422
1423 // And remove the function we don't need either too.
1424 F->eraseFromParent();
1425 }
1426 }
1427
1428 return Changed;
1429}
1430
Kévin Petitf5b78a22018-10-25 14:32:17 +00001431bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1432 bool Changed = false;
1433
1434 for (auto const &SymVal : M.getValueSymbolTable()) {
1435 // Skip symbols whose name doesn't match
1436 if (!SymVal.getKey().startswith("_Z6select")) {
1437 continue;
1438 }
1439 // Is there a function going by that name?
1440 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1441
1442 SmallVector<Instruction *, 4> ToRemoves;
1443
1444 // Walk the users of the function.
1445 for (auto &U : F->uses()) {
1446 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1447
1448 // Get arguments
1449 auto FalseValue = CI->getOperand(0);
1450 auto TrueValue = CI->getOperand(1);
1451 auto PredicateValue = CI->getOperand(2);
1452
1453 // Don't touch overloads that aren't in OpenCL C
1454 auto FalseType = FalseValue->getType();
1455 auto TrueType = TrueValue->getType();
1456 auto PredicateType = PredicateValue->getType();
1457
1458 if (FalseType != TrueType) {
1459 continue;
1460 }
1461
1462 if (!PredicateType->isIntOrIntVectorTy()) {
1463 continue;
1464 }
1465
1466 if (!FalseType->isIntOrIntVectorTy() &&
1467 !FalseType->getScalarType()->isFloatingPointTy()) {
1468 continue;
1469 }
1470
1471 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1472 continue;
1473 }
1474
1475 if (FalseType->getScalarSizeInBits() !=
1476 PredicateType->getScalarSizeInBits()) {
1477 continue;
1478 }
1479
1480 if (FalseType->isVectorTy()) {
1481 if (FalseType->getVectorNumElements() !=
1482 PredicateType->getVectorNumElements()) {
1483 continue;
1484 }
1485
1486 if ((FalseType->getVectorNumElements() != 2) &&
1487 (FalseType->getVectorNumElements() != 3) &&
1488 (FalseType->getVectorNumElements() != 4) &&
1489 (FalseType->getVectorNumElements() != 8) &&
1490 (FalseType->getVectorNumElements() != 16)) {
1491 continue;
1492 }
1493 }
1494
1495 // Create constant
1496 const auto ZeroValue = Constant::getNullValue(PredicateType);
1497
1498 // Scalar and vector are to be treated differently
1499 CmpInst::Predicate Pred;
1500 if (PredicateType->isVectorTy()) {
1501 Pred = CmpInst::ICMP_SLT;
1502 } else {
1503 Pred = CmpInst::ICMP_NE;
1504 }
1505
1506 // Create comparison instruction
1507 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1508 ZeroValue, "", CI);
1509
1510 // Create select
1511 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1512
1513 // Replace call with the selection
1514 CI->replaceAllUsesWith(V);
1515
1516 // Lastly, remember to remove the user.
1517 ToRemoves.push_back(CI);
1518 }
1519 }
1520
1521 Changed = !ToRemoves.empty();
1522
1523 // And cleanup the calls we don't use anymore.
1524 for (auto V : ToRemoves) {
1525 V->eraseFromParent();
1526 }
1527
1528 // And remove the function we don't need either too.
1529 F->eraseFromParent();
1530 }
1531 }
1532
1533 return Changed;
1534}
1535
Kévin Petite7d0cce2018-10-31 12:38:56 +00001536bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1537 bool Changed = false;
1538
1539 for (auto const &SymVal : M.getValueSymbolTable()) {
1540 // Skip symbols whose name doesn't match
1541 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1542 continue;
1543 }
1544 // Is there a function going by that name?
1545 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1546
1547 SmallVector<Instruction *, 4> ToRemoves;
1548
1549 // Walk the users of the function.
1550 for (auto &U : F->uses()) {
1551 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1552
1553 if (CI->getNumOperands() != 4) {
1554 continue;
1555 }
1556
1557 // Get arguments
1558 auto FalseValue = CI->getOperand(0);
1559 auto TrueValue = CI->getOperand(1);
1560 auto PredicateValue = CI->getOperand(2);
1561
1562 // Don't touch overloads that aren't in OpenCL C
1563 auto FalseType = FalseValue->getType();
1564 auto TrueType = TrueValue->getType();
1565 auto PredicateType = PredicateValue->getType();
1566
1567 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1568 continue;
1569 }
1570
1571 if (TrueType->isVectorTy()) {
1572 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1573 !TrueType->getScalarType()->isIntegerTy()) {
1574 continue;
1575 }
1576 if ((TrueType->getVectorNumElements() != 2) &&
1577 (TrueType->getVectorNumElements() != 3) &&
1578 (TrueType->getVectorNumElements() != 4) &&
1579 (TrueType->getVectorNumElements() != 8) &&
1580 (TrueType->getVectorNumElements() != 16)) {
1581 continue;
1582 }
1583 }
1584
1585 // Remember the type of the operands
1586 auto OpType = TrueType;
1587
1588 // The actual bit selection will always be done on an integer type,
1589 // declare it here
1590 Type *BitType;
1591
1592 // If the operands are float, then bitcast them to int
1593 if (OpType->getScalarType()->isFloatingPointTy()) {
1594
1595 // First create the new type
1596 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1597 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1598 if (OpType->isVectorTy()) {
1599 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1600 }
1601
1602 // Then bitcast all operands
1603 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1604 BitType, "", CI);
1605 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1606 BitType, "", CI);
1607 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1608
1609 } else {
1610 // The operands have an integer type, use it directly
1611 BitType = OpType;
1612 }
1613
1614 // All the operands are now always integers
1615 // implement as (c & b) | (~c & a)
1616
1617 // Create our negated predicate value
1618 auto AllOnes = Constant::getAllOnesValue(BitType);
1619 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1620 PredicateValue,
1621 AllOnes, "", CI);
1622
1623 // Then put everything together
1624 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1625 NotPredicateValue,
1626 FalseValue, "", CI);
1627 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1628 PredicateValue,
1629 TrueValue, "", CI);
1630
1631 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1632 BitsTrue, "", CI);
1633
1634 // If we were dealing with a floating point type, we must bitcast
1635 // the result back to that
1636 if (OpType->getScalarType()->isFloatingPointTy()) {
1637 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1638 }
1639
1640 // Replace call with our new code
1641 CI->replaceAllUsesWith(V);
1642
1643 // Lastly, remember to remove the user.
1644 ToRemoves.push_back(CI);
1645 }
1646 }
1647
1648 Changed = !ToRemoves.empty();
1649
1650 // And cleanup the calls we don't use anymore.
1651 for (auto V : ToRemoves) {
1652 V->eraseFromParent();
1653 }
1654
1655 // And remove the function we don't need either too.
1656 F->eraseFromParent();
1657 }
1658 }
1659
1660 return Changed;
1661}
1662
Kévin Petit6b0a9532018-10-30 20:00:39 +00001663bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1664 bool Changed = false;
1665
1666 const std::map<const char *, const char *> Map = {
1667 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1668 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1669 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1670 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1671 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1672 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1673 };
1674
1675 for (auto Pair : Map) {
1676 // If we find a function with the matching name.
1677 if (auto F = M.getFunction(Pair.first)) {
1678 SmallVector<Instruction *, 4> ToRemoves;
1679
1680 // Walk the users of the function.
1681 for (auto &U : F->uses()) {
1682 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1683
1684 auto ReplacementFn = Pair.second;
1685
1686 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1687 Value *VectorArg;
1688
1689 // First figure out which function we're dealing with
1690 if (F->getName().startswith("_Z10smoothstep")) {
1691 ArgsToSplat.push_back(CI->getOperand(1));
1692 VectorArg = CI->getOperand(2);
1693 } else {
1694 VectorArg = CI->getOperand(1);
1695 }
1696
1697 // Splat arguments that need to be
1698 SmallVector<Value*, 2> SplatArgs;
1699 auto VecType = VectorArg->getType();
1700
1701 for (auto arg : ArgsToSplat) {
1702 Value* NewVectorArg = UndefValue::get(VecType);
1703 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1704 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1705 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1706 }
1707 SplatArgs.push_back(NewVectorArg);
1708 }
1709
1710 // Replace the call with the vector/vector flavour
1711 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1712 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1713
1714 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1715
1716 SmallVector<Value*, 3> NewArgs;
1717 for (auto arg : SplatArgs) {
1718 NewArgs.push_back(arg);
1719 }
1720 NewArgs.push_back(VectorArg);
1721
1722 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1723
1724 CI->replaceAllUsesWith(NewCI);
1725
1726 // Lastly, remember to remove the user.
1727 ToRemoves.push_back(CI);
1728 }
1729 }
1730
1731 Changed = !ToRemoves.empty();
1732
1733 // And cleanup the calls we don't use anymore.
1734 for (auto V : ToRemoves) {
1735 V->eraseFromParent();
1736 }
1737
1738 // And remove the function we don't need either too.
1739 F->eraseFromParent();
1740 }
1741 }
1742
1743 return Changed;
1744}
1745
David Neto22f144c2017-06-12 14:26:21 -04001746bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1747 bool Changed = false;
1748
1749 const std::map<const char *, Instruction::BinaryOps> Map = {
1750 {"_Z7signbitf", Instruction::LShr},
1751 {"_Z7signbitDv2_f", Instruction::AShr},
1752 {"_Z7signbitDv3_f", Instruction::AShr},
1753 {"_Z7signbitDv4_f", Instruction::AShr},
1754 };
1755
1756 for (auto Pair : Map) {
1757 // If we find a function with the matching name.
1758 if (auto F = M.getFunction(Pair.first)) {
1759 SmallVector<Instruction *, 4> ToRemoves;
1760
1761 // Walk the users of the function.
1762 for (auto &U : F->uses()) {
1763 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1764 auto Arg = CI->getOperand(0);
1765
1766 auto Bitcast =
1767 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1768
1769 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1770 ConstantInt::get(CI->getType(), 31),
1771 "", CI);
1772
1773 CI->replaceAllUsesWith(Shr);
1774
1775 // Lastly, remember to remove the user.
1776 ToRemoves.push_back(CI);
1777 }
1778 }
1779
1780 Changed = !ToRemoves.empty();
1781
1782 // And cleanup the calls we don't use anymore.
1783 for (auto V : ToRemoves) {
1784 V->eraseFromParent();
1785 }
1786
1787 // And remove the function we don't need either too.
1788 F->eraseFromParent();
1789 }
1790 }
1791
1792 return Changed;
1793}
1794
1795bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1796 bool Changed = false;
1797
1798 const std::map<const char *,
1799 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1800 Map = {
1801 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1802 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1803 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1804 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1805 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1806 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1807 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1808 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1809 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1810 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1811 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1812 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1813 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1814 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1815 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1816 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1817 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1818 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1819 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1820 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1821 };
1822
1823 for (auto Pair : Map) {
1824 // If we find a function with the matching name.
1825 if (auto F = M.getFunction(Pair.first)) {
1826 SmallVector<Instruction *, 4> ToRemoves;
1827
1828 // Walk the users of the function.
1829 for (auto &U : F->uses()) {
1830 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1831 // The multiply instruction to use.
1832 auto MulInst = Pair.second.first;
1833
1834 // The add instruction to use.
1835 auto AddInst = Pair.second.second;
1836
1837 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1838
1839 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1840 CI->getArgOperand(1), "", CI);
1841
1842 if (Instruction::BinaryOpsEnd != AddInst) {
1843 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1844 CI);
1845 }
1846
1847 CI->replaceAllUsesWith(I);
1848
1849 // Lastly, remember to remove the user.
1850 ToRemoves.push_back(CI);
1851 }
1852 }
1853
1854 Changed = !ToRemoves.empty();
1855
1856 // And cleanup the calls we don't use anymore.
1857 for (auto V : ToRemoves) {
1858 V->eraseFromParent();
1859 }
1860
1861 // And remove the function we don't need either too.
1862 F->eraseFromParent();
1863 }
1864 }
1865
1866 return Changed;
1867}
1868
Derek Chowcfd368b2017-10-19 20:58:45 -07001869bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1870 bool Changed = false;
1871
1872 struct VectorStoreOps {
1873 const char* name;
1874 int n;
1875 Type* (*get_scalar_type_function)(LLVMContext&);
1876 } vector_store_ops[] = {
1877 // TODO(derekjchow): Expand this list.
1878 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1879 };
1880
David Neto544fffc2017-11-16 18:35:14 -05001881 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001882 auto Name = Op.name;
1883 auto N = Op.n;
1884 auto TypeFn = Op.get_scalar_type_function;
1885 if (auto F = M.getFunction(Name)) {
1886 SmallVector<Instruction *, 4> ToRemoves;
1887
1888 // Walk the users of the function.
1889 for (auto &U : F->uses()) {
1890 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1891 // The value argument from vstoren.
1892 auto Arg0 = CI->getOperand(0);
1893
1894 // The index argument from vstoren.
1895 auto Arg1 = CI->getOperand(1);
1896
1897 // The pointer argument from vstoren.
1898 auto Arg2 = CI->getOperand(2);
1899
1900 // Get types.
1901 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1902 auto ScalarNPointerTy = PointerType::get(
1903 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1904
1905 // Cast to scalarn
1906 auto Cast = CastInst::CreatePointerCast(
1907 Arg2, ScalarNPointerTy, "", CI);
1908 // Index to correct address
1909 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1910 // Store
1911 auto Store = new StoreInst(Arg0, Index, CI);
1912
1913 CI->replaceAllUsesWith(Store);
1914 ToRemoves.push_back(CI);
1915 }
1916 }
1917
1918 Changed = !ToRemoves.empty();
1919
1920 // And cleanup the calls we don't use anymore.
1921 for (auto V : ToRemoves) {
1922 V->eraseFromParent();
1923 }
1924
1925 // And remove the function we don't need either too.
1926 F->eraseFromParent();
1927 }
1928 }
1929
1930 return Changed;
1931}
1932
1933bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1934 bool Changed = false;
1935
1936 struct VectorLoadOps {
1937 const char* name;
1938 int n;
1939 Type* (*get_scalar_type_function)(LLVMContext&);
1940 } vector_load_ops[] = {
1941 // TODO(derekjchow): Expand this list.
1942 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1943 };
1944
David Neto544fffc2017-11-16 18:35:14 -05001945 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001946 auto Name = Op.name;
1947 auto N = Op.n;
1948 auto TypeFn = Op.get_scalar_type_function;
1949 // If we find a function with the matching name.
1950 if (auto F = M.getFunction(Name)) {
1951 SmallVector<Instruction *, 4> ToRemoves;
1952
1953 // Walk the users of the function.
1954 for (auto &U : F->uses()) {
1955 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1956 // The index argument from vloadn.
1957 auto Arg0 = CI->getOperand(0);
1958
1959 // The pointer argument from vloadn.
1960 auto Arg1 = CI->getOperand(1);
1961
1962 // Get types.
1963 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1964 auto ScalarNPointerTy = PointerType::get(
1965 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1966
1967 // Cast to scalarn
1968 auto Cast = CastInst::CreatePointerCast(
1969 Arg1, ScalarNPointerTy, "", CI);
1970 // Index to correct address
1971 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1972 // Load
1973 auto Load = new LoadInst(Index, "", CI);
1974
1975 CI->replaceAllUsesWith(Load);
1976 ToRemoves.push_back(CI);
1977 }
1978 }
1979
1980 Changed = !ToRemoves.empty();
1981
1982 // And cleanup the calls we don't use anymore.
1983 for (auto V : ToRemoves) {
1984 V->eraseFromParent();
1985 }
1986
1987 // And remove the function we don't need either too.
1988 F->eraseFromParent();
1989
1990 }
1991 }
1992
1993 return Changed;
1994}
1995
David Neto22f144c2017-06-12 14:26:21 -04001996bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1997 bool Changed = false;
1998
1999 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
2000 "_Z10vload_halfjPU3AS2KDh"};
2001
2002 for (auto Name : Map) {
2003 // If we find a function with the matching name.
2004 if (auto F = M.getFunction(Name)) {
2005 SmallVector<Instruction *, 4> ToRemoves;
2006
2007 // Walk the users of the function.
2008 for (auto &U : F->uses()) {
2009 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2010 // The index argument from vload_half.
2011 auto Arg0 = CI->getOperand(0);
2012
2013 // The pointer argument from vload_half.
2014 auto Arg1 = CI->getOperand(1);
2015
David Neto22f144c2017-06-12 14:26:21 -04002016 auto IntTy = Type::getInt32Ty(M.getContext());
2017 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002018 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2019
David Neto22f144c2017-06-12 14:26:21 -04002020 // Our intrinsic to unpack a float2 from an int.
2021 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2022
2023 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2024
David Neto482550a2018-03-24 05:21:07 -07002025 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04002026 auto ShortTy = Type::getInt16Ty(M.getContext());
2027 auto ShortPointerTy = PointerType::get(
2028 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002029
David Netoac825b82017-05-30 12:49:01 -04002030 // Cast the half* pointer to short*.
2031 auto Cast =
2032 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002033
David Netoac825b82017-05-30 12:49:01 -04002034 // Index into the correct address of the casted pointer.
2035 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
2036
2037 // Load from the short* we casted to.
2038 auto Load = new LoadInst(Index, "", CI);
2039
2040 // ZExt the short -> int.
2041 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
2042
2043 // Get our float2.
2044 auto Call = CallInst::Create(NewF, ZExt, "", CI);
2045
2046 // Extract out the bottom element which is our float result.
2047 auto Extract = ExtractElementInst::Create(
2048 Call, ConstantInt::get(IntTy, 0), "", CI);
2049
2050 CI->replaceAllUsesWith(Extract);
2051 } else {
2052 // Assume the pointer argument points to storage aligned to 32bits
2053 // or more.
2054 // TODO(dneto): Do more analysis to make sure this is true?
2055 //
2056 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
2057 // with:
2058 //
2059 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
2060 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
2061 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
2062 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
2063 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
2064 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
2065 // x float> %converted, %index_is_odd32
2066
2067 auto IntPointerTy = PointerType::get(
2068 IntTy, Arg1->getType()->getPointerAddressSpace());
2069
David Neto973e6a82017-05-30 13:48:18 -04002070 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04002071 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04002072 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04002073 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
2074
2075 auto One = ConstantInt::get(IntTy, 1);
2076 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
2077 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
2078
2079 // Index into the correct address of the casted pointer.
2080 auto Ptr =
2081 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
2082
2083 // Load from the int* we casted to.
2084 auto Load = new LoadInst(Ptr, "", CI);
2085
2086 // Get our float2.
2087 auto Call = CallInst::Create(NewF, Load, "", CI);
2088
2089 // Extract out the float result, where the element number is
2090 // determined by whether the original index was even or odd.
2091 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
2092
2093 CI->replaceAllUsesWith(Extract);
2094 }
David Neto22f144c2017-06-12 14:26:21 -04002095
2096 // Lastly, remember to remove the user.
2097 ToRemoves.push_back(CI);
2098 }
2099 }
2100
2101 Changed = !ToRemoves.empty();
2102
2103 // And cleanup the calls we don't use anymore.
2104 for (auto V : ToRemoves) {
2105 V->eraseFromParent();
2106 }
2107
2108 // And remove the function we don't need either too.
2109 F->eraseFromParent();
2110 }
2111 }
2112
2113 return Changed;
2114}
2115
2116bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
2117 bool Changed = false;
2118
David Neto556c7e62018-06-08 13:45:55 -07002119 const std::vector<const char *> Map = {
2120 "_Z11vload_half2jPU3AS1KDh",
2121 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
2122 "_Z11vload_half2jPU3AS2KDh",
2123 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
2124 };
David Neto22f144c2017-06-12 14:26:21 -04002125
2126 for (auto Name : Map) {
2127 // If we find a function with the matching name.
2128 if (auto F = M.getFunction(Name)) {
2129 SmallVector<Instruction *, 4> ToRemoves;
2130
2131 // Walk the users of the function.
2132 for (auto &U : F->uses()) {
2133 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2134 // The index argument from vload_half.
2135 auto Arg0 = CI->getOperand(0);
2136
2137 // The pointer argument from vload_half.
2138 auto Arg1 = CI->getOperand(1);
2139
2140 auto IntTy = Type::getInt32Ty(M.getContext());
2141 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2142 auto NewPointerTy = PointerType::get(
2143 IntTy, Arg1->getType()->getPointerAddressSpace());
2144 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2145
2146 // Cast the half* pointer to int*.
2147 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2148
2149 // Index into the correct address of the casted pointer.
2150 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
2151
2152 // Load from the int* we casted to.
2153 auto Load = new LoadInst(Index, "", CI);
2154
2155 // Our intrinsic to unpack a float2 from an int.
2156 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2157
2158 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2159
2160 // Get our float2.
2161 auto Call = CallInst::Create(NewF, Load, "", CI);
2162
2163 CI->replaceAllUsesWith(Call);
2164
2165 // Lastly, remember to remove the user.
2166 ToRemoves.push_back(CI);
2167 }
2168 }
2169
2170 Changed = !ToRemoves.empty();
2171
2172 // And cleanup the calls we don't use anymore.
2173 for (auto V : ToRemoves) {
2174 V->eraseFromParent();
2175 }
2176
2177 // And remove the function we don't need either too.
2178 F->eraseFromParent();
2179 }
2180 }
2181
2182 return Changed;
2183}
2184
2185bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
2186 bool Changed = false;
2187
David Neto556c7e62018-06-08 13:45:55 -07002188 const std::vector<const char *> Map = {
2189 "_Z11vload_half4jPU3AS1KDh",
2190 "_Z12vloada_half4jPU3AS1KDh",
2191 "_Z11vload_half4jPU3AS2KDh",
2192 "_Z12vloada_half4jPU3AS2KDh",
2193 };
David Neto22f144c2017-06-12 14:26:21 -04002194
2195 for (auto Name : Map) {
2196 // If we find a function with the matching name.
2197 if (auto F = M.getFunction(Name)) {
2198 SmallVector<Instruction *, 4> ToRemoves;
2199
2200 // Walk the users of the function.
2201 for (auto &U : F->uses()) {
2202 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2203 // The index argument from vload_half.
2204 auto Arg0 = CI->getOperand(0);
2205
2206 // The pointer argument from vload_half.
2207 auto Arg1 = CI->getOperand(1);
2208
2209 auto IntTy = Type::getInt32Ty(M.getContext());
2210 auto Int2Ty = VectorType::get(IntTy, 2);
2211 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2212 auto NewPointerTy = PointerType::get(
2213 Int2Ty, Arg1->getType()->getPointerAddressSpace());
2214 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2215
2216 // Cast the half* pointer to int2*.
2217 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
2218
2219 // Index into the correct address of the casted pointer.
2220 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
2221
2222 // Load from the int2* we casted to.
2223 auto Load = new LoadInst(Index, "", CI);
2224
2225 // Extract each element from the loaded int2.
2226 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2227 "", CI);
2228 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2229 "", CI);
2230
2231 // Our intrinsic to unpack a float2 from an int.
2232 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2233
2234 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2235
2236 // Get the lower (x & y) components of our final float4.
2237 auto Lo = CallInst::Create(NewF, X, "", CI);
2238
2239 // Get the higher (z & w) components of our final float4.
2240 auto Hi = CallInst::Create(NewF, Y, "", CI);
2241
2242 Constant *ShuffleMask[4] = {
2243 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2244 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2245
2246 // Combine our two float2's into one float4.
2247 auto Combine = new ShuffleVectorInst(
2248 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2249
2250 CI->replaceAllUsesWith(Combine);
2251
2252 // Lastly, remember to remove the user.
2253 ToRemoves.push_back(CI);
2254 }
2255 }
2256
2257 Changed = !ToRemoves.empty();
2258
2259 // And cleanup the calls we don't use anymore.
2260 for (auto V : ToRemoves) {
2261 V->eraseFromParent();
2262 }
2263
2264 // And remove the function we don't need either too.
2265 F->eraseFromParent();
2266 }
2267 }
2268
2269 return Changed;
2270}
2271
David Neto6ad93232018-06-07 15:42:58 -07002272bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
2273 bool Changed = false;
2274
2275 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2276 //
2277 // %u = load i32 %ptr
2278 // %fxy = call <2 x float> Unpack2xHalf(u)
2279 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2280 const std::vector<const char *> Map = {
2281 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2282 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2283 "_Z20__clspv_vloada_half2jPKj", // private
2284 };
2285
2286 for (auto Name : Map) {
2287 // If we find a function with the matching name.
2288 if (auto F = M.getFunction(Name)) {
2289 SmallVector<Instruction *, 4> ToRemoves;
2290
2291 // Walk the users of the function.
2292 for (auto &U : F->uses()) {
2293 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
2294 auto Index = CI->getOperand(0);
2295 auto Ptr = CI->getOperand(1);
2296
2297 auto IntTy = Type::getInt32Ty(M.getContext());
2298 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2299 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2300
2301 auto IndexedPtr =
2302 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
2303 auto Load = new LoadInst(IndexedPtr, "", CI);
2304
2305 // Our intrinsic to unpack a float2 from an int.
2306 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2307
2308 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2309
2310 // Get our final float2.
2311 auto Result = CallInst::Create(NewF, Load, "", CI);
2312
2313 CI->replaceAllUsesWith(Result);
2314
2315 // Lastly, remember to remove the user.
2316 ToRemoves.push_back(CI);
2317 }
2318 }
2319
2320 Changed = true;
2321
2322 // And cleanup the calls we don't use anymore.
2323 for (auto V : ToRemoves) {
2324 V->eraseFromParent();
2325 }
2326
2327 // And remove the function we don't need either too.
2328 F->eraseFromParent();
2329 }
2330 }
2331
2332 return Changed;
2333}
2334
2335bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
2336 bool Changed = false;
2337
2338 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2339 //
2340 // %u2 = load <2 x i32> %ptr
2341 // %u2xy = extractelement %u2, 0
2342 // %u2zw = extractelement %u2, 1
2343 // %fxy = call <2 x float> Unpack2xHalf(uint)
2344 // %fzw = call <2 x float> Unpack2xHalf(uint)
2345 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
2346 const std::vector<const char *> Map = {
2347 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2348 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2349 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2350 };
2351
2352 for (auto Name : Map) {
2353 // If we find a function with the matching name.
2354 if (auto F = M.getFunction(Name)) {
2355 SmallVector<Instruction *, 4> ToRemoves;
2356
2357 // Walk the users of the function.
2358 for (auto &U : F->uses()) {
2359 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2360 auto Index = CI->getOperand(0);
2361 auto Ptr = CI->getOperand(1);
2362
2363 auto IntTy = Type::getInt32Ty(M.getContext());
2364 auto Int2Ty = VectorType::get(IntTy, 2);
2365 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2366 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2367
2368 auto IndexedPtr =
2369 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
2370 auto Load = new LoadInst(IndexedPtr, "", CI);
2371
2372 // Extract each element from the loaded int2.
2373 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2374 "", CI);
2375 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2376 "", CI);
2377
2378 // Our intrinsic to unpack a float2 from an int.
2379 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2380
2381 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2382
2383 // Get the lower (x & y) components of our final float4.
2384 auto Lo = CallInst::Create(NewF, X, "", CI);
2385
2386 // Get the higher (z & w) components of our final float4.
2387 auto Hi = CallInst::Create(NewF, Y, "", CI);
2388
2389 Constant *ShuffleMask[4] = {
2390 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2391 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2392
2393 // Combine our two float2's into one float4.
2394 auto Combine = new ShuffleVectorInst(
2395 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2396
2397 CI->replaceAllUsesWith(Combine);
2398
2399 // Lastly, remember to remove the user.
2400 ToRemoves.push_back(CI);
2401 }
2402 }
2403
2404 Changed = true;
2405
2406 // And cleanup the calls we don't use anymore.
2407 for (auto V : ToRemoves) {
2408 V->eraseFromParent();
2409 }
2410
2411 // And remove the function we don't need either too.
2412 F->eraseFromParent();
2413 }
2414 }
2415
2416 return Changed;
2417}
2418
David Neto22f144c2017-06-12 14:26:21 -04002419bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
2420 bool Changed = false;
2421
2422 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
2423 "_Z15vstore_half_rtefjPU3AS1Dh",
2424 "_Z15vstore_half_rtzfjPU3AS1Dh"};
2425
2426 for (auto Name : Map) {
2427 // If we find a function with the matching name.
2428 if (auto F = M.getFunction(Name)) {
2429 SmallVector<Instruction *, 4> ToRemoves;
2430
2431 // Walk the users of the function.
2432 for (auto &U : F->uses()) {
2433 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2434 // The value to store.
2435 auto Arg0 = CI->getOperand(0);
2436
2437 // The index argument from vstore_half.
2438 auto Arg1 = CI->getOperand(1);
2439
2440 // The pointer argument from vstore_half.
2441 auto Arg2 = CI->getOperand(2);
2442
David Neto22f144c2017-06-12 14:26:21 -04002443 auto IntTy = Type::getInt32Ty(M.getContext());
2444 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002445 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04002446 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002447
2448 // Our intrinsic to pack a float2 to an int.
2449 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2450
2451 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2452
2453 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04002454 auto TempVec =
2455 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
2456 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002457
2458 // Pack the float2 -> half2 (in an int).
2459 auto X = CallInst::Create(NewF, TempVec, "", CI);
2460
David Neto482550a2018-03-24 05:21:07 -07002461 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002462 auto ShortTy = Type::getInt16Ty(M.getContext());
2463 auto ShortPointerTy = PointerType::get(
2464 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002465
David Neto17852de2017-05-29 17:29:31 -04002466 // Truncate our i32 to an i16.
2467 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002468
David Neto17852de2017-05-29 17:29:31 -04002469 // Cast the half* pointer to short*.
2470 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002471
David Neto17852de2017-05-29 17:29:31 -04002472 // Index into the correct address of the casted pointer.
2473 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002474
David Neto17852de2017-05-29 17:29:31 -04002475 // Store to the int* we casted to.
2476 auto Store = new StoreInst(Trunc, Index, CI);
2477
2478 CI->replaceAllUsesWith(Store);
2479 } else {
2480 // We can only write to 32-bit aligned words.
2481 //
2482 // Assuming base is aligned to 32-bits, replace the equivalent of
2483 // vstore_half(value, index, base)
2484 // with:
2485 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2486 // uint32_t write_to_upper_half = index & 1u;
2487 // uint32_t shift = write_to_upper_half << 4;
2488 //
2489 // // Pack the float value as a half number in bottom 16 bits
2490 // // of an i32.
2491 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2492 //
2493 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2494 // ^ ((packed & 0xffff) << shift)
2495 // // We only need relaxed consistency, but OpenCL 1.2 only has
2496 // // sequentially consistent atomics.
2497 // // TODO(dneto): Use relaxed consistency.
2498 // atomic_xor(target_ptr, xor_value)
2499 auto IntPointerTy = PointerType::get(
2500 IntTy, Arg2->getType()->getPointerAddressSpace());
2501
2502 auto Four = ConstantInt::get(IntTy, 4);
2503 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2504
2505 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2506 // Compute index / 2
2507 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2508 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2509 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2510 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2511 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2512 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2513 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2514
2515 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2516 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2517 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2518
2519 // Generate the call to atomi_xor.
2520 SmallVector<Type *, 5> ParamTypes;
2521 // The pointer type.
2522 ParamTypes.push_back(IntPointerTy);
2523 // The Types for memory scope, semantics, and value.
2524 ParamTypes.push_back(IntTy);
2525 ParamTypes.push_back(IntTy);
2526 ParamTypes.push_back(IntTy);
2527 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2528 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2529
2530 const auto ConstantScopeDevice =
2531 ConstantInt::get(IntTy, spv::ScopeDevice);
2532 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2533 // (SPIR-V Workgroup).
2534 const auto AddrSpaceSemanticsBits =
2535 IntPointerTy->getPointerAddressSpace() == 1
2536 ? spv::MemorySemanticsUniformMemoryMask
2537 : spv::MemorySemanticsWorkgroupMemoryMask;
2538
2539 // We're using relaxed consistency here.
2540 const auto ConstantMemorySemantics =
2541 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2542 AddrSpaceSemanticsBits);
2543
2544 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2545 ConstantMemorySemantics, ValueToXor};
2546 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2547 }
David Neto22f144c2017-06-12 14:26:21 -04002548
2549 // Lastly, remember to remove the user.
2550 ToRemoves.push_back(CI);
2551 }
2552 }
2553
2554 Changed = !ToRemoves.empty();
2555
2556 // And cleanup the calls we don't use anymore.
2557 for (auto V : ToRemoves) {
2558 V->eraseFromParent();
2559 }
2560
2561 // And remove the function we don't need either too.
2562 F->eraseFromParent();
2563 }
2564 }
2565
2566 return Changed;
2567}
2568
2569bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2570 bool Changed = false;
2571
David Netoe2871522018-06-08 11:09:54 -07002572 const std::vector<const char *> Map = {
2573 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2574 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2575 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2576 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2577 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2578 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2579 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2580 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2581 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2582 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2583 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2584 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2585 };
David Neto22f144c2017-06-12 14:26:21 -04002586
2587 for (auto Name : Map) {
2588 // If we find a function with the matching name.
2589 if (auto F = M.getFunction(Name)) {
2590 SmallVector<Instruction *, 4> ToRemoves;
2591
2592 // Walk the users of the function.
2593 for (auto &U : F->uses()) {
2594 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2595 // The value to store.
2596 auto Arg0 = CI->getOperand(0);
2597
2598 // The index argument from vstore_half.
2599 auto Arg1 = CI->getOperand(1);
2600
2601 // The pointer argument from vstore_half.
2602 auto Arg2 = CI->getOperand(2);
2603
2604 auto IntTy = Type::getInt32Ty(M.getContext());
2605 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2606 auto NewPointerTy = PointerType::get(
2607 IntTy, Arg2->getType()->getPointerAddressSpace());
2608 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2609
2610 // Our intrinsic to pack a float2 to an int.
2611 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2612
2613 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2614
2615 // Turn the packed x & y into the final packing.
2616 auto X = CallInst::Create(NewF, Arg0, "", CI);
2617
2618 // Cast the half* pointer to int*.
2619 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2620
2621 // Index into the correct address of the casted pointer.
2622 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2623
2624 // Store to the int* we casted to.
2625 auto Store = new StoreInst(X, Index, CI);
2626
2627 CI->replaceAllUsesWith(Store);
2628
2629 // Lastly, remember to remove the user.
2630 ToRemoves.push_back(CI);
2631 }
2632 }
2633
2634 Changed = !ToRemoves.empty();
2635
2636 // And cleanup the calls we don't use anymore.
2637 for (auto V : ToRemoves) {
2638 V->eraseFromParent();
2639 }
2640
2641 // And remove the function we don't need either too.
2642 F->eraseFromParent();
2643 }
2644 }
2645
2646 return Changed;
2647}
2648
2649bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2650 bool Changed = false;
2651
David Netoe2871522018-06-08 11:09:54 -07002652 const std::vector<const char *> Map = {
2653 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2654 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2655 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2656 "_Z13vstorea_half4Dv4_fjPDh", // private
2657 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2658 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2659 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2660 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2661 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2662 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2663 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2664 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2665 };
David Neto22f144c2017-06-12 14:26:21 -04002666
2667 for (auto Name : Map) {
2668 // If we find a function with the matching name.
2669 if (auto F = M.getFunction(Name)) {
2670 SmallVector<Instruction *, 4> ToRemoves;
2671
2672 // Walk the users of the function.
2673 for (auto &U : F->uses()) {
2674 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2675 // The value to store.
2676 auto Arg0 = CI->getOperand(0);
2677
2678 // The index argument from vstore_half.
2679 auto Arg1 = CI->getOperand(1);
2680
2681 // The pointer argument from vstore_half.
2682 auto Arg2 = CI->getOperand(2);
2683
2684 auto IntTy = Type::getInt32Ty(M.getContext());
2685 auto Int2Ty = VectorType::get(IntTy, 2);
2686 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2687 auto NewPointerTy = PointerType::get(
2688 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2689 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2690
2691 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2692 ConstantInt::get(IntTy, 1)};
2693
2694 // Extract out the x & y components of our to store value.
2695 auto Lo =
2696 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2697 ConstantVector::get(LoShuffleMask), "", CI);
2698
2699 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2700 ConstantInt::get(IntTy, 3)};
2701
2702 // Extract out the z & w components of our to store value.
2703 auto Hi =
2704 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2705 ConstantVector::get(HiShuffleMask), "", CI);
2706
2707 // Our intrinsic to pack a float2 to an int.
2708 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2709
2710 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2711
2712 // Turn the packed x & y into the final component of our int2.
2713 auto X = CallInst::Create(NewF, Lo, "", CI);
2714
2715 // Turn the packed z & w into the final component of our int2.
2716 auto Y = CallInst::Create(NewF, Hi, "", CI);
2717
2718 auto Combine = InsertElementInst::Create(
2719 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2720 Combine = InsertElementInst::Create(
2721 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2722
2723 // Cast the half* pointer to int2*.
2724 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2725
2726 // Index into the correct address of the casted pointer.
2727 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2728
2729 // Store to the int2* we casted to.
2730 auto Store = new StoreInst(Combine, Index, CI);
2731
2732 CI->replaceAllUsesWith(Store);
2733
2734 // Lastly, remember to remove the user.
2735 ToRemoves.push_back(CI);
2736 }
2737 }
2738
2739 Changed = !ToRemoves.empty();
2740
2741 // And cleanup the calls we don't use anymore.
2742 for (auto V : ToRemoves) {
2743 V->eraseFromParent();
2744 }
2745
2746 // And remove the function we don't need either too.
2747 F->eraseFromParent();
2748 }
2749 }
2750
2751 return Changed;
2752}
2753
2754bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2755 bool Changed = false;
2756
2757 const std::map<const char *, const char*> Map = {
2758 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2759 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2760 };
2761
2762 for (auto Pair : Map) {
2763 // If we find a function with the matching name.
2764 if (auto F = M.getFunction(Pair.first)) {
2765 SmallVector<Instruction *, 4> ToRemoves;
2766
2767 // Walk the users of the function.
2768 for (auto &U : F->uses()) {
2769 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2770 // The image.
2771 auto Arg0 = CI->getOperand(0);
2772
2773 // The sampler.
2774 auto Arg1 = CI->getOperand(1);
2775
2776 // The coordinate (integer type that we can't handle).
2777 auto Arg2 = CI->getOperand(2);
2778
2779 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2780
2781 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2782
2783 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2784
2785 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2786
2787 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2788
2789 CI->replaceAllUsesWith(NewCI);
2790
2791 // Lastly, remember to remove the user.
2792 ToRemoves.push_back(CI);
2793 }
2794 }
2795
2796 Changed = !ToRemoves.empty();
2797
2798 // And cleanup the calls we don't use anymore.
2799 for (auto V : ToRemoves) {
2800 V->eraseFromParent();
2801 }
2802
2803 // And remove the function we don't need either too.
2804 F->eraseFromParent();
2805 }
2806 }
2807
2808 return Changed;
2809}
2810
2811bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2812 bool Changed = false;
2813
2814 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002815 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2816 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2817 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2818 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2819 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2820 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002821 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2822 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2823 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2824 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2825 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002826 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002827
2828 for (auto Pair : Map) {
2829 // If we find a function with the matching name.
2830 if (auto F = M.getFunction(Pair.first)) {
2831 SmallVector<Instruction *, 4> ToRemoves;
2832
2833 // Walk the users of the function.
2834 for (auto &U : F->uses()) {
2835 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2836 auto FType = F->getFunctionType();
2837 SmallVector<Type *, 5> ParamTypes;
2838
2839 // The pointer type.
2840 ParamTypes.push_back(FType->getParamType(0));
2841
2842 auto IntTy = Type::getInt32Ty(M.getContext());
2843
2844 // The memory scope type.
2845 ParamTypes.push_back(IntTy);
2846
2847 // The memory semantics type.
2848 ParamTypes.push_back(IntTy);
2849
2850 if (2 < CI->getNumArgOperands()) {
2851 // The unequal memory semantics type.
2852 ParamTypes.push_back(IntTy);
2853
2854 // The value type.
2855 ParamTypes.push_back(FType->getParamType(2));
2856
2857 // The comparator type.
2858 ParamTypes.push_back(FType->getParamType(1));
2859 } else if (1 < CI->getNumArgOperands()) {
2860 // The value type.
2861 ParamTypes.push_back(FType->getParamType(1));
2862 }
2863
2864 auto NewFType =
2865 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2866 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2867
2868 // We need to map the OpenCL constants to the SPIR-V equivalents.
2869 const auto ConstantScopeDevice =
2870 ConstantInt::get(IntTy, spv::ScopeDevice);
2871 const auto ConstantMemorySemantics = ConstantInt::get(
2872 IntTy, spv::MemorySemanticsUniformMemoryMask |
2873 spv::MemorySemanticsSequentiallyConsistentMask);
2874
2875 SmallVector<Value *, 5> Params;
2876
2877 // The pointer.
2878 Params.push_back(CI->getArgOperand(0));
2879
2880 // The memory scope.
2881 Params.push_back(ConstantScopeDevice);
2882
2883 // The memory semantics.
2884 Params.push_back(ConstantMemorySemantics);
2885
2886 if (2 < CI->getNumArgOperands()) {
2887 // The unequal memory semantics.
2888 Params.push_back(ConstantMemorySemantics);
2889
2890 // The value.
2891 Params.push_back(CI->getArgOperand(2));
2892
2893 // The comparator.
2894 Params.push_back(CI->getArgOperand(1));
2895 } else if (1 < CI->getNumArgOperands()) {
2896 // The value.
2897 Params.push_back(CI->getArgOperand(1));
2898 }
2899
2900 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2901
2902 CI->replaceAllUsesWith(NewCI);
2903
2904 // Lastly, remember to remove the user.
2905 ToRemoves.push_back(CI);
2906 }
2907 }
2908
2909 Changed = !ToRemoves.empty();
2910
2911 // And cleanup the calls we don't use anymore.
2912 for (auto V : ToRemoves) {
2913 V->eraseFromParent();
2914 }
2915
2916 // And remove the function we don't need either too.
2917 F->eraseFromParent();
2918 }
2919 }
2920
Neil Henning39672102017-09-29 14:33:13 +01002921 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002922 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2923 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2924 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2925 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2926 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2927 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2928 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2929 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2930 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2931 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2932 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2933 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2934 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2935 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2936 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2937 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002938 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2939 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2940 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2941 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2942 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2943 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2944 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2945 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2946 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2947 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2948 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2949 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2950 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2951 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2952 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2953 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2954
2955 for (auto Pair : Map2) {
2956 // If we find a function with the matching name.
2957 if (auto F = M.getFunction(Pair.first)) {
2958 SmallVector<Instruction *, 4> ToRemoves;
2959
2960 // Walk the users of the function.
2961 for (auto &U : F->uses()) {
2962 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2963 auto AtomicOp = new AtomicRMWInst(
2964 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2965 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2966
2967 CI->replaceAllUsesWith(AtomicOp);
2968
2969 // Lastly, remember to remove the user.
2970 ToRemoves.push_back(CI);
2971 }
2972 }
2973
2974 Changed = !ToRemoves.empty();
2975
2976 // And cleanup the calls we don't use anymore.
2977 for (auto V : ToRemoves) {
2978 V->eraseFromParent();
2979 }
2980
2981 // And remove the function we don't need either too.
2982 F->eraseFromParent();
2983 }
2984 }
2985
David Neto22f144c2017-06-12 14:26:21 -04002986 return Changed;
2987}
2988
2989bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2990 bool Changed = false;
2991
2992 // If we find a function with the matching name.
2993 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2994 SmallVector<Instruction *, 4> ToRemoves;
2995
2996 auto IntTy = Type::getInt32Ty(M.getContext());
2997 auto FloatTy = Type::getFloatTy(M.getContext());
2998
2999 Constant *DownShuffleMask[3] = {
3000 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
3001 ConstantInt::get(IntTy, 2)};
3002
3003 Constant *UpShuffleMask[4] = {
3004 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
3005 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
3006
3007 Constant *FloatVec[3] = {
3008 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
3009 };
3010
3011 // Walk the users of the function.
3012 for (auto &U : F->uses()) {
3013 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3014 auto Vec4Ty = CI->getArgOperand(0)->getType();
3015 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
3016 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
3017 auto Vec3Ty = Arg0->getType();
3018
3019 auto NewFType =
3020 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
3021
3022 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
3023
3024 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
3025
3026 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
3027
3028 CI->replaceAllUsesWith(Result);
3029
3030 // Lastly, remember to remove the user.
3031 ToRemoves.push_back(CI);
3032 }
3033 }
3034
3035 Changed = !ToRemoves.empty();
3036
3037 // And cleanup the calls we don't use anymore.
3038 for (auto V : ToRemoves) {
3039 V->eraseFromParent();
3040 }
3041
3042 // And remove the function we don't need either too.
3043 F->eraseFromParent();
3044 }
3045
3046 return Changed;
3047}
David Neto62653202017-10-16 19:05:18 -04003048
3049bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
3050 bool Changed = false;
3051
3052 // OpenCL's float result = fract(float x, float* ptr)
3053 //
3054 // In the LLVM domain:
3055 //
3056 // %floor_result = call spir_func float @floor(float %x)
3057 // store float %floor_result, float * %ptr
3058 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
3059 // %result = call spir_func float
3060 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
3061 //
3062 // Becomes in the SPIR-V domain, where translations of floor, fmin,
3063 // and clspv.fract occur in the SPIR-V generator pass:
3064 //
3065 // %glsl_ext = OpExtInstImport "GLSL.std.450"
3066 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
3067 // ...
3068 // %floor_result = OpExtInst %float %glsl_ext Floor %x
3069 // OpStore %ptr %floor_result
3070 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
3071 // %fract_result = OpExtInst %float
3072 // %glsl_ext Fmin %fract_intermediate %just_under_1
3073
3074
3075 using std::string;
3076
3077 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
3078 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
3079 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
3080 auto make_quad = [](const char *a, const char *b, const char *c,
3081 const char *d) {
3082 return std::tuple<const char *, const char *, const char *, const char *>(
3083 a, b, c, d);
3084 };
3085 const std::vector<QuadType> Functions = {
3086 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
3087 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
3088 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
3089 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
3090 };
3091
3092 for (auto& quad : Functions) {
3093 const StringRef fract_name(std::get<0>(quad));
3094
3095 // If we find a function with the matching name.
3096 if (auto F = M.getFunction(fract_name)) {
3097 if (F->use_begin() == F->use_end())
3098 continue;
3099
3100 // We have some uses.
3101 Changed = true;
3102
3103 auto& Context = M.getContext();
3104
3105 const StringRef floor_name(std::get<1>(quad));
3106 const StringRef fmin_name(std::get<2>(quad));
3107 const StringRef clspv_fract_name(std::get<3>(quad));
3108
3109 // This is either float or a float vector. All the float-like
3110 // types are this type.
3111 auto result_ty = F->getReturnType();
3112
3113 Function* fmin_fn = M.getFunction(fmin_name);
3114 if (!fmin_fn) {
3115 // Make the fmin function.
3116 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
3117 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04003118 fmin_fn->addFnAttr(Attribute::ReadNone);
3119 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
3120 }
3121
3122 Function* floor_fn = M.getFunction(floor_name);
3123 if (!floor_fn) {
3124 // Make the floor function.
3125 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
3126 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04003127 floor_fn->addFnAttr(Attribute::ReadNone);
3128 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
3129 }
3130
3131 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
3132 if (!clspv_fract_fn) {
3133 // Make the clspv_fract function.
3134 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
3135 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04003136 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
3137 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
3138 }
3139
3140 // Number of significant significand bits, whether represented or not.
3141 unsigned num_significand_bits;
3142 switch (result_ty->getScalarType()->getTypeID()) {
3143 case Type::HalfTyID:
3144 num_significand_bits = 11;
3145 break;
3146 case Type::FloatTyID:
3147 num_significand_bits = 24;
3148 break;
3149 case Type::DoubleTyID:
3150 num_significand_bits = 53;
3151 break;
3152 default:
3153 assert(false && "Unhandled float type when processing fract builtin");
3154 break;
3155 }
3156 // Beware that the disassembler displays this value as
3157 // OpConstant %float 1
3158 // which is not quite right.
3159 const double kJustUnderOneScalar =
3160 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3161
3162 Constant *just_under_one =
3163 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3164 if (result_ty->isVectorTy()) {
3165 just_under_one = ConstantVector::getSplat(
3166 result_ty->getVectorNumElements(), just_under_one);
3167 }
3168
3169 IRBuilder<> Builder(Context);
3170
3171 SmallVector<Instruction *, 4> ToRemoves;
3172
3173 // Walk the users of the function.
3174 for (auto &U : F->uses()) {
3175 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3176
3177 Builder.SetInsertPoint(CI);
3178 auto arg = CI->getArgOperand(0);
3179 auto ptr = CI->getArgOperand(1);
3180
3181 // Compute floor result and store it.
3182 auto floor = Builder.CreateCall(floor_fn, {arg});
3183 Builder.CreateStore(floor, ptr);
3184
3185 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
3186 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
3187
3188 CI->replaceAllUsesWith(fract_result);
3189
3190 // Lastly, remember to remove the user.
3191 ToRemoves.push_back(CI);
3192 }
3193 }
3194
3195 // And cleanup the calls we don't use anymore.
3196 for (auto V : ToRemoves) {
3197 V->eraseFromParent();
3198 }
3199
3200 // And remove the function we don't need either too.
3201 F->eraseFromParent();
3202 }
3203 }
3204
3205 return Changed;
3206}