blob: 02c284bb339d89a86bff2c12a31c80844b2d5559 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
David Neto118188e2018-08-24 11:27:54 -040019#include "llvm/IR/Constants.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000023#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/Pass.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/raw_ostream.h"
27#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040028
David Neto118188e2018-08-24 11:27:54 -040029#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040030
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Neto22f144c2017-06-12 14:26:21 -040033using namespace llvm;
34
35#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
36
37namespace {
38uint32_t clz(uint32_t v) {
39 uint32_t r;
40 uint32_t shift;
41
42 r = (v > 0xFFFF) << 4;
43 v >>= r;
44 shift = (v > 0xFF) << 3;
45 v >>= shift;
46 r |= shift;
47 shift = (v > 0xF) << 2;
48 v >>= shift;
49 r |= shift;
50 shift = (v > 0x3) << 1;
51 v >>= shift;
52 r |= shift;
53 r |= (v >> 1);
54
55 return r;
56}
57
58Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
59 if (1 == elements) {
60 return Type::getInt1Ty(C);
61 } else {
62 return VectorType::get(Type::getInt1Ty(C), elements);
63 }
64}
65
66struct ReplaceOpenCLBuiltinPass final : public ModulePass {
67 static char ID;
68 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
69
70 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +000071 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040072 bool replaceRecip(Module &M);
73 bool replaceDivide(Module &M);
74 bool replaceExp10(Module &M);
75 bool replaceLog10(Module &M);
76 bool replaceBarrier(Module &M);
77 bool replaceMemFence(Module &M);
78 bool replaceRelational(Module &M);
79 bool replaceIsInfAndIsNan(Module &M);
80 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +000081 bool replaceUpsample(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +000082 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +000083 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +000084 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040085 bool replaceSignbit(Module &M);
86 bool replaceMadandMad24andMul24(Module &M);
87 bool replaceVloadHalf(Module &M);
88 bool replaceVloadHalf2(Module &M);
89 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -070090 bool replaceClspvVloadaHalf2(Module &M);
91 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040092 bool replaceVstoreHalf(Module &M);
93 bool replaceVstoreHalf2(Module &M);
94 bool replaceVstoreHalf4(Module &M);
95 bool replaceReadImageF(Module &M);
96 bool replaceAtomics(Module &M);
97 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -040098 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -070099 bool replaceVload(Module &M);
100 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400101};
102}
103
104char ReplaceOpenCLBuiltinPass::ID = 0;
105static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
106 "Replace OpenCL Builtins Pass");
107
108namespace clspv {
109ModulePass *createReplaceOpenCLBuiltinPass() {
110 return new ReplaceOpenCLBuiltinPass();
111}
112}
113
114bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
115 bool Changed = false;
116
Kévin Petit2444e9b2018-11-09 14:14:37 +0000117 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400118 Changed |= replaceRecip(M);
119 Changed |= replaceDivide(M);
120 Changed |= replaceExp10(M);
121 Changed |= replaceLog10(M);
122 Changed |= replaceBarrier(M);
123 Changed |= replaceMemFence(M);
124 Changed |= replaceRelational(M);
125 Changed |= replaceIsInfAndIsNan(M);
126 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000127 Changed |= replaceUpsample(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000128 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000129 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000130 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400131 Changed |= replaceSignbit(M);
132 Changed |= replaceMadandMad24andMul24(M);
133 Changed |= replaceVloadHalf(M);
134 Changed |= replaceVloadHalf2(M);
135 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700136 Changed |= replaceClspvVloadaHalf2(M);
137 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400138 Changed |= replaceVstoreHalf(M);
139 Changed |= replaceVstoreHalf2(M);
140 Changed |= replaceVstoreHalf4(M);
141 Changed |= replaceReadImageF(M);
142 Changed |= replaceAtomics(M);
143 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400144 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700145 Changed |= replaceVload(M);
146 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400147
148 return Changed;
149}
150
Kévin Petit2444e9b2018-11-09 14:14:37 +0000151bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
152 bool Changed = false;
153
154 const char *Names[] = {
155 "_Z3abst",
156 "_Z3absDv2_t",
157 "_Z3absDv3_t",
158 "_Z3absDv4_t",
159 "_Z3absj",
160 "_Z3absDv2_j",
161 "_Z3absDv3_j",
162 "_Z3absDv4_j",
163 "_Z3absm",
164 "_Z3absDv2_m",
165 "_Z3absDv3_m",
166 "_Z3absDv4_m",
167 };
168
169 for (auto Name : Names) {
170 // If we find a function with the matching name.
171 if (auto F = M.getFunction(Name)) {
172 SmallVector<Instruction *, 4> ToRemoves;
173
174 // Walk the users of the function.
175 for (auto &U : F->uses()) {
176 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
177 // Abs has one arg.
178 auto Arg = CI->getOperand(0);
179
180 // Use the argument unchanged, we know it's unsigned
181 CI->replaceAllUsesWith(Arg);
182
183 // Lastly, remember to remove the user.
184 ToRemoves.push_back(CI);
185 }
186 }
187
188 Changed = !ToRemoves.empty();
189
190 // And cleanup the calls we don't use anymore.
191 for (auto V : ToRemoves) {
192 V->eraseFromParent();
193 }
194
195 // And remove the function we don't need either too.
196 F->eraseFromParent();
197 }
198 }
199
200 return Changed;
201}
202
David Neto22f144c2017-06-12 14:26:21 -0400203bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
204 bool Changed = false;
205
206 const char *Names[] = {
207 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
208 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
209 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
210 };
211
212 for (auto Name : Names) {
213 // If we find a function with the matching name.
214 if (auto F = M.getFunction(Name)) {
215 SmallVector<Instruction *, 4> ToRemoves;
216
217 // Walk the users of the function.
218 for (auto &U : F->uses()) {
219 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
220 // Recip has one arg.
221 auto Arg = CI->getOperand(0);
222
223 auto Div = BinaryOperator::Create(
224 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
225 CI);
226
227 CI->replaceAllUsesWith(Div);
228
229 // Lastly, remember to remove the user.
230 ToRemoves.push_back(CI);
231 }
232 }
233
234 Changed = !ToRemoves.empty();
235
236 // And cleanup the calls we don't use anymore.
237 for (auto V : ToRemoves) {
238 V->eraseFromParent();
239 }
240
241 // And remove the function we don't need either too.
242 F->eraseFromParent();
243 }
244 }
245
246 return Changed;
247}
248
249bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
250 bool Changed = false;
251
252 const char *Names[] = {
253 "_Z11half_divideff", "_Z13native_divideff",
254 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
255 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
256 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
257 };
258
259 for (auto Name : Names) {
260 // If we find a function with the matching name.
261 if (auto F = M.getFunction(Name)) {
262 SmallVector<Instruction *, 4> ToRemoves;
263
264 // Walk the users of the function.
265 for (auto &U : F->uses()) {
266 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
267 auto Div = BinaryOperator::Create(
268 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
269
270 CI->replaceAllUsesWith(Div);
271
272 // Lastly, remember to remove the user.
273 ToRemoves.push_back(CI);
274 }
275 }
276
277 Changed = !ToRemoves.empty();
278
279 // And cleanup the calls we don't use anymore.
280 for (auto V : ToRemoves) {
281 V->eraseFromParent();
282 }
283
284 // And remove the function we don't need either too.
285 F->eraseFromParent();
286 }
287 }
288
289 return Changed;
290}
291
292bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
293 bool Changed = false;
294
295 const std::map<const char *, const char *> Map = {
296 {"_Z5exp10f", "_Z3expf"},
297 {"_Z10half_exp10f", "_Z8half_expf"},
298 {"_Z12native_exp10f", "_Z10native_expf"},
299 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
300 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
301 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
302 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
303 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
304 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
305 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
306 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
307 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
308
309 for (auto Pair : Map) {
310 // If we find a function with the matching name.
311 if (auto F = M.getFunction(Pair.first)) {
312 SmallVector<Instruction *, 4> ToRemoves;
313
314 // Walk the users of the function.
315 for (auto &U : F->uses()) {
316 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
317 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
318
319 auto Arg = CI->getOperand(0);
320
321 // Constant of the natural log of 10 (ln(10)).
322 const double Ln10 =
323 2.302585092994045684017991454684364207601101488628772976033;
324
325 auto Mul = BinaryOperator::Create(
326 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
327 CI);
328
329 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
330
331 CI->replaceAllUsesWith(NewCI);
332
333 // Lastly, remember to remove the user.
334 ToRemoves.push_back(CI);
335 }
336 }
337
338 Changed = !ToRemoves.empty();
339
340 // And cleanup the calls we don't use anymore.
341 for (auto V : ToRemoves) {
342 V->eraseFromParent();
343 }
344
345 // And remove the function we don't need either too.
346 F->eraseFromParent();
347 }
348 }
349
350 return Changed;
351}
352
353bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
354 bool Changed = false;
355
356 const std::map<const char *, const char *> Map = {
357 {"_Z5log10f", "_Z3logf"},
358 {"_Z10half_log10f", "_Z8half_logf"},
359 {"_Z12native_log10f", "_Z10native_logf"},
360 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
361 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
362 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
363 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
364 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
365 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
366 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
367 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
368 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
369
370 for (auto Pair : Map) {
371 // If we find a function with the matching name.
372 if (auto F = M.getFunction(Pair.first)) {
373 SmallVector<Instruction *, 4> ToRemoves;
374
375 // Walk the users of the function.
376 for (auto &U : F->uses()) {
377 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
378 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
379
380 auto Arg = CI->getOperand(0);
381
382 // Constant of the reciprocal of the natural log of 10 (ln(10)).
383 const double Ln10 =
384 0.434294481903251827651128918916605082294397005803666566114;
385
386 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
387
388 auto Mul = BinaryOperator::Create(
389 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
390 "", CI);
391
392 CI->replaceAllUsesWith(Mul);
393
394 // Lastly, remember to remove the user.
395 ToRemoves.push_back(CI);
396 }
397 }
398
399 Changed = !ToRemoves.empty();
400
401 // And cleanup the calls we don't use anymore.
402 for (auto V : ToRemoves) {
403 V->eraseFromParent();
404 }
405
406 // And remove the function we don't need either too.
407 F->eraseFromParent();
408 }
409 }
410
411 return Changed;
412}
413
414bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
415 bool Changed = false;
416
417 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
418
419 const std::map<const char *, const char *> Map = {
420 {"_Z7barrierj", "__spirv_control_barrier"}};
421
422 for (auto Pair : Map) {
423 // If we find a function with the matching name.
424 if (auto F = M.getFunction(Pair.first)) {
425 SmallVector<Instruction *, 4> ToRemoves;
426
427 // Walk the users of the function.
428 for (auto &U : F->uses()) {
429 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
430 auto FType = F->getFunctionType();
431 SmallVector<Type *, 3> Params;
432 for (unsigned i = 0; i < 3; i++) {
433 Params.push_back(FType->getParamType(0));
434 }
435 auto NewFType =
436 FunctionType::get(FType->getReturnType(), Params, false);
437 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
438
439 auto Arg = CI->getOperand(0);
440
441 // We need to map the OpenCL constants to the SPIR-V equivalents.
442 const auto LocalMemFence =
443 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
444 const auto GlobalMemFence =
445 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
446 const auto ConstantSequentiallyConsistent = ConstantInt::get(
447 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
448 const auto ConstantScopeDevice =
449 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
450 const auto ConstantScopeWorkgroup =
451 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
452
453 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
454 const auto LocalMemFenceMask = BinaryOperator::Create(
455 Instruction::And, LocalMemFence, Arg, "", CI);
456 const auto WorkgroupShiftAmount =
457 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
458 clz(CLK_LOCAL_MEM_FENCE);
459 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
460 Instruction::Shl, LocalMemFenceMask,
461 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
462
463 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
464 const auto GlobalMemFenceMask = BinaryOperator::Create(
465 Instruction::And, GlobalMemFence, Arg, "", CI);
466 const auto UniformShiftAmount =
467 clz(spv::MemorySemanticsUniformMemoryMask) -
468 clz(CLK_GLOBAL_MEM_FENCE);
469 const auto MemorySemanticsUniform = BinaryOperator::Create(
470 Instruction::Shl, GlobalMemFenceMask,
471 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
472
473 // And combine the above together, also adding in
474 // MemorySemanticsSequentiallyConsistentMask.
475 auto MemorySemantics =
476 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
477 ConstantSequentiallyConsistent, "", CI);
478 MemorySemantics = BinaryOperator::Create(
479 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
480
481 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
482 // Device Scope, otherwise Workgroup Scope.
483 const auto Cmp =
484 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
485 GlobalMemFenceMask, GlobalMemFence, "", CI);
486 const auto MemoryScope = SelectInst::Create(
487 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
488
489 // Lastly, the Execution Scope is always Workgroup Scope.
490 const auto ExecutionScope = ConstantScopeWorkgroup;
491
492 auto NewCI = CallInst::Create(
493 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
494
495 CI->replaceAllUsesWith(NewCI);
496
497 // Lastly, remember to remove the user.
498 ToRemoves.push_back(CI);
499 }
500 }
501
502 Changed = !ToRemoves.empty();
503
504 // And cleanup the calls we don't use anymore.
505 for (auto V : ToRemoves) {
506 V->eraseFromParent();
507 }
508
509 // And remove the function we don't need either too.
510 F->eraseFromParent();
511 }
512 }
513
514 return Changed;
515}
516
517bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
518 bool Changed = false;
519
520 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
521
Neil Henning39672102017-09-29 14:33:13 +0100522 using Tuple = std::tuple<const char *, unsigned>;
523 const std::map<const char *, Tuple> Map = {
524 {"_Z9mem_fencej",
525 Tuple("__spirv_memory_barrier",
526 spv::MemorySemanticsSequentiallyConsistentMask)},
527 {"_Z14read_mem_fencej",
528 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
529 {"_Z15write_mem_fencej",
530 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400531
532 for (auto Pair : Map) {
533 // If we find a function with the matching name.
534 if (auto F = M.getFunction(Pair.first)) {
535 SmallVector<Instruction *, 4> ToRemoves;
536
537 // Walk the users of the function.
538 for (auto &U : F->uses()) {
539 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
540 auto FType = F->getFunctionType();
541 SmallVector<Type *, 2> Params;
542 for (unsigned i = 0; i < 2; i++) {
543 Params.push_back(FType->getParamType(0));
544 }
545 auto NewFType =
546 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100547 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400548
549 auto Arg = CI->getOperand(0);
550
551 // We need to map the OpenCL constants to the SPIR-V equivalents.
552 const auto LocalMemFence =
553 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
554 const auto GlobalMemFence =
555 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
556 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100557 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400558 const auto ConstantScopeDevice =
559 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
560
561 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
562 const auto LocalMemFenceMask = BinaryOperator::Create(
563 Instruction::And, LocalMemFence, Arg, "", CI);
564 const auto WorkgroupShiftAmount =
565 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
566 clz(CLK_LOCAL_MEM_FENCE);
567 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
568 Instruction::Shl, LocalMemFenceMask,
569 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
570
571 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
572 const auto GlobalMemFenceMask = BinaryOperator::Create(
573 Instruction::And, GlobalMemFence, Arg, "", CI);
574 const auto UniformShiftAmount =
575 clz(spv::MemorySemanticsUniformMemoryMask) -
576 clz(CLK_GLOBAL_MEM_FENCE);
577 const auto MemorySemanticsUniform = BinaryOperator::Create(
578 Instruction::Shl, GlobalMemFenceMask,
579 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
580
581 // And combine the above together, also adding in
582 // MemorySemanticsSequentiallyConsistentMask.
583 auto MemorySemantics =
584 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
585 ConstantMemorySemantics, "", CI);
586 MemorySemantics = BinaryOperator::Create(
587 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
588
589 // Memory Scope is always device.
590 const auto MemoryScope = ConstantScopeDevice;
591
592 auto NewCI =
593 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
594
595 CI->replaceAllUsesWith(NewCI);
596
597 // Lastly, remember to remove the user.
598 ToRemoves.push_back(CI);
599 }
600 }
601
602 Changed = !ToRemoves.empty();
603
604 // And cleanup the calls we don't use anymore.
605 for (auto V : ToRemoves) {
606 V->eraseFromParent();
607 }
608
609 // And remove the function we don't need either too.
610 F->eraseFromParent();
611 }
612 }
613
614 return Changed;
615}
616
617bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
618 bool Changed = false;
619
620 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
621 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
622 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
623 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
624 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
625 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
626 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
627 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
628 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
629 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
630 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
631 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
632 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
633 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
634 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
635 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
636 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
637 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
638 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
639 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
640 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
641 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
642 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
643 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
644 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
645 };
646
647 for (auto Pair : Map) {
648 // If we find a function with the matching name.
649 if (auto F = M.getFunction(Pair.first)) {
650 SmallVector<Instruction *, 4> ToRemoves;
651
652 // Walk the users of the function.
653 for (auto &U : F->uses()) {
654 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
655 // The predicate to use in the CmpInst.
656 auto Predicate = Pair.second.first;
657
658 // The value to return for true.
659 auto TrueValue =
660 ConstantInt::getSigned(CI->getType(), Pair.second.second);
661
662 // The value to return for false.
663 auto FalseValue = Constant::getNullValue(CI->getType());
664
665 auto Arg1 = CI->getOperand(0);
666 auto Arg2 = CI->getOperand(1);
667
668 const auto Cmp =
669 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
670
671 const auto Select =
672 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
673
674 CI->replaceAllUsesWith(Select);
675
676 // Lastly, remember to remove the user.
677 ToRemoves.push_back(CI);
678 }
679 }
680
681 Changed = !ToRemoves.empty();
682
683 // And cleanup the calls we don't use anymore.
684 for (auto V : ToRemoves) {
685 V->eraseFromParent();
686 }
687
688 // And remove the function we don't need either too.
689 F->eraseFromParent();
690 }
691 }
692
693 return Changed;
694}
695
696bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
697 bool Changed = false;
698
699 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
700 {"_Z5isinff", {"__spirv_isinff", 1}},
701 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
702 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
703 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
704 {"_Z5isnanf", {"__spirv_isnanf", 1}},
705 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
706 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
707 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
708 };
709
710 for (auto Pair : Map) {
711 // If we find a function with the matching name.
712 if (auto F = M.getFunction(Pair.first)) {
713 SmallVector<Instruction *, 4> ToRemoves;
714
715 // Walk the users of the function.
716 for (auto &U : F->uses()) {
717 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
718 const auto CITy = CI->getType();
719
720 // The fake SPIR-V intrinsic to generate.
721 auto SPIRVIntrinsic = Pair.second.first;
722
723 // The value to return for true.
724 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
725
726 // The value to return for false.
727 auto FalseValue = Constant::getNullValue(CITy);
728
729 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
730 M.getContext(),
731 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
732
733 auto NewFType =
734 FunctionType::get(CorrespondingBoolTy,
735 F->getFunctionType()->getParamType(0), false);
736
737 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
738
739 auto Arg = CI->getOperand(0);
740
741 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
742
743 const auto Select =
744 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
745
746 CI->replaceAllUsesWith(Select);
747
748 // Lastly, remember to remove the user.
749 ToRemoves.push_back(CI);
750 }
751 }
752
753 Changed = !ToRemoves.empty();
754
755 // And cleanup the calls we don't use anymore.
756 for (auto V : ToRemoves) {
757 V->eraseFromParent();
758 }
759
760 // And remove the function we don't need either too.
761 F->eraseFromParent();
762 }
763 }
764
765 return Changed;
766}
767
768bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
769 bool Changed = false;
770
771 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000772 // all
773 {"_Z3alls", ""},
774 {"_Z3allDv2_s", "__spirv_allDv2_s"},
775 {"_Z3allDv3_s", "__spirv_allDv3_s"},
776 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400777 {"_Z3alli", ""},
778 {"_Z3allDv2_i", "__spirv_allDv2_i"},
779 {"_Z3allDv3_i", "__spirv_allDv3_i"},
780 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000781 {"_Z3alll", ""},
782 {"_Z3allDv2_l", "__spirv_allDv2_l"},
783 {"_Z3allDv3_l", "__spirv_allDv3_l"},
784 {"_Z3allDv4_l", "__spirv_allDv4_l"},
785
786 // any
787 {"_Z3anys", ""},
788 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
789 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
790 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400791 {"_Z3anyi", ""},
792 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
793 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
794 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000795 {"_Z3anyl", ""},
796 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
797 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
798 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400799 };
800
801 for (auto Pair : Map) {
802 // If we find a function with the matching name.
803 if (auto F = M.getFunction(Pair.first)) {
804 SmallVector<Instruction *, 4> ToRemoves;
805
806 // Walk the users of the function.
807 for (auto &U : F->uses()) {
808 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
809 // The fake SPIR-V intrinsic to generate.
810 auto SPIRVIntrinsic = Pair.second;
811
812 auto Arg = CI->getOperand(0);
813
814 Value *V;
815
Kévin Petitfd27cca2018-10-31 13:00:17 +0000816 // If the argument is a 32-bit int, just use a shift
817 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
818 V = BinaryOperator::Create(Instruction::LShr, Arg,
819 ConstantInt::get(Arg->getType(), 31), "",
820 CI);
821 } else {
David Neto22f144c2017-06-12 14:26:21 -0400822 // The value for zero to compare against.
823 const auto ZeroValue = Constant::getNullValue(Arg->getType());
824
David Neto22f144c2017-06-12 14:26:21 -0400825 // The value to return for true.
826 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
827
828 // The value to return for false.
829 const auto FalseValue = Constant::getNullValue(CI->getType());
830
Kévin Petitfd27cca2018-10-31 13:00:17 +0000831 const auto Cmp = CmpInst::Create(
832 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
833
834 Value* SelectSource;
835
836 // If we have a function to call, call it!
837 if (0 < strlen(SPIRVIntrinsic)) {
838
839 const auto NewFType = FunctionType::get(
840 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
841
842 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
843
844 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
845
846 SelectSource = NewCI;
847
848 } else {
849 SelectSource = Cmp;
850 }
851
852 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400853 }
854
855 CI->replaceAllUsesWith(V);
856
857 // Lastly, remember to remove the user.
858 ToRemoves.push_back(CI);
859 }
860 }
861
862 Changed = !ToRemoves.empty();
863
864 // And cleanup the calls we don't use anymore.
865 for (auto V : ToRemoves) {
866 V->eraseFromParent();
867 }
868
869 // And remove the function we don't need either too.
870 F->eraseFromParent();
871 }
872 }
873
874 return Changed;
875}
876
Kévin Petitbf0036c2019-03-06 13:57:10 +0000877bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
878 bool Changed = false;
879
880 for (auto const &SymVal : M.getValueSymbolTable()) {
881 // Skip symbols whose name doesn't match
882 if (!SymVal.getKey().startswith("_Z8upsample")) {
883 continue;
884 }
885 // Is there a function going by that name?
886 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
887
888 SmallVector<Instruction *, 4> ToRemoves;
889
890 // Walk the users of the function.
891 for (auto &U : F->uses()) {
892 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
893
894 // Get arguments
895 auto HiValue = CI->getOperand(0);
896 auto LoValue = CI->getOperand(1);
897
898 // Don't touch overloads that aren't in OpenCL C
899 auto HiType = HiValue->getType();
900 auto LoType = LoValue->getType();
901
902 if (HiType != LoType) {
903 continue;
904 }
905
906 if (!HiType->isIntOrIntVectorTy()) {
907 continue;
908 }
909
910 if (HiType->getScalarSizeInBits() * 2 !=
911 CI->getType()->getScalarSizeInBits()) {
912 continue;
913 }
914
915 if ((HiType->getScalarSizeInBits() != 8) &&
916 (HiType->getScalarSizeInBits() != 16) &&
917 (HiType->getScalarSizeInBits() != 32)) {
918 continue;
919 }
920
921 if (HiType->isVectorTy()) {
922 if ((HiType->getVectorNumElements() != 2) &&
923 (HiType->getVectorNumElements() != 3) &&
924 (HiType->getVectorNumElements() != 4) &&
925 (HiType->getVectorNumElements() != 8) &&
926 (HiType->getVectorNumElements() != 16)) {
927 continue;
928 }
929 }
930
931 // Convert both operands to the result type
932 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
933 "", CI);
934 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
935 "", CI);
936
937 // Shift high operand
938 auto ShiftAmount = ConstantInt::get(CI->getType(),
939 HiType->getScalarSizeInBits());
940 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
941 ShiftAmount, "", CI);
942
943 // OR both results
944 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
945 "", CI);
946
947 // Replace call with the expression
948 CI->replaceAllUsesWith(V);
949
950 // Lastly, remember to remove the user.
951 ToRemoves.push_back(CI);
952 }
953 }
954
955 Changed = !ToRemoves.empty();
956
957 // And cleanup the calls we don't use anymore.
958 for (auto V : ToRemoves) {
959 V->eraseFromParent();
960 }
961
962 // And remove the function we don't need either too.
963 F->eraseFromParent();
964 }
965 }
966
967 return Changed;
968}
969
Kévin Petitf5b78a22018-10-25 14:32:17 +0000970bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
971 bool Changed = false;
972
973 for (auto const &SymVal : M.getValueSymbolTable()) {
974 // Skip symbols whose name doesn't match
975 if (!SymVal.getKey().startswith("_Z6select")) {
976 continue;
977 }
978 // Is there a function going by that name?
979 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
980
981 SmallVector<Instruction *, 4> ToRemoves;
982
983 // Walk the users of the function.
984 for (auto &U : F->uses()) {
985 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
986
987 // Get arguments
988 auto FalseValue = CI->getOperand(0);
989 auto TrueValue = CI->getOperand(1);
990 auto PredicateValue = CI->getOperand(2);
991
992 // Don't touch overloads that aren't in OpenCL C
993 auto FalseType = FalseValue->getType();
994 auto TrueType = TrueValue->getType();
995 auto PredicateType = PredicateValue->getType();
996
997 if (FalseType != TrueType) {
998 continue;
999 }
1000
1001 if (!PredicateType->isIntOrIntVectorTy()) {
1002 continue;
1003 }
1004
1005 if (!FalseType->isIntOrIntVectorTy() &&
1006 !FalseType->getScalarType()->isFloatingPointTy()) {
1007 continue;
1008 }
1009
1010 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1011 continue;
1012 }
1013
1014 if (FalseType->getScalarSizeInBits() !=
1015 PredicateType->getScalarSizeInBits()) {
1016 continue;
1017 }
1018
1019 if (FalseType->isVectorTy()) {
1020 if (FalseType->getVectorNumElements() !=
1021 PredicateType->getVectorNumElements()) {
1022 continue;
1023 }
1024
1025 if ((FalseType->getVectorNumElements() != 2) &&
1026 (FalseType->getVectorNumElements() != 3) &&
1027 (FalseType->getVectorNumElements() != 4) &&
1028 (FalseType->getVectorNumElements() != 8) &&
1029 (FalseType->getVectorNumElements() != 16)) {
1030 continue;
1031 }
1032 }
1033
1034 // Create constant
1035 const auto ZeroValue = Constant::getNullValue(PredicateType);
1036
1037 // Scalar and vector are to be treated differently
1038 CmpInst::Predicate Pred;
1039 if (PredicateType->isVectorTy()) {
1040 Pred = CmpInst::ICMP_SLT;
1041 } else {
1042 Pred = CmpInst::ICMP_NE;
1043 }
1044
1045 // Create comparison instruction
1046 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1047 ZeroValue, "", CI);
1048
1049 // Create select
1050 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1051
1052 // Replace call with the selection
1053 CI->replaceAllUsesWith(V);
1054
1055 // Lastly, remember to remove the user.
1056 ToRemoves.push_back(CI);
1057 }
1058 }
1059
1060 Changed = !ToRemoves.empty();
1061
1062 // And cleanup the calls we don't use anymore.
1063 for (auto V : ToRemoves) {
1064 V->eraseFromParent();
1065 }
1066
1067 // And remove the function we don't need either too.
1068 F->eraseFromParent();
1069 }
1070 }
1071
1072 return Changed;
1073}
1074
Kévin Petite7d0cce2018-10-31 12:38:56 +00001075bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1076 bool Changed = false;
1077
1078 for (auto const &SymVal : M.getValueSymbolTable()) {
1079 // Skip symbols whose name doesn't match
1080 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1081 continue;
1082 }
1083 // Is there a function going by that name?
1084 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1085
1086 SmallVector<Instruction *, 4> ToRemoves;
1087
1088 // Walk the users of the function.
1089 for (auto &U : F->uses()) {
1090 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1091
1092 if (CI->getNumOperands() != 4) {
1093 continue;
1094 }
1095
1096 // Get arguments
1097 auto FalseValue = CI->getOperand(0);
1098 auto TrueValue = CI->getOperand(1);
1099 auto PredicateValue = CI->getOperand(2);
1100
1101 // Don't touch overloads that aren't in OpenCL C
1102 auto FalseType = FalseValue->getType();
1103 auto TrueType = TrueValue->getType();
1104 auto PredicateType = PredicateValue->getType();
1105
1106 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1107 continue;
1108 }
1109
1110 if (TrueType->isVectorTy()) {
1111 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1112 !TrueType->getScalarType()->isIntegerTy()) {
1113 continue;
1114 }
1115 if ((TrueType->getVectorNumElements() != 2) &&
1116 (TrueType->getVectorNumElements() != 3) &&
1117 (TrueType->getVectorNumElements() != 4) &&
1118 (TrueType->getVectorNumElements() != 8) &&
1119 (TrueType->getVectorNumElements() != 16)) {
1120 continue;
1121 }
1122 }
1123
1124 // Remember the type of the operands
1125 auto OpType = TrueType;
1126
1127 // The actual bit selection will always be done on an integer type,
1128 // declare it here
1129 Type *BitType;
1130
1131 // If the operands are float, then bitcast them to int
1132 if (OpType->getScalarType()->isFloatingPointTy()) {
1133
1134 // First create the new type
1135 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1136 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1137 if (OpType->isVectorTy()) {
1138 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1139 }
1140
1141 // Then bitcast all operands
1142 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1143 BitType, "", CI);
1144 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1145 BitType, "", CI);
1146 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1147
1148 } else {
1149 // The operands have an integer type, use it directly
1150 BitType = OpType;
1151 }
1152
1153 // All the operands are now always integers
1154 // implement as (c & b) | (~c & a)
1155
1156 // Create our negated predicate value
1157 auto AllOnes = Constant::getAllOnesValue(BitType);
1158 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1159 PredicateValue,
1160 AllOnes, "", CI);
1161
1162 // Then put everything together
1163 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1164 NotPredicateValue,
1165 FalseValue, "", CI);
1166 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1167 PredicateValue,
1168 TrueValue, "", CI);
1169
1170 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1171 BitsTrue, "", CI);
1172
1173 // If we were dealing with a floating point type, we must bitcast
1174 // the result back to that
1175 if (OpType->getScalarType()->isFloatingPointTy()) {
1176 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1177 }
1178
1179 // Replace call with our new code
1180 CI->replaceAllUsesWith(V);
1181
1182 // Lastly, remember to remove the user.
1183 ToRemoves.push_back(CI);
1184 }
1185 }
1186
1187 Changed = !ToRemoves.empty();
1188
1189 // And cleanup the calls we don't use anymore.
1190 for (auto V : ToRemoves) {
1191 V->eraseFromParent();
1192 }
1193
1194 // And remove the function we don't need either too.
1195 F->eraseFromParent();
1196 }
1197 }
1198
1199 return Changed;
1200}
1201
Kévin Petit6b0a9532018-10-30 20:00:39 +00001202bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1203 bool Changed = false;
1204
1205 const std::map<const char *, const char *> Map = {
1206 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1207 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1208 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1209 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1210 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1211 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1212 };
1213
1214 for (auto Pair : Map) {
1215 // If we find a function with the matching name.
1216 if (auto F = M.getFunction(Pair.first)) {
1217 SmallVector<Instruction *, 4> ToRemoves;
1218
1219 // Walk the users of the function.
1220 for (auto &U : F->uses()) {
1221 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1222
1223 auto ReplacementFn = Pair.second;
1224
1225 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1226 Value *VectorArg;
1227
1228 // First figure out which function we're dealing with
1229 if (F->getName().startswith("_Z10smoothstep")) {
1230 ArgsToSplat.push_back(CI->getOperand(1));
1231 VectorArg = CI->getOperand(2);
1232 } else {
1233 VectorArg = CI->getOperand(1);
1234 }
1235
1236 // Splat arguments that need to be
1237 SmallVector<Value*, 2> SplatArgs;
1238 auto VecType = VectorArg->getType();
1239
1240 for (auto arg : ArgsToSplat) {
1241 Value* NewVectorArg = UndefValue::get(VecType);
1242 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1243 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1244 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1245 }
1246 SplatArgs.push_back(NewVectorArg);
1247 }
1248
1249 // Replace the call with the vector/vector flavour
1250 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1251 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1252
1253 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1254
1255 SmallVector<Value*, 3> NewArgs;
1256 for (auto arg : SplatArgs) {
1257 NewArgs.push_back(arg);
1258 }
1259 NewArgs.push_back(VectorArg);
1260
1261 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1262
1263 CI->replaceAllUsesWith(NewCI);
1264
1265 // Lastly, remember to remove the user.
1266 ToRemoves.push_back(CI);
1267 }
1268 }
1269
1270 Changed = !ToRemoves.empty();
1271
1272 // And cleanup the calls we don't use anymore.
1273 for (auto V : ToRemoves) {
1274 V->eraseFromParent();
1275 }
1276
1277 // And remove the function we don't need either too.
1278 F->eraseFromParent();
1279 }
1280 }
1281
1282 return Changed;
1283}
1284
David Neto22f144c2017-06-12 14:26:21 -04001285bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1286 bool Changed = false;
1287
1288 const std::map<const char *, Instruction::BinaryOps> Map = {
1289 {"_Z7signbitf", Instruction::LShr},
1290 {"_Z7signbitDv2_f", Instruction::AShr},
1291 {"_Z7signbitDv3_f", Instruction::AShr},
1292 {"_Z7signbitDv4_f", Instruction::AShr},
1293 };
1294
1295 for (auto Pair : Map) {
1296 // If we find a function with the matching name.
1297 if (auto F = M.getFunction(Pair.first)) {
1298 SmallVector<Instruction *, 4> ToRemoves;
1299
1300 // Walk the users of the function.
1301 for (auto &U : F->uses()) {
1302 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1303 auto Arg = CI->getOperand(0);
1304
1305 auto Bitcast =
1306 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1307
1308 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1309 ConstantInt::get(CI->getType(), 31),
1310 "", CI);
1311
1312 CI->replaceAllUsesWith(Shr);
1313
1314 // Lastly, remember to remove the user.
1315 ToRemoves.push_back(CI);
1316 }
1317 }
1318
1319 Changed = !ToRemoves.empty();
1320
1321 // And cleanup the calls we don't use anymore.
1322 for (auto V : ToRemoves) {
1323 V->eraseFromParent();
1324 }
1325
1326 // And remove the function we don't need either too.
1327 F->eraseFromParent();
1328 }
1329 }
1330
1331 return Changed;
1332}
1333
1334bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1335 bool Changed = false;
1336
1337 const std::map<const char *,
1338 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1339 Map = {
1340 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1341 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1342 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1343 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1344 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1345 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1346 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1347 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1348 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1349 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1350 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1351 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1352 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1353 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1354 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1355 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1356 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1357 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1358 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1359 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1360 };
1361
1362 for (auto Pair : Map) {
1363 // If we find a function with the matching name.
1364 if (auto F = M.getFunction(Pair.first)) {
1365 SmallVector<Instruction *, 4> ToRemoves;
1366
1367 // Walk the users of the function.
1368 for (auto &U : F->uses()) {
1369 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1370 // The multiply instruction to use.
1371 auto MulInst = Pair.second.first;
1372
1373 // The add instruction to use.
1374 auto AddInst = Pair.second.second;
1375
1376 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1377
1378 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1379 CI->getArgOperand(1), "", CI);
1380
1381 if (Instruction::BinaryOpsEnd != AddInst) {
1382 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1383 CI);
1384 }
1385
1386 CI->replaceAllUsesWith(I);
1387
1388 // Lastly, remember to remove the user.
1389 ToRemoves.push_back(CI);
1390 }
1391 }
1392
1393 Changed = !ToRemoves.empty();
1394
1395 // And cleanup the calls we don't use anymore.
1396 for (auto V : ToRemoves) {
1397 V->eraseFromParent();
1398 }
1399
1400 // And remove the function we don't need either too.
1401 F->eraseFromParent();
1402 }
1403 }
1404
1405 return Changed;
1406}
1407
Derek Chowcfd368b2017-10-19 20:58:45 -07001408bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1409 bool Changed = false;
1410
1411 struct VectorStoreOps {
1412 const char* name;
1413 int n;
1414 Type* (*get_scalar_type_function)(LLVMContext&);
1415 } vector_store_ops[] = {
1416 // TODO(derekjchow): Expand this list.
1417 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1418 };
1419
David Neto544fffc2017-11-16 18:35:14 -05001420 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001421 auto Name = Op.name;
1422 auto N = Op.n;
1423 auto TypeFn = Op.get_scalar_type_function;
1424 if (auto F = M.getFunction(Name)) {
1425 SmallVector<Instruction *, 4> ToRemoves;
1426
1427 // Walk the users of the function.
1428 for (auto &U : F->uses()) {
1429 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1430 // The value argument from vstoren.
1431 auto Arg0 = CI->getOperand(0);
1432
1433 // The index argument from vstoren.
1434 auto Arg1 = CI->getOperand(1);
1435
1436 // The pointer argument from vstoren.
1437 auto Arg2 = CI->getOperand(2);
1438
1439 // Get types.
1440 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1441 auto ScalarNPointerTy = PointerType::get(
1442 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1443
1444 // Cast to scalarn
1445 auto Cast = CastInst::CreatePointerCast(
1446 Arg2, ScalarNPointerTy, "", CI);
1447 // Index to correct address
1448 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1449 // Store
1450 auto Store = new StoreInst(Arg0, Index, CI);
1451
1452 CI->replaceAllUsesWith(Store);
1453 ToRemoves.push_back(CI);
1454 }
1455 }
1456
1457 Changed = !ToRemoves.empty();
1458
1459 // And cleanup the calls we don't use anymore.
1460 for (auto V : ToRemoves) {
1461 V->eraseFromParent();
1462 }
1463
1464 // And remove the function we don't need either too.
1465 F->eraseFromParent();
1466 }
1467 }
1468
1469 return Changed;
1470}
1471
1472bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1473 bool Changed = false;
1474
1475 struct VectorLoadOps {
1476 const char* name;
1477 int n;
1478 Type* (*get_scalar_type_function)(LLVMContext&);
1479 } vector_load_ops[] = {
1480 // TODO(derekjchow): Expand this list.
1481 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1482 };
1483
David Neto544fffc2017-11-16 18:35:14 -05001484 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001485 auto Name = Op.name;
1486 auto N = Op.n;
1487 auto TypeFn = Op.get_scalar_type_function;
1488 // If we find a function with the matching name.
1489 if (auto F = M.getFunction(Name)) {
1490 SmallVector<Instruction *, 4> ToRemoves;
1491
1492 // Walk the users of the function.
1493 for (auto &U : F->uses()) {
1494 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1495 // The index argument from vloadn.
1496 auto Arg0 = CI->getOperand(0);
1497
1498 // The pointer argument from vloadn.
1499 auto Arg1 = CI->getOperand(1);
1500
1501 // Get types.
1502 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1503 auto ScalarNPointerTy = PointerType::get(
1504 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1505
1506 // Cast to scalarn
1507 auto Cast = CastInst::CreatePointerCast(
1508 Arg1, ScalarNPointerTy, "", CI);
1509 // Index to correct address
1510 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1511 // Load
1512 auto Load = new LoadInst(Index, "", CI);
1513
1514 CI->replaceAllUsesWith(Load);
1515 ToRemoves.push_back(CI);
1516 }
1517 }
1518
1519 Changed = !ToRemoves.empty();
1520
1521 // And cleanup the calls we don't use anymore.
1522 for (auto V : ToRemoves) {
1523 V->eraseFromParent();
1524 }
1525
1526 // And remove the function we don't need either too.
1527 F->eraseFromParent();
1528
1529 }
1530 }
1531
1532 return Changed;
1533}
1534
David Neto22f144c2017-06-12 14:26:21 -04001535bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1536 bool Changed = false;
1537
1538 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
1539 "_Z10vload_halfjPU3AS2KDh"};
1540
1541 for (auto Name : Map) {
1542 // If we find a function with the matching name.
1543 if (auto F = M.getFunction(Name)) {
1544 SmallVector<Instruction *, 4> ToRemoves;
1545
1546 // Walk the users of the function.
1547 for (auto &U : F->uses()) {
1548 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1549 // The index argument from vload_half.
1550 auto Arg0 = CI->getOperand(0);
1551
1552 // The pointer argument from vload_half.
1553 auto Arg1 = CI->getOperand(1);
1554
David Neto22f144c2017-06-12 14:26:21 -04001555 auto IntTy = Type::getInt32Ty(M.getContext());
1556 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001557 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1558
David Neto22f144c2017-06-12 14:26:21 -04001559 // Our intrinsic to unpack a float2 from an int.
1560 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1561
1562 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1563
David Neto482550a2018-03-24 05:21:07 -07001564 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04001565 auto ShortTy = Type::getInt16Ty(M.getContext());
1566 auto ShortPointerTy = PointerType::get(
1567 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001568
David Netoac825b82017-05-30 12:49:01 -04001569 // Cast the half* pointer to short*.
1570 auto Cast =
1571 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001572
David Netoac825b82017-05-30 12:49:01 -04001573 // Index into the correct address of the casted pointer.
1574 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1575
1576 // Load from the short* we casted to.
1577 auto Load = new LoadInst(Index, "", CI);
1578
1579 // ZExt the short -> int.
1580 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1581
1582 // Get our float2.
1583 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1584
1585 // Extract out the bottom element which is our float result.
1586 auto Extract = ExtractElementInst::Create(
1587 Call, ConstantInt::get(IntTy, 0), "", CI);
1588
1589 CI->replaceAllUsesWith(Extract);
1590 } else {
1591 // Assume the pointer argument points to storage aligned to 32bits
1592 // or more.
1593 // TODO(dneto): Do more analysis to make sure this is true?
1594 //
1595 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1596 // with:
1597 //
1598 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1599 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1600 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1601 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1602 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1603 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1604 // x float> %converted, %index_is_odd32
1605
1606 auto IntPointerTy = PointerType::get(
1607 IntTy, Arg1->getType()->getPointerAddressSpace());
1608
David Neto973e6a82017-05-30 13:48:18 -04001609 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04001610 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04001611 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04001612 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1613
1614 auto One = ConstantInt::get(IntTy, 1);
1615 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1616 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1617
1618 // Index into the correct address of the casted pointer.
1619 auto Ptr =
1620 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1621
1622 // Load from the int* we casted to.
1623 auto Load = new LoadInst(Ptr, "", CI);
1624
1625 // Get our float2.
1626 auto Call = CallInst::Create(NewF, Load, "", CI);
1627
1628 // Extract out the float result, where the element number is
1629 // determined by whether the original index was even or odd.
1630 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1631
1632 CI->replaceAllUsesWith(Extract);
1633 }
David Neto22f144c2017-06-12 14:26:21 -04001634
1635 // Lastly, remember to remove the user.
1636 ToRemoves.push_back(CI);
1637 }
1638 }
1639
1640 Changed = !ToRemoves.empty();
1641
1642 // And cleanup the calls we don't use anymore.
1643 for (auto V : ToRemoves) {
1644 V->eraseFromParent();
1645 }
1646
1647 // And remove the function we don't need either too.
1648 F->eraseFromParent();
1649 }
1650 }
1651
1652 return Changed;
1653}
1654
1655bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
1656 bool Changed = false;
1657
David Neto556c7e62018-06-08 13:45:55 -07001658 const std::vector<const char *> Map = {
1659 "_Z11vload_half2jPU3AS1KDh",
1660 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
1661 "_Z11vload_half2jPU3AS2KDh",
1662 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
1663 };
David Neto22f144c2017-06-12 14:26:21 -04001664
1665 for (auto Name : Map) {
1666 // If we find a function with the matching name.
1667 if (auto F = M.getFunction(Name)) {
1668 SmallVector<Instruction *, 4> ToRemoves;
1669
1670 // Walk the users of the function.
1671 for (auto &U : F->uses()) {
1672 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1673 // The index argument from vload_half.
1674 auto Arg0 = CI->getOperand(0);
1675
1676 // The pointer argument from vload_half.
1677 auto Arg1 = CI->getOperand(1);
1678
1679 auto IntTy = Type::getInt32Ty(M.getContext());
1680 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1681 auto NewPointerTy = PointerType::get(
1682 IntTy, Arg1->getType()->getPointerAddressSpace());
1683 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1684
1685 // Cast the half* pointer to int*.
1686 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1687
1688 // Index into the correct address of the casted pointer.
1689 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
1690
1691 // Load from the int* we casted to.
1692 auto Load = new LoadInst(Index, "", CI);
1693
1694 // Our intrinsic to unpack a float2 from an int.
1695 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1696
1697 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1698
1699 // Get our float2.
1700 auto Call = CallInst::Create(NewF, Load, "", CI);
1701
1702 CI->replaceAllUsesWith(Call);
1703
1704 // Lastly, remember to remove the user.
1705 ToRemoves.push_back(CI);
1706 }
1707 }
1708
1709 Changed = !ToRemoves.empty();
1710
1711 // And cleanup the calls we don't use anymore.
1712 for (auto V : ToRemoves) {
1713 V->eraseFromParent();
1714 }
1715
1716 // And remove the function we don't need either too.
1717 F->eraseFromParent();
1718 }
1719 }
1720
1721 return Changed;
1722}
1723
1724bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
1725 bool Changed = false;
1726
David Neto556c7e62018-06-08 13:45:55 -07001727 const std::vector<const char *> Map = {
1728 "_Z11vload_half4jPU3AS1KDh",
1729 "_Z12vloada_half4jPU3AS1KDh",
1730 "_Z11vload_half4jPU3AS2KDh",
1731 "_Z12vloada_half4jPU3AS2KDh",
1732 };
David Neto22f144c2017-06-12 14:26:21 -04001733
1734 for (auto Name : Map) {
1735 // If we find a function with the matching name.
1736 if (auto F = M.getFunction(Name)) {
1737 SmallVector<Instruction *, 4> ToRemoves;
1738
1739 // Walk the users of the function.
1740 for (auto &U : F->uses()) {
1741 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1742 // The index argument from vload_half.
1743 auto Arg0 = CI->getOperand(0);
1744
1745 // The pointer argument from vload_half.
1746 auto Arg1 = CI->getOperand(1);
1747
1748 auto IntTy = Type::getInt32Ty(M.getContext());
1749 auto Int2Ty = VectorType::get(IntTy, 2);
1750 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1751 auto NewPointerTy = PointerType::get(
1752 Int2Ty, Arg1->getType()->getPointerAddressSpace());
1753 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1754
1755 // Cast the half* pointer to int2*.
1756 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1757
1758 // Index into the correct address of the casted pointer.
1759 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
1760
1761 // Load from the int2* we casted to.
1762 auto Load = new LoadInst(Index, "", CI);
1763
1764 // Extract each element from the loaded int2.
1765 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1766 "", CI);
1767 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1768 "", CI);
1769
1770 // Our intrinsic to unpack a float2 from an int.
1771 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1772
1773 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1774
1775 // Get the lower (x & y) components of our final float4.
1776 auto Lo = CallInst::Create(NewF, X, "", CI);
1777
1778 // Get the higher (z & w) components of our final float4.
1779 auto Hi = CallInst::Create(NewF, Y, "", CI);
1780
1781 Constant *ShuffleMask[4] = {
1782 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1783 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1784
1785 // Combine our two float2's into one float4.
1786 auto Combine = new ShuffleVectorInst(
1787 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1788
1789 CI->replaceAllUsesWith(Combine);
1790
1791 // Lastly, remember to remove the user.
1792 ToRemoves.push_back(CI);
1793 }
1794 }
1795
1796 Changed = !ToRemoves.empty();
1797
1798 // And cleanup the calls we don't use anymore.
1799 for (auto V : ToRemoves) {
1800 V->eraseFromParent();
1801 }
1802
1803 // And remove the function we don't need either too.
1804 F->eraseFromParent();
1805 }
1806 }
1807
1808 return Changed;
1809}
1810
David Neto6ad93232018-06-07 15:42:58 -07001811bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
1812 bool Changed = false;
1813
1814 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1815 //
1816 // %u = load i32 %ptr
1817 // %fxy = call <2 x float> Unpack2xHalf(u)
1818 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1819 const std::vector<const char *> Map = {
1820 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
1821 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
1822 "_Z20__clspv_vloada_half2jPKj", // private
1823 };
1824
1825 for (auto Name : Map) {
1826 // If we find a function with the matching name.
1827 if (auto F = M.getFunction(Name)) {
1828 SmallVector<Instruction *, 4> ToRemoves;
1829
1830 // Walk the users of the function.
1831 for (auto &U : F->uses()) {
1832 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
1833 auto Index = CI->getOperand(0);
1834 auto Ptr = CI->getOperand(1);
1835
1836 auto IntTy = Type::getInt32Ty(M.getContext());
1837 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1838 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1839
1840 auto IndexedPtr =
1841 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
1842 auto Load = new LoadInst(IndexedPtr, "", CI);
1843
1844 // Our intrinsic to unpack a float2 from an int.
1845 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1846
1847 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1848
1849 // Get our final float2.
1850 auto Result = CallInst::Create(NewF, Load, "", CI);
1851
1852 CI->replaceAllUsesWith(Result);
1853
1854 // Lastly, remember to remove the user.
1855 ToRemoves.push_back(CI);
1856 }
1857 }
1858
1859 Changed = true;
1860
1861 // And cleanup the calls we don't use anymore.
1862 for (auto V : ToRemoves) {
1863 V->eraseFromParent();
1864 }
1865
1866 // And remove the function we don't need either too.
1867 F->eraseFromParent();
1868 }
1869 }
1870
1871 return Changed;
1872}
1873
1874bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
1875 bool Changed = false;
1876
1877 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1878 //
1879 // %u2 = load <2 x i32> %ptr
1880 // %u2xy = extractelement %u2, 0
1881 // %u2zw = extractelement %u2, 1
1882 // %fxy = call <2 x float> Unpack2xHalf(uint)
1883 // %fzw = call <2 x float> Unpack2xHalf(uint)
1884 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1885 const std::vector<const char *> Map = {
1886 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
1887 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
1888 "_Z20__clspv_vloada_half4jPKDv2_j", // private
1889 };
1890
1891 for (auto Name : Map) {
1892 // If we find a function with the matching name.
1893 if (auto F = M.getFunction(Name)) {
1894 SmallVector<Instruction *, 4> ToRemoves;
1895
1896 // Walk the users of the function.
1897 for (auto &U : F->uses()) {
1898 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1899 auto Index = CI->getOperand(0);
1900 auto Ptr = CI->getOperand(1);
1901
1902 auto IntTy = Type::getInt32Ty(M.getContext());
1903 auto Int2Ty = VectorType::get(IntTy, 2);
1904 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1905 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1906
1907 auto IndexedPtr =
1908 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
1909 auto Load = new LoadInst(IndexedPtr, "", CI);
1910
1911 // Extract each element from the loaded int2.
1912 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1913 "", CI);
1914 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1915 "", CI);
1916
1917 // Our intrinsic to unpack a float2 from an int.
1918 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1919
1920 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1921
1922 // Get the lower (x & y) components of our final float4.
1923 auto Lo = CallInst::Create(NewF, X, "", CI);
1924
1925 // Get the higher (z & w) components of our final float4.
1926 auto Hi = CallInst::Create(NewF, Y, "", CI);
1927
1928 Constant *ShuffleMask[4] = {
1929 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1930 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1931
1932 // Combine our two float2's into one float4.
1933 auto Combine = new ShuffleVectorInst(
1934 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1935
1936 CI->replaceAllUsesWith(Combine);
1937
1938 // Lastly, remember to remove the user.
1939 ToRemoves.push_back(CI);
1940 }
1941 }
1942
1943 Changed = true;
1944
1945 // And cleanup the calls we don't use anymore.
1946 for (auto V : ToRemoves) {
1947 V->eraseFromParent();
1948 }
1949
1950 // And remove the function we don't need either too.
1951 F->eraseFromParent();
1952 }
1953 }
1954
1955 return Changed;
1956}
1957
David Neto22f144c2017-06-12 14:26:21 -04001958bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
1959 bool Changed = false;
1960
1961 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
1962 "_Z15vstore_half_rtefjPU3AS1Dh",
1963 "_Z15vstore_half_rtzfjPU3AS1Dh"};
1964
1965 for (auto Name : Map) {
1966 // If we find a function with the matching name.
1967 if (auto F = M.getFunction(Name)) {
1968 SmallVector<Instruction *, 4> ToRemoves;
1969
1970 // Walk the users of the function.
1971 for (auto &U : F->uses()) {
1972 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1973 // The value to store.
1974 auto Arg0 = CI->getOperand(0);
1975
1976 // The index argument from vstore_half.
1977 auto Arg1 = CI->getOperand(1);
1978
1979 // The pointer argument from vstore_half.
1980 auto Arg2 = CI->getOperand(2);
1981
David Neto22f144c2017-06-12 14:26:21 -04001982 auto IntTy = Type::getInt32Ty(M.getContext());
1983 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001984 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04001985 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04001986
1987 // Our intrinsic to pack a float2 to an int.
1988 auto SPIRVIntrinsic = "spirv.pack.v2f16";
1989
1990 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1991
1992 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04001993 auto TempVec =
1994 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
1995 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001996
1997 // Pack the float2 -> half2 (in an int).
1998 auto X = CallInst::Create(NewF, TempVec, "", CI);
1999
David Neto482550a2018-03-24 05:21:07 -07002000 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002001 auto ShortTy = Type::getInt16Ty(M.getContext());
2002 auto ShortPointerTy = PointerType::get(
2003 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002004
David Neto17852de2017-05-29 17:29:31 -04002005 // Truncate our i32 to an i16.
2006 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002007
David Neto17852de2017-05-29 17:29:31 -04002008 // Cast the half* pointer to short*.
2009 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002010
David Neto17852de2017-05-29 17:29:31 -04002011 // Index into the correct address of the casted pointer.
2012 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002013
David Neto17852de2017-05-29 17:29:31 -04002014 // Store to the int* we casted to.
2015 auto Store = new StoreInst(Trunc, Index, CI);
2016
2017 CI->replaceAllUsesWith(Store);
2018 } else {
2019 // We can only write to 32-bit aligned words.
2020 //
2021 // Assuming base is aligned to 32-bits, replace the equivalent of
2022 // vstore_half(value, index, base)
2023 // with:
2024 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2025 // uint32_t write_to_upper_half = index & 1u;
2026 // uint32_t shift = write_to_upper_half << 4;
2027 //
2028 // // Pack the float value as a half number in bottom 16 bits
2029 // // of an i32.
2030 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2031 //
2032 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2033 // ^ ((packed & 0xffff) << shift)
2034 // // We only need relaxed consistency, but OpenCL 1.2 only has
2035 // // sequentially consistent atomics.
2036 // // TODO(dneto): Use relaxed consistency.
2037 // atomic_xor(target_ptr, xor_value)
2038 auto IntPointerTy = PointerType::get(
2039 IntTy, Arg2->getType()->getPointerAddressSpace());
2040
2041 auto Four = ConstantInt::get(IntTy, 4);
2042 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2043
2044 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2045 // Compute index / 2
2046 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2047 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2048 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2049 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2050 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2051 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2052 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2053
2054 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2055 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2056 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2057
2058 // Generate the call to atomi_xor.
2059 SmallVector<Type *, 5> ParamTypes;
2060 // The pointer type.
2061 ParamTypes.push_back(IntPointerTy);
2062 // The Types for memory scope, semantics, and value.
2063 ParamTypes.push_back(IntTy);
2064 ParamTypes.push_back(IntTy);
2065 ParamTypes.push_back(IntTy);
2066 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2067 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2068
2069 const auto ConstantScopeDevice =
2070 ConstantInt::get(IntTy, spv::ScopeDevice);
2071 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2072 // (SPIR-V Workgroup).
2073 const auto AddrSpaceSemanticsBits =
2074 IntPointerTy->getPointerAddressSpace() == 1
2075 ? spv::MemorySemanticsUniformMemoryMask
2076 : spv::MemorySemanticsWorkgroupMemoryMask;
2077
2078 // We're using relaxed consistency here.
2079 const auto ConstantMemorySemantics =
2080 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2081 AddrSpaceSemanticsBits);
2082
2083 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2084 ConstantMemorySemantics, ValueToXor};
2085 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2086 }
David Neto22f144c2017-06-12 14:26:21 -04002087
2088 // Lastly, remember to remove the user.
2089 ToRemoves.push_back(CI);
2090 }
2091 }
2092
2093 Changed = !ToRemoves.empty();
2094
2095 // And cleanup the calls we don't use anymore.
2096 for (auto V : ToRemoves) {
2097 V->eraseFromParent();
2098 }
2099
2100 // And remove the function we don't need either too.
2101 F->eraseFromParent();
2102 }
2103 }
2104
2105 return Changed;
2106}
2107
2108bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2109 bool Changed = false;
2110
David Netoe2871522018-06-08 11:09:54 -07002111 const std::vector<const char *> Map = {
2112 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2113 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2114 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2115 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2116 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2117 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2118 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2119 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2120 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2121 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2122 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2123 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2124 };
David Neto22f144c2017-06-12 14:26:21 -04002125
2126 for (auto Name : Map) {
2127 // If we find a function with the matching name.
2128 if (auto F = M.getFunction(Name)) {
2129 SmallVector<Instruction *, 4> ToRemoves;
2130
2131 // Walk the users of the function.
2132 for (auto &U : F->uses()) {
2133 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2134 // The value to store.
2135 auto Arg0 = CI->getOperand(0);
2136
2137 // The index argument from vstore_half.
2138 auto Arg1 = CI->getOperand(1);
2139
2140 // The pointer argument from vstore_half.
2141 auto Arg2 = CI->getOperand(2);
2142
2143 auto IntTy = Type::getInt32Ty(M.getContext());
2144 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2145 auto NewPointerTy = PointerType::get(
2146 IntTy, Arg2->getType()->getPointerAddressSpace());
2147 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2148
2149 // Our intrinsic to pack a float2 to an int.
2150 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2151
2152 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2153
2154 // Turn the packed x & y into the final packing.
2155 auto X = CallInst::Create(NewF, Arg0, "", CI);
2156
2157 // Cast the half* pointer to int*.
2158 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2159
2160 // Index into the correct address of the casted pointer.
2161 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2162
2163 // Store to the int* we casted to.
2164 auto Store = new StoreInst(X, Index, CI);
2165
2166 CI->replaceAllUsesWith(Store);
2167
2168 // Lastly, remember to remove the user.
2169 ToRemoves.push_back(CI);
2170 }
2171 }
2172
2173 Changed = !ToRemoves.empty();
2174
2175 // And cleanup the calls we don't use anymore.
2176 for (auto V : ToRemoves) {
2177 V->eraseFromParent();
2178 }
2179
2180 // And remove the function we don't need either too.
2181 F->eraseFromParent();
2182 }
2183 }
2184
2185 return Changed;
2186}
2187
2188bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2189 bool Changed = false;
2190
David Netoe2871522018-06-08 11:09:54 -07002191 const std::vector<const char *> Map = {
2192 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2193 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2194 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2195 "_Z13vstorea_half4Dv4_fjPDh", // private
2196 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2197 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2198 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2199 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2200 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2201 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2202 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2203 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2204 };
David Neto22f144c2017-06-12 14:26:21 -04002205
2206 for (auto Name : Map) {
2207 // If we find a function with the matching name.
2208 if (auto F = M.getFunction(Name)) {
2209 SmallVector<Instruction *, 4> ToRemoves;
2210
2211 // Walk the users of the function.
2212 for (auto &U : F->uses()) {
2213 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2214 // The value to store.
2215 auto Arg0 = CI->getOperand(0);
2216
2217 // The index argument from vstore_half.
2218 auto Arg1 = CI->getOperand(1);
2219
2220 // The pointer argument from vstore_half.
2221 auto Arg2 = CI->getOperand(2);
2222
2223 auto IntTy = Type::getInt32Ty(M.getContext());
2224 auto Int2Ty = VectorType::get(IntTy, 2);
2225 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2226 auto NewPointerTy = PointerType::get(
2227 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2228 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2229
2230 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2231 ConstantInt::get(IntTy, 1)};
2232
2233 // Extract out the x & y components of our to store value.
2234 auto Lo =
2235 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2236 ConstantVector::get(LoShuffleMask), "", CI);
2237
2238 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2239 ConstantInt::get(IntTy, 3)};
2240
2241 // Extract out the z & w components of our to store value.
2242 auto Hi =
2243 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2244 ConstantVector::get(HiShuffleMask), "", CI);
2245
2246 // Our intrinsic to pack a float2 to an int.
2247 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2248
2249 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2250
2251 // Turn the packed x & y into the final component of our int2.
2252 auto X = CallInst::Create(NewF, Lo, "", CI);
2253
2254 // Turn the packed z & w into the final component of our int2.
2255 auto Y = CallInst::Create(NewF, Hi, "", CI);
2256
2257 auto Combine = InsertElementInst::Create(
2258 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2259 Combine = InsertElementInst::Create(
2260 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2261
2262 // Cast the half* pointer to int2*.
2263 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2264
2265 // Index into the correct address of the casted pointer.
2266 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2267
2268 // Store to the int2* we casted to.
2269 auto Store = new StoreInst(Combine, Index, CI);
2270
2271 CI->replaceAllUsesWith(Store);
2272
2273 // Lastly, remember to remove the user.
2274 ToRemoves.push_back(CI);
2275 }
2276 }
2277
2278 Changed = !ToRemoves.empty();
2279
2280 // And cleanup the calls we don't use anymore.
2281 for (auto V : ToRemoves) {
2282 V->eraseFromParent();
2283 }
2284
2285 // And remove the function we don't need either too.
2286 F->eraseFromParent();
2287 }
2288 }
2289
2290 return Changed;
2291}
2292
2293bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2294 bool Changed = false;
2295
2296 const std::map<const char *, const char*> Map = {
2297 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2298 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2299 };
2300
2301 for (auto Pair : Map) {
2302 // If we find a function with the matching name.
2303 if (auto F = M.getFunction(Pair.first)) {
2304 SmallVector<Instruction *, 4> ToRemoves;
2305
2306 // Walk the users of the function.
2307 for (auto &U : F->uses()) {
2308 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2309 // The image.
2310 auto Arg0 = CI->getOperand(0);
2311
2312 // The sampler.
2313 auto Arg1 = CI->getOperand(1);
2314
2315 // The coordinate (integer type that we can't handle).
2316 auto Arg2 = CI->getOperand(2);
2317
2318 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2319
2320 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2321
2322 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2323
2324 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2325
2326 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2327
2328 CI->replaceAllUsesWith(NewCI);
2329
2330 // Lastly, remember to remove the user.
2331 ToRemoves.push_back(CI);
2332 }
2333 }
2334
2335 Changed = !ToRemoves.empty();
2336
2337 // And cleanup the calls we don't use anymore.
2338 for (auto V : ToRemoves) {
2339 V->eraseFromParent();
2340 }
2341
2342 // And remove the function we don't need either too.
2343 F->eraseFromParent();
2344 }
2345 }
2346
2347 return Changed;
2348}
2349
2350bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2351 bool Changed = false;
2352
2353 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002354 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2355 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2356 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2357 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2358 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2359 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002360 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2361 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2362 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2363 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2364 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002365 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002366
2367 for (auto Pair : Map) {
2368 // If we find a function with the matching name.
2369 if (auto F = M.getFunction(Pair.first)) {
2370 SmallVector<Instruction *, 4> ToRemoves;
2371
2372 // Walk the users of the function.
2373 for (auto &U : F->uses()) {
2374 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2375 auto FType = F->getFunctionType();
2376 SmallVector<Type *, 5> ParamTypes;
2377
2378 // The pointer type.
2379 ParamTypes.push_back(FType->getParamType(0));
2380
2381 auto IntTy = Type::getInt32Ty(M.getContext());
2382
2383 // The memory scope type.
2384 ParamTypes.push_back(IntTy);
2385
2386 // The memory semantics type.
2387 ParamTypes.push_back(IntTy);
2388
2389 if (2 < CI->getNumArgOperands()) {
2390 // The unequal memory semantics type.
2391 ParamTypes.push_back(IntTy);
2392
2393 // The value type.
2394 ParamTypes.push_back(FType->getParamType(2));
2395
2396 // The comparator type.
2397 ParamTypes.push_back(FType->getParamType(1));
2398 } else if (1 < CI->getNumArgOperands()) {
2399 // The value type.
2400 ParamTypes.push_back(FType->getParamType(1));
2401 }
2402
2403 auto NewFType =
2404 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2405 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2406
2407 // We need to map the OpenCL constants to the SPIR-V equivalents.
2408 const auto ConstantScopeDevice =
2409 ConstantInt::get(IntTy, spv::ScopeDevice);
2410 const auto ConstantMemorySemantics = ConstantInt::get(
2411 IntTy, spv::MemorySemanticsUniformMemoryMask |
2412 spv::MemorySemanticsSequentiallyConsistentMask);
2413
2414 SmallVector<Value *, 5> Params;
2415
2416 // The pointer.
2417 Params.push_back(CI->getArgOperand(0));
2418
2419 // The memory scope.
2420 Params.push_back(ConstantScopeDevice);
2421
2422 // The memory semantics.
2423 Params.push_back(ConstantMemorySemantics);
2424
2425 if (2 < CI->getNumArgOperands()) {
2426 // The unequal memory semantics.
2427 Params.push_back(ConstantMemorySemantics);
2428
2429 // The value.
2430 Params.push_back(CI->getArgOperand(2));
2431
2432 // The comparator.
2433 Params.push_back(CI->getArgOperand(1));
2434 } else if (1 < CI->getNumArgOperands()) {
2435 // The value.
2436 Params.push_back(CI->getArgOperand(1));
2437 }
2438
2439 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2440
2441 CI->replaceAllUsesWith(NewCI);
2442
2443 // Lastly, remember to remove the user.
2444 ToRemoves.push_back(CI);
2445 }
2446 }
2447
2448 Changed = !ToRemoves.empty();
2449
2450 // And cleanup the calls we don't use anymore.
2451 for (auto V : ToRemoves) {
2452 V->eraseFromParent();
2453 }
2454
2455 // And remove the function we don't need either too.
2456 F->eraseFromParent();
2457 }
2458 }
2459
Neil Henning39672102017-09-29 14:33:13 +01002460 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002461 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2462 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2463 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2464 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2465 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2466 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2467 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2468 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2469 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2470 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2471 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2472 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2473 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2474 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2475 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2476 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002477 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2478 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2479 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2480 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2481 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2482 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2483 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2484 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2485 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2486 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2487 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2488 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2489 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2490 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2491 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2492 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2493
2494 for (auto Pair : Map2) {
2495 // If we find a function with the matching name.
2496 if (auto F = M.getFunction(Pair.first)) {
2497 SmallVector<Instruction *, 4> ToRemoves;
2498
2499 // Walk the users of the function.
2500 for (auto &U : F->uses()) {
2501 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2502 auto AtomicOp = new AtomicRMWInst(
2503 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2504 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2505
2506 CI->replaceAllUsesWith(AtomicOp);
2507
2508 // Lastly, remember to remove the user.
2509 ToRemoves.push_back(CI);
2510 }
2511 }
2512
2513 Changed = !ToRemoves.empty();
2514
2515 // And cleanup the calls we don't use anymore.
2516 for (auto V : ToRemoves) {
2517 V->eraseFromParent();
2518 }
2519
2520 // And remove the function we don't need either too.
2521 F->eraseFromParent();
2522 }
2523 }
2524
David Neto22f144c2017-06-12 14:26:21 -04002525 return Changed;
2526}
2527
2528bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2529 bool Changed = false;
2530
2531 // If we find a function with the matching name.
2532 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2533 SmallVector<Instruction *, 4> ToRemoves;
2534
2535 auto IntTy = Type::getInt32Ty(M.getContext());
2536 auto FloatTy = Type::getFloatTy(M.getContext());
2537
2538 Constant *DownShuffleMask[3] = {
2539 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2540 ConstantInt::get(IntTy, 2)};
2541
2542 Constant *UpShuffleMask[4] = {
2543 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2544 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2545
2546 Constant *FloatVec[3] = {
2547 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
2548 };
2549
2550 // Walk the users of the function.
2551 for (auto &U : F->uses()) {
2552 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2553 auto Vec4Ty = CI->getArgOperand(0)->getType();
2554 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2555 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2556 auto Vec3Ty = Arg0->getType();
2557
2558 auto NewFType =
2559 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
2560
2561 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
2562
2563 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
2564
2565 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
2566
2567 CI->replaceAllUsesWith(Result);
2568
2569 // Lastly, remember to remove the user.
2570 ToRemoves.push_back(CI);
2571 }
2572 }
2573
2574 Changed = !ToRemoves.empty();
2575
2576 // And cleanup the calls we don't use anymore.
2577 for (auto V : ToRemoves) {
2578 V->eraseFromParent();
2579 }
2580
2581 // And remove the function we don't need either too.
2582 F->eraseFromParent();
2583 }
2584
2585 return Changed;
2586}
David Neto62653202017-10-16 19:05:18 -04002587
2588bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2589 bool Changed = false;
2590
2591 // OpenCL's float result = fract(float x, float* ptr)
2592 //
2593 // In the LLVM domain:
2594 //
2595 // %floor_result = call spir_func float @floor(float %x)
2596 // store float %floor_result, float * %ptr
2597 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2598 // %result = call spir_func float
2599 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2600 //
2601 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2602 // and clspv.fract occur in the SPIR-V generator pass:
2603 //
2604 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2605 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2606 // ...
2607 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2608 // OpStore %ptr %floor_result
2609 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2610 // %fract_result = OpExtInst %float
2611 // %glsl_ext Fmin %fract_intermediate %just_under_1
2612
2613
2614 using std::string;
2615
2616 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2617 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
2618 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
2619 auto make_quad = [](const char *a, const char *b, const char *c,
2620 const char *d) {
2621 return std::tuple<const char *, const char *, const char *, const char *>(
2622 a, b, c, d);
2623 };
2624 const std::vector<QuadType> Functions = {
2625 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
2626 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
2627 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
2628 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
2629 };
2630
2631 for (auto& quad : Functions) {
2632 const StringRef fract_name(std::get<0>(quad));
2633
2634 // If we find a function with the matching name.
2635 if (auto F = M.getFunction(fract_name)) {
2636 if (F->use_begin() == F->use_end())
2637 continue;
2638
2639 // We have some uses.
2640 Changed = true;
2641
2642 auto& Context = M.getContext();
2643
2644 const StringRef floor_name(std::get<1>(quad));
2645 const StringRef fmin_name(std::get<2>(quad));
2646 const StringRef clspv_fract_name(std::get<3>(quad));
2647
2648 // This is either float or a float vector. All the float-like
2649 // types are this type.
2650 auto result_ty = F->getReturnType();
2651
2652 Function* fmin_fn = M.getFunction(fmin_name);
2653 if (!fmin_fn) {
2654 // Make the fmin function.
2655 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
2656 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002657 fmin_fn->addFnAttr(Attribute::ReadNone);
2658 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2659 }
2660
2661 Function* floor_fn = M.getFunction(floor_name);
2662 if (!floor_fn) {
2663 // Make the floor function.
2664 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2665 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002666 floor_fn->addFnAttr(Attribute::ReadNone);
2667 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2668 }
2669
2670 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
2671 if (!clspv_fract_fn) {
2672 // Make the clspv_fract function.
2673 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2674 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002675 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2676 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2677 }
2678
2679 // Number of significant significand bits, whether represented or not.
2680 unsigned num_significand_bits;
2681 switch (result_ty->getScalarType()->getTypeID()) {
2682 case Type::HalfTyID:
2683 num_significand_bits = 11;
2684 break;
2685 case Type::FloatTyID:
2686 num_significand_bits = 24;
2687 break;
2688 case Type::DoubleTyID:
2689 num_significand_bits = 53;
2690 break;
2691 default:
2692 assert(false && "Unhandled float type when processing fract builtin");
2693 break;
2694 }
2695 // Beware that the disassembler displays this value as
2696 // OpConstant %float 1
2697 // which is not quite right.
2698 const double kJustUnderOneScalar =
2699 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2700
2701 Constant *just_under_one =
2702 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2703 if (result_ty->isVectorTy()) {
2704 just_under_one = ConstantVector::getSplat(
2705 result_ty->getVectorNumElements(), just_under_one);
2706 }
2707
2708 IRBuilder<> Builder(Context);
2709
2710 SmallVector<Instruction *, 4> ToRemoves;
2711
2712 // Walk the users of the function.
2713 for (auto &U : F->uses()) {
2714 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2715
2716 Builder.SetInsertPoint(CI);
2717 auto arg = CI->getArgOperand(0);
2718 auto ptr = CI->getArgOperand(1);
2719
2720 // Compute floor result and store it.
2721 auto floor = Builder.CreateCall(floor_fn, {arg});
2722 Builder.CreateStore(floor, ptr);
2723
2724 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2725 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2726
2727 CI->replaceAllUsesWith(fract_result);
2728
2729 // Lastly, remember to remove the user.
2730 ToRemoves.push_back(CI);
2731 }
2732 }
2733
2734 // And cleanup the calls we don't use anymore.
2735 for (auto V : ToRemoves) {
2736 V->eraseFromParent();
2737 }
2738
2739 // And remove the function we don't need either too.
2740 F->eraseFromParent();
2741 }
2742 }
2743
2744 return Changed;
2745}