blob: fbeb0467ca4ef75f8eef2c7666cda6434507a729 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
David Neto118188e2018-08-24 11:27:54 -040019#include "llvm/IR/Constants.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000023#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/Pass.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/raw_ostream.h"
27#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040028
David Neto118188e2018-08-24 11:27:54 -040029#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040030
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Neto22f144c2017-06-12 14:26:21 -040033using namespace llvm;
34
35#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
36
37namespace {
38uint32_t clz(uint32_t v) {
39 uint32_t r;
40 uint32_t shift;
41
42 r = (v > 0xFFFF) << 4;
43 v >>= r;
44 shift = (v > 0xFF) << 3;
45 v >>= shift;
46 r |= shift;
47 shift = (v > 0xF) << 2;
48 v >>= shift;
49 r |= shift;
50 shift = (v > 0x3) << 1;
51 v >>= shift;
52 r |= shift;
53 r |= (v >> 1);
54
55 return r;
56}
57
58Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
59 if (1 == elements) {
60 return Type::getInt1Ty(C);
61 } else {
62 return VectorType::get(Type::getInt1Ty(C), elements);
63 }
64}
65
66struct ReplaceOpenCLBuiltinPass final : public ModulePass {
67 static char ID;
68 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
69
70 bool runOnModule(Module &M) override;
71 bool replaceRecip(Module &M);
72 bool replaceDivide(Module &M);
73 bool replaceExp10(Module &M);
74 bool replaceLog10(Module &M);
75 bool replaceBarrier(Module &M);
76 bool replaceMemFence(Module &M);
77 bool replaceRelational(Module &M);
78 bool replaceIsInfAndIsNan(Module &M);
79 bool replaceAllAndAny(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +000080 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +000081 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +000082 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040083 bool replaceSignbit(Module &M);
84 bool replaceMadandMad24andMul24(Module &M);
85 bool replaceVloadHalf(Module &M);
86 bool replaceVloadHalf2(Module &M);
87 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -070088 bool replaceClspvVloadaHalf2(Module &M);
89 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040090 bool replaceVstoreHalf(Module &M);
91 bool replaceVstoreHalf2(Module &M);
92 bool replaceVstoreHalf4(Module &M);
93 bool replaceReadImageF(Module &M);
94 bool replaceAtomics(Module &M);
95 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -040096 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -070097 bool replaceVload(Module &M);
98 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040099};
100}
101
102char ReplaceOpenCLBuiltinPass::ID = 0;
103static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
104 "Replace OpenCL Builtins Pass");
105
106namespace clspv {
107ModulePass *createReplaceOpenCLBuiltinPass() {
108 return new ReplaceOpenCLBuiltinPass();
109}
110}
111
112bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
113 bool Changed = false;
114
115 Changed |= replaceRecip(M);
116 Changed |= replaceDivide(M);
117 Changed |= replaceExp10(M);
118 Changed |= replaceLog10(M);
119 Changed |= replaceBarrier(M);
120 Changed |= replaceMemFence(M);
121 Changed |= replaceRelational(M);
122 Changed |= replaceIsInfAndIsNan(M);
123 Changed |= replaceAllAndAny(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000124 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000125 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000126 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400127 Changed |= replaceSignbit(M);
128 Changed |= replaceMadandMad24andMul24(M);
129 Changed |= replaceVloadHalf(M);
130 Changed |= replaceVloadHalf2(M);
131 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700132 Changed |= replaceClspvVloadaHalf2(M);
133 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400134 Changed |= replaceVstoreHalf(M);
135 Changed |= replaceVstoreHalf2(M);
136 Changed |= replaceVstoreHalf4(M);
137 Changed |= replaceReadImageF(M);
138 Changed |= replaceAtomics(M);
139 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400140 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700141 Changed |= replaceVload(M);
142 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400143
144 return Changed;
145}
146
147bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
148 bool Changed = false;
149
150 const char *Names[] = {
151 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
152 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
153 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
154 };
155
156 for (auto Name : Names) {
157 // If we find a function with the matching name.
158 if (auto F = M.getFunction(Name)) {
159 SmallVector<Instruction *, 4> ToRemoves;
160
161 // Walk the users of the function.
162 for (auto &U : F->uses()) {
163 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
164 // Recip has one arg.
165 auto Arg = CI->getOperand(0);
166
167 auto Div = BinaryOperator::Create(
168 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
169 CI);
170
171 CI->replaceAllUsesWith(Div);
172
173 // Lastly, remember to remove the user.
174 ToRemoves.push_back(CI);
175 }
176 }
177
178 Changed = !ToRemoves.empty();
179
180 // And cleanup the calls we don't use anymore.
181 for (auto V : ToRemoves) {
182 V->eraseFromParent();
183 }
184
185 // And remove the function we don't need either too.
186 F->eraseFromParent();
187 }
188 }
189
190 return Changed;
191}
192
193bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
194 bool Changed = false;
195
196 const char *Names[] = {
197 "_Z11half_divideff", "_Z13native_divideff",
198 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
199 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
200 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
201 };
202
203 for (auto Name : Names) {
204 // If we find a function with the matching name.
205 if (auto F = M.getFunction(Name)) {
206 SmallVector<Instruction *, 4> ToRemoves;
207
208 // Walk the users of the function.
209 for (auto &U : F->uses()) {
210 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
211 auto Div = BinaryOperator::Create(
212 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
213
214 CI->replaceAllUsesWith(Div);
215
216 // Lastly, remember to remove the user.
217 ToRemoves.push_back(CI);
218 }
219 }
220
221 Changed = !ToRemoves.empty();
222
223 // And cleanup the calls we don't use anymore.
224 for (auto V : ToRemoves) {
225 V->eraseFromParent();
226 }
227
228 // And remove the function we don't need either too.
229 F->eraseFromParent();
230 }
231 }
232
233 return Changed;
234}
235
236bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
237 bool Changed = false;
238
239 const std::map<const char *, const char *> Map = {
240 {"_Z5exp10f", "_Z3expf"},
241 {"_Z10half_exp10f", "_Z8half_expf"},
242 {"_Z12native_exp10f", "_Z10native_expf"},
243 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
244 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
245 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
246 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
247 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
248 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
249 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
250 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
251 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
252
253 for (auto Pair : Map) {
254 // If we find a function with the matching name.
255 if (auto F = M.getFunction(Pair.first)) {
256 SmallVector<Instruction *, 4> ToRemoves;
257
258 // Walk the users of the function.
259 for (auto &U : F->uses()) {
260 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
261 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
262
263 auto Arg = CI->getOperand(0);
264
265 // Constant of the natural log of 10 (ln(10)).
266 const double Ln10 =
267 2.302585092994045684017991454684364207601101488628772976033;
268
269 auto Mul = BinaryOperator::Create(
270 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
271 CI);
272
273 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
274
275 CI->replaceAllUsesWith(NewCI);
276
277 // Lastly, remember to remove the user.
278 ToRemoves.push_back(CI);
279 }
280 }
281
282 Changed = !ToRemoves.empty();
283
284 // And cleanup the calls we don't use anymore.
285 for (auto V : ToRemoves) {
286 V->eraseFromParent();
287 }
288
289 // And remove the function we don't need either too.
290 F->eraseFromParent();
291 }
292 }
293
294 return Changed;
295}
296
297bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
298 bool Changed = false;
299
300 const std::map<const char *, const char *> Map = {
301 {"_Z5log10f", "_Z3logf"},
302 {"_Z10half_log10f", "_Z8half_logf"},
303 {"_Z12native_log10f", "_Z10native_logf"},
304 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
305 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
306 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
307 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
308 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
309 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
310 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
311 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
312 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
313
314 for (auto Pair : Map) {
315 // If we find a function with the matching name.
316 if (auto F = M.getFunction(Pair.first)) {
317 SmallVector<Instruction *, 4> ToRemoves;
318
319 // Walk the users of the function.
320 for (auto &U : F->uses()) {
321 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
322 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
323
324 auto Arg = CI->getOperand(0);
325
326 // Constant of the reciprocal of the natural log of 10 (ln(10)).
327 const double Ln10 =
328 0.434294481903251827651128918916605082294397005803666566114;
329
330 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
331
332 auto Mul = BinaryOperator::Create(
333 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
334 "", CI);
335
336 CI->replaceAllUsesWith(Mul);
337
338 // Lastly, remember to remove the user.
339 ToRemoves.push_back(CI);
340 }
341 }
342
343 Changed = !ToRemoves.empty();
344
345 // And cleanup the calls we don't use anymore.
346 for (auto V : ToRemoves) {
347 V->eraseFromParent();
348 }
349
350 // And remove the function we don't need either too.
351 F->eraseFromParent();
352 }
353 }
354
355 return Changed;
356}
357
358bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
359 bool Changed = false;
360
361 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
362
363 const std::map<const char *, const char *> Map = {
364 {"_Z7barrierj", "__spirv_control_barrier"}};
365
366 for (auto Pair : Map) {
367 // If we find a function with the matching name.
368 if (auto F = M.getFunction(Pair.first)) {
369 SmallVector<Instruction *, 4> ToRemoves;
370
371 // Walk the users of the function.
372 for (auto &U : F->uses()) {
373 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
374 auto FType = F->getFunctionType();
375 SmallVector<Type *, 3> Params;
376 for (unsigned i = 0; i < 3; i++) {
377 Params.push_back(FType->getParamType(0));
378 }
379 auto NewFType =
380 FunctionType::get(FType->getReturnType(), Params, false);
381 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
382
383 auto Arg = CI->getOperand(0);
384
385 // We need to map the OpenCL constants to the SPIR-V equivalents.
386 const auto LocalMemFence =
387 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
388 const auto GlobalMemFence =
389 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
390 const auto ConstantSequentiallyConsistent = ConstantInt::get(
391 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
392 const auto ConstantScopeDevice =
393 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
394 const auto ConstantScopeWorkgroup =
395 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
396
397 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
398 const auto LocalMemFenceMask = BinaryOperator::Create(
399 Instruction::And, LocalMemFence, Arg, "", CI);
400 const auto WorkgroupShiftAmount =
401 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
402 clz(CLK_LOCAL_MEM_FENCE);
403 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
404 Instruction::Shl, LocalMemFenceMask,
405 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
406
407 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
408 const auto GlobalMemFenceMask = BinaryOperator::Create(
409 Instruction::And, GlobalMemFence, Arg, "", CI);
410 const auto UniformShiftAmount =
411 clz(spv::MemorySemanticsUniformMemoryMask) -
412 clz(CLK_GLOBAL_MEM_FENCE);
413 const auto MemorySemanticsUniform = BinaryOperator::Create(
414 Instruction::Shl, GlobalMemFenceMask,
415 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
416
417 // And combine the above together, also adding in
418 // MemorySemanticsSequentiallyConsistentMask.
419 auto MemorySemantics =
420 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
421 ConstantSequentiallyConsistent, "", CI);
422 MemorySemantics = BinaryOperator::Create(
423 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
424
425 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
426 // Device Scope, otherwise Workgroup Scope.
427 const auto Cmp =
428 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
429 GlobalMemFenceMask, GlobalMemFence, "", CI);
430 const auto MemoryScope = SelectInst::Create(
431 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
432
433 // Lastly, the Execution Scope is always Workgroup Scope.
434 const auto ExecutionScope = ConstantScopeWorkgroup;
435
436 auto NewCI = CallInst::Create(
437 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
438
439 CI->replaceAllUsesWith(NewCI);
440
441 // Lastly, remember to remove the user.
442 ToRemoves.push_back(CI);
443 }
444 }
445
446 Changed = !ToRemoves.empty();
447
448 // And cleanup the calls we don't use anymore.
449 for (auto V : ToRemoves) {
450 V->eraseFromParent();
451 }
452
453 // And remove the function we don't need either too.
454 F->eraseFromParent();
455 }
456 }
457
458 return Changed;
459}
460
461bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
462 bool Changed = false;
463
464 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
465
Neil Henning39672102017-09-29 14:33:13 +0100466 using Tuple = std::tuple<const char *, unsigned>;
467 const std::map<const char *, Tuple> Map = {
468 {"_Z9mem_fencej",
469 Tuple("__spirv_memory_barrier",
470 spv::MemorySemanticsSequentiallyConsistentMask)},
471 {"_Z14read_mem_fencej",
472 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
473 {"_Z15write_mem_fencej",
474 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400475
476 for (auto Pair : Map) {
477 // If we find a function with the matching name.
478 if (auto F = M.getFunction(Pair.first)) {
479 SmallVector<Instruction *, 4> ToRemoves;
480
481 // Walk the users of the function.
482 for (auto &U : F->uses()) {
483 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
484 auto FType = F->getFunctionType();
485 SmallVector<Type *, 2> Params;
486 for (unsigned i = 0; i < 2; i++) {
487 Params.push_back(FType->getParamType(0));
488 }
489 auto NewFType =
490 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100491 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400492
493 auto Arg = CI->getOperand(0);
494
495 // We need to map the OpenCL constants to the SPIR-V equivalents.
496 const auto LocalMemFence =
497 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
498 const auto GlobalMemFence =
499 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
500 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100501 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400502 const auto ConstantScopeDevice =
503 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
504
505 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
506 const auto LocalMemFenceMask = BinaryOperator::Create(
507 Instruction::And, LocalMemFence, Arg, "", CI);
508 const auto WorkgroupShiftAmount =
509 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
510 clz(CLK_LOCAL_MEM_FENCE);
511 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
512 Instruction::Shl, LocalMemFenceMask,
513 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
514
515 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
516 const auto GlobalMemFenceMask = BinaryOperator::Create(
517 Instruction::And, GlobalMemFence, Arg, "", CI);
518 const auto UniformShiftAmount =
519 clz(spv::MemorySemanticsUniformMemoryMask) -
520 clz(CLK_GLOBAL_MEM_FENCE);
521 const auto MemorySemanticsUniform = BinaryOperator::Create(
522 Instruction::Shl, GlobalMemFenceMask,
523 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
524
525 // And combine the above together, also adding in
526 // MemorySemanticsSequentiallyConsistentMask.
527 auto MemorySemantics =
528 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
529 ConstantMemorySemantics, "", CI);
530 MemorySemantics = BinaryOperator::Create(
531 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
532
533 // Memory Scope is always device.
534 const auto MemoryScope = ConstantScopeDevice;
535
536 auto NewCI =
537 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
538
539 CI->replaceAllUsesWith(NewCI);
540
541 // Lastly, remember to remove the user.
542 ToRemoves.push_back(CI);
543 }
544 }
545
546 Changed = !ToRemoves.empty();
547
548 // And cleanup the calls we don't use anymore.
549 for (auto V : ToRemoves) {
550 V->eraseFromParent();
551 }
552
553 // And remove the function we don't need either too.
554 F->eraseFromParent();
555 }
556 }
557
558 return Changed;
559}
560
561bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
562 bool Changed = false;
563
564 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
565 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
566 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
567 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
568 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
569 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
570 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
571 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
572 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
573 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
574 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
575 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
576 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
577 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
578 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
579 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
580 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
581 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
582 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
583 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
584 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
585 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
586 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
587 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
588 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
589 };
590
591 for (auto Pair : Map) {
592 // If we find a function with the matching name.
593 if (auto F = M.getFunction(Pair.first)) {
594 SmallVector<Instruction *, 4> ToRemoves;
595
596 // Walk the users of the function.
597 for (auto &U : F->uses()) {
598 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
599 // The predicate to use in the CmpInst.
600 auto Predicate = Pair.second.first;
601
602 // The value to return for true.
603 auto TrueValue =
604 ConstantInt::getSigned(CI->getType(), Pair.second.second);
605
606 // The value to return for false.
607 auto FalseValue = Constant::getNullValue(CI->getType());
608
609 auto Arg1 = CI->getOperand(0);
610 auto Arg2 = CI->getOperand(1);
611
612 const auto Cmp =
613 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
614
615 const auto Select =
616 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
617
618 CI->replaceAllUsesWith(Select);
619
620 // Lastly, remember to remove the user.
621 ToRemoves.push_back(CI);
622 }
623 }
624
625 Changed = !ToRemoves.empty();
626
627 // And cleanup the calls we don't use anymore.
628 for (auto V : ToRemoves) {
629 V->eraseFromParent();
630 }
631
632 // And remove the function we don't need either too.
633 F->eraseFromParent();
634 }
635 }
636
637 return Changed;
638}
639
640bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
641 bool Changed = false;
642
643 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
644 {"_Z5isinff", {"__spirv_isinff", 1}},
645 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
646 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
647 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
648 {"_Z5isnanf", {"__spirv_isnanf", 1}},
649 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
650 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
651 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
652 };
653
654 for (auto Pair : Map) {
655 // If we find a function with the matching name.
656 if (auto F = M.getFunction(Pair.first)) {
657 SmallVector<Instruction *, 4> ToRemoves;
658
659 // Walk the users of the function.
660 for (auto &U : F->uses()) {
661 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
662 const auto CITy = CI->getType();
663
664 // The fake SPIR-V intrinsic to generate.
665 auto SPIRVIntrinsic = Pair.second.first;
666
667 // The value to return for true.
668 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
669
670 // The value to return for false.
671 auto FalseValue = Constant::getNullValue(CITy);
672
673 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
674 M.getContext(),
675 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
676
677 auto NewFType =
678 FunctionType::get(CorrespondingBoolTy,
679 F->getFunctionType()->getParamType(0), false);
680
681 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
682
683 auto Arg = CI->getOperand(0);
684
685 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
686
687 const auto Select =
688 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
689
690 CI->replaceAllUsesWith(Select);
691
692 // Lastly, remember to remove the user.
693 ToRemoves.push_back(CI);
694 }
695 }
696
697 Changed = !ToRemoves.empty();
698
699 // And cleanup the calls we don't use anymore.
700 for (auto V : ToRemoves) {
701 V->eraseFromParent();
702 }
703
704 // And remove the function we don't need either too.
705 F->eraseFromParent();
706 }
707 }
708
709 return Changed;
710}
711
712bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
713 bool Changed = false;
714
715 const std::map<const char *, const char *> Map = {
716 {"_Z3alli", ""},
717 {"_Z3allDv2_i", "__spirv_allDv2_i"},
718 {"_Z3allDv3_i", "__spirv_allDv3_i"},
719 {"_Z3allDv4_i", "__spirv_allDv4_i"},
720 {"_Z3anyi", ""},
721 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
722 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
723 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
724 };
725
726 for (auto Pair : Map) {
727 // If we find a function with the matching name.
728 if (auto F = M.getFunction(Pair.first)) {
729 SmallVector<Instruction *, 4> ToRemoves;
730
731 // Walk the users of the function.
732 for (auto &U : F->uses()) {
733 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
734 // The fake SPIR-V intrinsic to generate.
735 auto SPIRVIntrinsic = Pair.second;
736
737 auto Arg = CI->getOperand(0);
738
739 Value *V;
740
741 // If we have a function to call, call it!
742 if (0 < strlen(SPIRVIntrinsic)) {
743 // The value for zero to compare against.
744 const auto ZeroValue = Constant::getNullValue(Arg->getType());
745
746 const auto Cmp = CmpInst::Create(
747 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
748 const auto NewFType = FunctionType::get(
749 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
750
751 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
752
753 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
754
755 // The value to return for true.
756 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
757
758 // The value to return for false.
759 const auto FalseValue = Constant::getNullValue(CI->getType());
760
761 V = SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
762 } else {
763 V = BinaryOperator::Create(Instruction::LShr, Arg,
764 ConstantInt::get(CI->getType(), 31), "",
765 CI);
766 }
767
768 CI->replaceAllUsesWith(V);
769
770 // Lastly, remember to remove the user.
771 ToRemoves.push_back(CI);
772 }
773 }
774
775 Changed = !ToRemoves.empty();
776
777 // And cleanup the calls we don't use anymore.
778 for (auto V : ToRemoves) {
779 V->eraseFromParent();
780 }
781
782 // And remove the function we don't need either too.
783 F->eraseFromParent();
784 }
785 }
786
787 return Changed;
788}
789
Kévin Petitf5b78a22018-10-25 14:32:17 +0000790bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
791 bool Changed = false;
792
793 for (auto const &SymVal : M.getValueSymbolTable()) {
794 // Skip symbols whose name doesn't match
795 if (!SymVal.getKey().startswith("_Z6select")) {
796 continue;
797 }
798 // Is there a function going by that name?
799 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
800
801 SmallVector<Instruction *, 4> ToRemoves;
802
803 // Walk the users of the function.
804 for (auto &U : F->uses()) {
805 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
806
807 // Get arguments
808 auto FalseValue = CI->getOperand(0);
809 auto TrueValue = CI->getOperand(1);
810 auto PredicateValue = CI->getOperand(2);
811
812 // Don't touch overloads that aren't in OpenCL C
813 auto FalseType = FalseValue->getType();
814 auto TrueType = TrueValue->getType();
815 auto PredicateType = PredicateValue->getType();
816
817 if (FalseType != TrueType) {
818 continue;
819 }
820
821 if (!PredicateType->isIntOrIntVectorTy()) {
822 continue;
823 }
824
825 if (!FalseType->isIntOrIntVectorTy() &&
826 !FalseType->getScalarType()->isFloatingPointTy()) {
827 continue;
828 }
829
830 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
831 continue;
832 }
833
834 if (FalseType->getScalarSizeInBits() !=
835 PredicateType->getScalarSizeInBits()) {
836 continue;
837 }
838
839 if (FalseType->isVectorTy()) {
840 if (FalseType->getVectorNumElements() !=
841 PredicateType->getVectorNumElements()) {
842 continue;
843 }
844
845 if ((FalseType->getVectorNumElements() != 2) &&
846 (FalseType->getVectorNumElements() != 3) &&
847 (FalseType->getVectorNumElements() != 4) &&
848 (FalseType->getVectorNumElements() != 8) &&
849 (FalseType->getVectorNumElements() != 16)) {
850 continue;
851 }
852 }
853
854 // Create constant
855 const auto ZeroValue = Constant::getNullValue(PredicateType);
856
857 // Scalar and vector are to be treated differently
858 CmpInst::Predicate Pred;
859 if (PredicateType->isVectorTy()) {
860 Pred = CmpInst::ICMP_SLT;
861 } else {
862 Pred = CmpInst::ICMP_NE;
863 }
864
865 // Create comparison instruction
866 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
867 ZeroValue, "", CI);
868
869 // Create select
870 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
871
872 // Replace call with the selection
873 CI->replaceAllUsesWith(V);
874
875 // Lastly, remember to remove the user.
876 ToRemoves.push_back(CI);
877 }
878 }
879
880 Changed = !ToRemoves.empty();
881
882 // And cleanup the calls we don't use anymore.
883 for (auto V : ToRemoves) {
884 V->eraseFromParent();
885 }
886
887 // And remove the function we don't need either too.
888 F->eraseFromParent();
889 }
890 }
891
892 return Changed;
893}
894
Kévin Petite7d0cce2018-10-31 12:38:56 +0000895bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
896 bool Changed = false;
897
898 for (auto const &SymVal : M.getValueSymbolTable()) {
899 // Skip symbols whose name doesn't match
900 if (!SymVal.getKey().startswith("_Z9bitselect")) {
901 continue;
902 }
903 // Is there a function going by that name?
904 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
905
906 SmallVector<Instruction *, 4> ToRemoves;
907
908 // Walk the users of the function.
909 for (auto &U : F->uses()) {
910 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
911
912 if (CI->getNumOperands() != 4) {
913 continue;
914 }
915
916 // Get arguments
917 auto FalseValue = CI->getOperand(0);
918 auto TrueValue = CI->getOperand(1);
919 auto PredicateValue = CI->getOperand(2);
920
921 // Don't touch overloads that aren't in OpenCL C
922 auto FalseType = FalseValue->getType();
923 auto TrueType = TrueValue->getType();
924 auto PredicateType = PredicateValue->getType();
925
926 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
927 continue;
928 }
929
930 if (TrueType->isVectorTy()) {
931 if (!TrueType->getScalarType()->isFloatingPointTy() &&
932 !TrueType->getScalarType()->isIntegerTy()) {
933 continue;
934 }
935 if ((TrueType->getVectorNumElements() != 2) &&
936 (TrueType->getVectorNumElements() != 3) &&
937 (TrueType->getVectorNumElements() != 4) &&
938 (TrueType->getVectorNumElements() != 8) &&
939 (TrueType->getVectorNumElements() != 16)) {
940 continue;
941 }
942 }
943
944 // Remember the type of the operands
945 auto OpType = TrueType;
946
947 // The actual bit selection will always be done on an integer type,
948 // declare it here
949 Type *BitType;
950
951 // If the operands are float, then bitcast them to int
952 if (OpType->getScalarType()->isFloatingPointTy()) {
953
954 // First create the new type
955 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
956 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
957 if (OpType->isVectorTy()) {
958 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
959 }
960
961 // Then bitcast all operands
962 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
963 BitType, "", CI);
964 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
965 BitType, "", CI);
966 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
967
968 } else {
969 // The operands have an integer type, use it directly
970 BitType = OpType;
971 }
972
973 // All the operands are now always integers
974 // implement as (c & b) | (~c & a)
975
976 // Create our negated predicate value
977 auto AllOnes = Constant::getAllOnesValue(BitType);
978 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
979 PredicateValue,
980 AllOnes, "", CI);
981
982 // Then put everything together
983 auto BitsFalse = BinaryOperator::Create(Instruction::And,
984 NotPredicateValue,
985 FalseValue, "", CI);
986 auto BitsTrue = BinaryOperator::Create(Instruction::And,
987 PredicateValue,
988 TrueValue, "", CI);
989
990 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
991 BitsTrue, "", CI);
992
993 // If we were dealing with a floating point type, we must bitcast
994 // the result back to that
995 if (OpType->getScalarType()->isFloatingPointTy()) {
996 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
997 }
998
999 // Replace call with our new code
1000 CI->replaceAllUsesWith(V);
1001
1002 // Lastly, remember to remove the user.
1003 ToRemoves.push_back(CI);
1004 }
1005 }
1006
1007 Changed = !ToRemoves.empty();
1008
1009 // And cleanup the calls we don't use anymore.
1010 for (auto V : ToRemoves) {
1011 V->eraseFromParent();
1012 }
1013
1014 // And remove the function we don't need either too.
1015 F->eraseFromParent();
1016 }
1017 }
1018
1019 return Changed;
1020}
1021
Kévin Petit6b0a9532018-10-30 20:00:39 +00001022bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1023 bool Changed = false;
1024
1025 const std::map<const char *, const char *> Map = {
1026 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1027 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1028 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1029 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1030 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1031 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1032 };
1033
1034 for (auto Pair : Map) {
1035 // If we find a function with the matching name.
1036 if (auto F = M.getFunction(Pair.first)) {
1037 SmallVector<Instruction *, 4> ToRemoves;
1038
1039 // Walk the users of the function.
1040 for (auto &U : F->uses()) {
1041 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1042
1043 auto ReplacementFn = Pair.second;
1044
1045 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1046 Value *VectorArg;
1047
1048 // First figure out which function we're dealing with
1049 if (F->getName().startswith("_Z10smoothstep")) {
1050 ArgsToSplat.push_back(CI->getOperand(1));
1051 VectorArg = CI->getOperand(2);
1052 } else {
1053 VectorArg = CI->getOperand(1);
1054 }
1055
1056 // Splat arguments that need to be
1057 SmallVector<Value*, 2> SplatArgs;
1058 auto VecType = VectorArg->getType();
1059
1060 for (auto arg : ArgsToSplat) {
1061 Value* NewVectorArg = UndefValue::get(VecType);
1062 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1063 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1064 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1065 }
1066 SplatArgs.push_back(NewVectorArg);
1067 }
1068
1069 // Replace the call with the vector/vector flavour
1070 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1071 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1072
1073 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1074
1075 SmallVector<Value*, 3> NewArgs;
1076 for (auto arg : SplatArgs) {
1077 NewArgs.push_back(arg);
1078 }
1079 NewArgs.push_back(VectorArg);
1080
1081 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1082
1083 CI->replaceAllUsesWith(NewCI);
1084
1085 // Lastly, remember to remove the user.
1086 ToRemoves.push_back(CI);
1087 }
1088 }
1089
1090 Changed = !ToRemoves.empty();
1091
1092 // And cleanup the calls we don't use anymore.
1093 for (auto V : ToRemoves) {
1094 V->eraseFromParent();
1095 }
1096
1097 // And remove the function we don't need either too.
1098 F->eraseFromParent();
1099 }
1100 }
1101
1102 return Changed;
1103}
1104
David Neto22f144c2017-06-12 14:26:21 -04001105bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1106 bool Changed = false;
1107
1108 const std::map<const char *, Instruction::BinaryOps> Map = {
1109 {"_Z7signbitf", Instruction::LShr},
1110 {"_Z7signbitDv2_f", Instruction::AShr},
1111 {"_Z7signbitDv3_f", Instruction::AShr},
1112 {"_Z7signbitDv4_f", Instruction::AShr},
1113 };
1114
1115 for (auto Pair : Map) {
1116 // If we find a function with the matching name.
1117 if (auto F = M.getFunction(Pair.first)) {
1118 SmallVector<Instruction *, 4> ToRemoves;
1119
1120 // Walk the users of the function.
1121 for (auto &U : F->uses()) {
1122 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1123 auto Arg = CI->getOperand(0);
1124
1125 auto Bitcast =
1126 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1127
1128 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1129 ConstantInt::get(CI->getType(), 31),
1130 "", CI);
1131
1132 CI->replaceAllUsesWith(Shr);
1133
1134 // Lastly, remember to remove the user.
1135 ToRemoves.push_back(CI);
1136 }
1137 }
1138
1139 Changed = !ToRemoves.empty();
1140
1141 // And cleanup the calls we don't use anymore.
1142 for (auto V : ToRemoves) {
1143 V->eraseFromParent();
1144 }
1145
1146 // And remove the function we don't need either too.
1147 F->eraseFromParent();
1148 }
1149 }
1150
1151 return Changed;
1152}
1153
1154bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1155 bool Changed = false;
1156
1157 const std::map<const char *,
1158 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1159 Map = {
1160 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1161 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1162 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1163 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1164 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1165 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1166 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1167 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1168 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1169 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1170 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1171 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1172 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1173 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1174 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1175 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1176 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1177 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1178 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1179 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1180 };
1181
1182 for (auto Pair : Map) {
1183 // If we find a function with the matching name.
1184 if (auto F = M.getFunction(Pair.first)) {
1185 SmallVector<Instruction *, 4> ToRemoves;
1186
1187 // Walk the users of the function.
1188 for (auto &U : F->uses()) {
1189 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1190 // The multiply instruction to use.
1191 auto MulInst = Pair.second.first;
1192
1193 // The add instruction to use.
1194 auto AddInst = Pair.second.second;
1195
1196 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1197
1198 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1199 CI->getArgOperand(1), "", CI);
1200
1201 if (Instruction::BinaryOpsEnd != AddInst) {
1202 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1203 CI);
1204 }
1205
1206 CI->replaceAllUsesWith(I);
1207
1208 // Lastly, remember to remove the user.
1209 ToRemoves.push_back(CI);
1210 }
1211 }
1212
1213 Changed = !ToRemoves.empty();
1214
1215 // And cleanup the calls we don't use anymore.
1216 for (auto V : ToRemoves) {
1217 V->eraseFromParent();
1218 }
1219
1220 // And remove the function we don't need either too.
1221 F->eraseFromParent();
1222 }
1223 }
1224
1225 return Changed;
1226}
1227
Derek Chowcfd368b2017-10-19 20:58:45 -07001228bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1229 bool Changed = false;
1230
1231 struct VectorStoreOps {
1232 const char* name;
1233 int n;
1234 Type* (*get_scalar_type_function)(LLVMContext&);
1235 } vector_store_ops[] = {
1236 // TODO(derekjchow): Expand this list.
1237 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1238 };
1239
David Neto544fffc2017-11-16 18:35:14 -05001240 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001241 auto Name = Op.name;
1242 auto N = Op.n;
1243 auto TypeFn = Op.get_scalar_type_function;
1244 if (auto F = M.getFunction(Name)) {
1245 SmallVector<Instruction *, 4> ToRemoves;
1246
1247 // Walk the users of the function.
1248 for (auto &U : F->uses()) {
1249 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1250 // The value argument from vstoren.
1251 auto Arg0 = CI->getOperand(0);
1252
1253 // The index argument from vstoren.
1254 auto Arg1 = CI->getOperand(1);
1255
1256 // The pointer argument from vstoren.
1257 auto Arg2 = CI->getOperand(2);
1258
1259 // Get types.
1260 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1261 auto ScalarNPointerTy = PointerType::get(
1262 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1263
1264 // Cast to scalarn
1265 auto Cast = CastInst::CreatePointerCast(
1266 Arg2, ScalarNPointerTy, "", CI);
1267 // Index to correct address
1268 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1269 // Store
1270 auto Store = new StoreInst(Arg0, Index, CI);
1271
1272 CI->replaceAllUsesWith(Store);
1273 ToRemoves.push_back(CI);
1274 }
1275 }
1276
1277 Changed = !ToRemoves.empty();
1278
1279 // And cleanup the calls we don't use anymore.
1280 for (auto V : ToRemoves) {
1281 V->eraseFromParent();
1282 }
1283
1284 // And remove the function we don't need either too.
1285 F->eraseFromParent();
1286 }
1287 }
1288
1289 return Changed;
1290}
1291
1292bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1293 bool Changed = false;
1294
1295 struct VectorLoadOps {
1296 const char* name;
1297 int n;
1298 Type* (*get_scalar_type_function)(LLVMContext&);
1299 } vector_load_ops[] = {
1300 // TODO(derekjchow): Expand this list.
1301 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1302 };
1303
David Neto544fffc2017-11-16 18:35:14 -05001304 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001305 auto Name = Op.name;
1306 auto N = Op.n;
1307 auto TypeFn = Op.get_scalar_type_function;
1308 // If we find a function with the matching name.
1309 if (auto F = M.getFunction(Name)) {
1310 SmallVector<Instruction *, 4> ToRemoves;
1311
1312 // Walk the users of the function.
1313 for (auto &U : F->uses()) {
1314 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1315 // The index argument from vloadn.
1316 auto Arg0 = CI->getOperand(0);
1317
1318 // The pointer argument from vloadn.
1319 auto Arg1 = CI->getOperand(1);
1320
1321 // Get types.
1322 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1323 auto ScalarNPointerTy = PointerType::get(
1324 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1325
1326 // Cast to scalarn
1327 auto Cast = CastInst::CreatePointerCast(
1328 Arg1, ScalarNPointerTy, "", CI);
1329 // Index to correct address
1330 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1331 // Load
1332 auto Load = new LoadInst(Index, "", CI);
1333
1334 CI->replaceAllUsesWith(Load);
1335 ToRemoves.push_back(CI);
1336 }
1337 }
1338
1339 Changed = !ToRemoves.empty();
1340
1341 // And cleanup the calls we don't use anymore.
1342 for (auto V : ToRemoves) {
1343 V->eraseFromParent();
1344 }
1345
1346 // And remove the function we don't need either too.
1347 F->eraseFromParent();
1348
1349 }
1350 }
1351
1352 return Changed;
1353}
1354
David Neto22f144c2017-06-12 14:26:21 -04001355bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1356 bool Changed = false;
1357
1358 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
1359 "_Z10vload_halfjPU3AS2KDh"};
1360
1361 for (auto Name : Map) {
1362 // If we find a function with the matching name.
1363 if (auto F = M.getFunction(Name)) {
1364 SmallVector<Instruction *, 4> ToRemoves;
1365
1366 // Walk the users of the function.
1367 for (auto &U : F->uses()) {
1368 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1369 // The index argument from vload_half.
1370 auto Arg0 = CI->getOperand(0);
1371
1372 // The pointer argument from vload_half.
1373 auto Arg1 = CI->getOperand(1);
1374
David Neto22f144c2017-06-12 14:26:21 -04001375 auto IntTy = Type::getInt32Ty(M.getContext());
1376 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001377 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1378
David Neto22f144c2017-06-12 14:26:21 -04001379 // Our intrinsic to unpack a float2 from an int.
1380 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1381
1382 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1383
David Neto482550a2018-03-24 05:21:07 -07001384 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04001385 auto ShortTy = Type::getInt16Ty(M.getContext());
1386 auto ShortPointerTy = PointerType::get(
1387 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001388
David Netoac825b82017-05-30 12:49:01 -04001389 // Cast the half* pointer to short*.
1390 auto Cast =
1391 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001392
David Netoac825b82017-05-30 12:49:01 -04001393 // Index into the correct address of the casted pointer.
1394 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1395
1396 // Load from the short* we casted to.
1397 auto Load = new LoadInst(Index, "", CI);
1398
1399 // ZExt the short -> int.
1400 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1401
1402 // Get our float2.
1403 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1404
1405 // Extract out the bottom element which is our float result.
1406 auto Extract = ExtractElementInst::Create(
1407 Call, ConstantInt::get(IntTy, 0), "", CI);
1408
1409 CI->replaceAllUsesWith(Extract);
1410 } else {
1411 // Assume the pointer argument points to storage aligned to 32bits
1412 // or more.
1413 // TODO(dneto): Do more analysis to make sure this is true?
1414 //
1415 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1416 // with:
1417 //
1418 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1419 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1420 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1421 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1422 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1423 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1424 // x float> %converted, %index_is_odd32
1425
1426 auto IntPointerTy = PointerType::get(
1427 IntTy, Arg1->getType()->getPointerAddressSpace());
1428
David Neto973e6a82017-05-30 13:48:18 -04001429 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04001430 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04001431 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04001432 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1433
1434 auto One = ConstantInt::get(IntTy, 1);
1435 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1436 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1437
1438 // Index into the correct address of the casted pointer.
1439 auto Ptr =
1440 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1441
1442 // Load from the int* we casted to.
1443 auto Load = new LoadInst(Ptr, "", CI);
1444
1445 // Get our float2.
1446 auto Call = CallInst::Create(NewF, Load, "", CI);
1447
1448 // Extract out the float result, where the element number is
1449 // determined by whether the original index was even or odd.
1450 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1451
1452 CI->replaceAllUsesWith(Extract);
1453 }
David Neto22f144c2017-06-12 14:26:21 -04001454
1455 // Lastly, remember to remove the user.
1456 ToRemoves.push_back(CI);
1457 }
1458 }
1459
1460 Changed = !ToRemoves.empty();
1461
1462 // And cleanup the calls we don't use anymore.
1463 for (auto V : ToRemoves) {
1464 V->eraseFromParent();
1465 }
1466
1467 // And remove the function we don't need either too.
1468 F->eraseFromParent();
1469 }
1470 }
1471
1472 return Changed;
1473}
1474
1475bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
1476 bool Changed = false;
1477
David Neto556c7e62018-06-08 13:45:55 -07001478 const std::vector<const char *> Map = {
1479 "_Z11vload_half2jPU3AS1KDh",
1480 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
1481 "_Z11vload_half2jPU3AS2KDh",
1482 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
1483 };
David Neto22f144c2017-06-12 14:26:21 -04001484
1485 for (auto Name : Map) {
1486 // If we find a function with the matching name.
1487 if (auto F = M.getFunction(Name)) {
1488 SmallVector<Instruction *, 4> ToRemoves;
1489
1490 // Walk the users of the function.
1491 for (auto &U : F->uses()) {
1492 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1493 // The index argument from vload_half.
1494 auto Arg0 = CI->getOperand(0);
1495
1496 // The pointer argument from vload_half.
1497 auto Arg1 = CI->getOperand(1);
1498
1499 auto IntTy = Type::getInt32Ty(M.getContext());
1500 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1501 auto NewPointerTy = PointerType::get(
1502 IntTy, Arg1->getType()->getPointerAddressSpace());
1503 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1504
1505 // Cast the half* pointer to int*.
1506 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1507
1508 // Index into the correct address of the casted pointer.
1509 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
1510
1511 // Load from the int* we casted to.
1512 auto Load = new LoadInst(Index, "", CI);
1513
1514 // Our intrinsic to unpack a float2 from an int.
1515 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1516
1517 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1518
1519 // Get our float2.
1520 auto Call = CallInst::Create(NewF, Load, "", CI);
1521
1522 CI->replaceAllUsesWith(Call);
1523
1524 // Lastly, remember to remove the user.
1525 ToRemoves.push_back(CI);
1526 }
1527 }
1528
1529 Changed = !ToRemoves.empty();
1530
1531 // And cleanup the calls we don't use anymore.
1532 for (auto V : ToRemoves) {
1533 V->eraseFromParent();
1534 }
1535
1536 // And remove the function we don't need either too.
1537 F->eraseFromParent();
1538 }
1539 }
1540
1541 return Changed;
1542}
1543
1544bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
1545 bool Changed = false;
1546
David Neto556c7e62018-06-08 13:45:55 -07001547 const std::vector<const char *> Map = {
1548 "_Z11vload_half4jPU3AS1KDh",
1549 "_Z12vloada_half4jPU3AS1KDh",
1550 "_Z11vload_half4jPU3AS2KDh",
1551 "_Z12vloada_half4jPU3AS2KDh",
1552 };
David Neto22f144c2017-06-12 14:26:21 -04001553
1554 for (auto Name : Map) {
1555 // If we find a function with the matching name.
1556 if (auto F = M.getFunction(Name)) {
1557 SmallVector<Instruction *, 4> ToRemoves;
1558
1559 // Walk the users of the function.
1560 for (auto &U : F->uses()) {
1561 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1562 // The index argument from vload_half.
1563 auto Arg0 = CI->getOperand(0);
1564
1565 // The pointer argument from vload_half.
1566 auto Arg1 = CI->getOperand(1);
1567
1568 auto IntTy = Type::getInt32Ty(M.getContext());
1569 auto Int2Ty = VectorType::get(IntTy, 2);
1570 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1571 auto NewPointerTy = PointerType::get(
1572 Int2Ty, Arg1->getType()->getPointerAddressSpace());
1573 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1574
1575 // Cast the half* pointer to int2*.
1576 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1577
1578 // Index into the correct address of the casted pointer.
1579 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
1580
1581 // Load from the int2* we casted to.
1582 auto Load = new LoadInst(Index, "", CI);
1583
1584 // Extract each element from the loaded int2.
1585 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1586 "", CI);
1587 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1588 "", CI);
1589
1590 // Our intrinsic to unpack a float2 from an int.
1591 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1592
1593 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1594
1595 // Get the lower (x & y) components of our final float4.
1596 auto Lo = CallInst::Create(NewF, X, "", CI);
1597
1598 // Get the higher (z & w) components of our final float4.
1599 auto Hi = CallInst::Create(NewF, Y, "", CI);
1600
1601 Constant *ShuffleMask[4] = {
1602 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1603 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1604
1605 // Combine our two float2's into one float4.
1606 auto Combine = new ShuffleVectorInst(
1607 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1608
1609 CI->replaceAllUsesWith(Combine);
1610
1611 // Lastly, remember to remove the user.
1612 ToRemoves.push_back(CI);
1613 }
1614 }
1615
1616 Changed = !ToRemoves.empty();
1617
1618 // And cleanup the calls we don't use anymore.
1619 for (auto V : ToRemoves) {
1620 V->eraseFromParent();
1621 }
1622
1623 // And remove the function we don't need either too.
1624 F->eraseFromParent();
1625 }
1626 }
1627
1628 return Changed;
1629}
1630
David Neto6ad93232018-06-07 15:42:58 -07001631bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
1632 bool Changed = false;
1633
1634 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1635 //
1636 // %u = load i32 %ptr
1637 // %fxy = call <2 x float> Unpack2xHalf(u)
1638 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1639 const std::vector<const char *> Map = {
1640 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
1641 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
1642 "_Z20__clspv_vloada_half2jPKj", // private
1643 };
1644
1645 for (auto Name : Map) {
1646 // If we find a function with the matching name.
1647 if (auto F = M.getFunction(Name)) {
1648 SmallVector<Instruction *, 4> ToRemoves;
1649
1650 // Walk the users of the function.
1651 for (auto &U : F->uses()) {
1652 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
1653 auto Index = CI->getOperand(0);
1654 auto Ptr = CI->getOperand(1);
1655
1656 auto IntTy = Type::getInt32Ty(M.getContext());
1657 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1658 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1659
1660 auto IndexedPtr =
1661 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
1662 auto Load = new LoadInst(IndexedPtr, "", CI);
1663
1664 // Our intrinsic to unpack a float2 from an int.
1665 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1666
1667 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1668
1669 // Get our final float2.
1670 auto Result = CallInst::Create(NewF, Load, "", CI);
1671
1672 CI->replaceAllUsesWith(Result);
1673
1674 // Lastly, remember to remove the user.
1675 ToRemoves.push_back(CI);
1676 }
1677 }
1678
1679 Changed = true;
1680
1681 // And cleanup the calls we don't use anymore.
1682 for (auto V : ToRemoves) {
1683 V->eraseFromParent();
1684 }
1685
1686 // And remove the function we don't need either too.
1687 F->eraseFromParent();
1688 }
1689 }
1690
1691 return Changed;
1692}
1693
1694bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
1695 bool Changed = false;
1696
1697 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1698 //
1699 // %u2 = load <2 x i32> %ptr
1700 // %u2xy = extractelement %u2, 0
1701 // %u2zw = extractelement %u2, 1
1702 // %fxy = call <2 x float> Unpack2xHalf(uint)
1703 // %fzw = call <2 x float> Unpack2xHalf(uint)
1704 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1705 const std::vector<const char *> Map = {
1706 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
1707 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
1708 "_Z20__clspv_vloada_half4jPKDv2_j", // private
1709 };
1710
1711 for (auto Name : Map) {
1712 // If we find a function with the matching name.
1713 if (auto F = M.getFunction(Name)) {
1714 SmallVector<Instruction *, 4> ToRemoves;
1715
1716 // Walk the users of the function.
1717 for (auto &U : F->uses()) {
1718 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1719 auto Index = CI->getOperand(0);
1720 auto Ptr = CI->getOperand(1);
1721
1722 auto IntTy = Type::getInt32Ty(M.getContext());
1723 auto Int2Ty = VectorType::get(IntTy, 2);
1724 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1725 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1726
1727 auto IndexedPtr =
1728 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
1729 auto Load = new LoadInst(IndexedPtr, "", CI);
1730
1731 // Extract each element from the loaded int2.
1732 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1733 "", CI);
1734 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1735 "", CI);
1736
1737 // Our intrinsic to unpack a float2 from an int.
1738 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1739
1740 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1741
1742 // Get the lower (x & y) components of our final float4.
1743 auto Lo = CallInst::Create(NewF, X, "", CI);
1744
1745 // Get the higher (z & w) components of our final float4.
1746 auto Hi = CallInst::Create(NewF, Y, "", CI);
1747
1748 Constant *ShuffleMask[4] = {
1749 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1750 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1751
1752 // Combine our two float2's into one float4.
1753 auto Combine = new ShuffleVectorInst(
1754 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1755
1756 CI->replaceAllUsesWith(Combine);
1757
1758 // Lastly, remember to remove the user.
1759 ToRemoves.push_back(CI);
1760 }
1761 }
1762
1763 Changed = true;
1764
1765 // And cleanup the calls we don't use anymore.
1766 for (auto V : ToRemoves) {
1767 V->eraseFromParent();
1768 }
1769
1770 // And remove the function we don't need either too.
1771 F->eraseFromParent();
1772 }
1773 }
1774
1775 return Changed;
1776}
1777
David Neto22f144c2017-06-12 14:26:21 -04001778bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
1779 bool Changed = false;
1780
1781 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
1782 "_Z15vstore_half_rtefjPU3AS1Dh",
1783 "_Z15vstore_half_rtzfjPU3AS1Dh"};
1784
1785 for (auto Name : Map) {
1786 // If we find a function with the matching name.
1787 if (auto F = M.getFunction(Name)) {
1788 SmallVector<Instruction *, 4> ToRemoves;
1789
1790 // Walk the users of the function.
1791 for (auto &U : F->uses()) {
1792 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1793 // The value to store.
1794 auto Arg0 = CI->getOperand(0);
1795
1796 // The index argument from vstore_half.
1797 auto Arg1 = CI->getOperand(1);
1798
1799 // The pointer argument from vstore_half.
1800 auto Arg2 = CI->getOperand(2);
1801
David Neto22f144c2017-06-12 14:26:21 -04001802 auto IntTy = Type::getInt32Ty(M.getContext());
1803 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001804 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04001805 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04001806
1807 // Our intrinsic to pack a float2 to an int.
1808 auto SPIRVIntrinsic = "spirv.pack.v2f16";
1809
1810 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1811
1812 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04001813 auto TempVec =
1814 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
1815 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001816
1817 // Pack the float2 -> half2 (in an int).
1818 auto X = CallInst::Create(NewF, TempVec, "", CI);
1819
David Neto482550a2018-03-24 05:21:07 -07001820 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04001821 auto ShortTy = Type::getInt16Ty(M.getContext());
1822 auto ShortPointerTy = PointerType::get(
1823 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001824
David Neto17852de2017-05-29 17:29:31 -04001825 // Truncate our i32 to an i16.
1826 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001827
David Neto17852de2017-05-29 17:29:31 -04001828 // Cast the half* pointer to short*.
1829 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001830
David Neto17852de2017-05-29 17:29:31 -04001831 // Index into the correct address of the casted pointer.
1832 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001833
David Neto17852de2017-05-29 17:29:31 -04001834 // Store to the int* we casted to.
1835 auto Store = new StoreInst(Trunc, Index, CI);
1836
1837 CI->replaceAllUsesWith(Store);
1838 } else {
1839 // We can only write to 32-bit aligned words.
1840 //
1841 // Assuming base is aligned to 32-bits, replace the equivalent of
1842 // vstore_half(value, index, base)
1843 // with:
1844 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
1845 // uint32_t write_to_upper_half = index & 1u;
1846 // uint32_t shift = write_to_upper_half << 4;
1847 //
1848 // // Pack the float value as a half number in bottom 16 bits
1849 // // of an i32.
1850 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
1851 //
1852 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
1853 // ^ ((packed & 0xffff) << shift)
1854 // // We only need relaxed consistency, but OpenCL 1.2 only has
1855 // // sequentially consistent atomics.
1856 // // TODO(dneto): Use relaxed consistency.
1857 // atomic_xor(target_ptr, xor_value)
1858 auto IntPointerTy = PointerType::get(
1859 IntTy, Arg2->getType()->getPointerAddressSpace());
1860
1861 auto Four = ConstantInt::get(IntTy, 4);
1862 auto FFFF = ConstantInt::get(IntTy, 0xffff);
1863
1864 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
1865 // Compute index / 2
1866 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
1867 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
1868 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
1869 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
1870 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
1871 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
1872 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
1873
1874 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
1875 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
1876 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
1877
1878 // Generate the call to atomi_xor.
1879 SmallVector<Type *, 5> ParamTypes;
1880 // The pointer type.
1881 ParamTypes.push_back(IntPointerTy);
1882 // The Types for memory scope, semantics, and value.
1883 ParamTypes.push_back(IntTy);
1884 ParamTypes.push_back(IntTy);
1885 ParamTypes.push_back(IntTy);
1886 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
1887 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
1888
1889 const auto ConstantScopeDevice =
1890 ConstantInt::get(IntTy, spv::ScopeDevice);
1891 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
1892 // (SPIR-V Workgroup).
1893 const auto AddrSpaceSemanticsBits =
1894 IntPointerTy->getPointerAddressSpace() == 1
1895 ? spv::MemorySemanticsUniformMemoryMask
1896 : spv::MemorySemanticsWorkgroupMemoryMask;
1897
1898 // We're using relaxed consistency here.
1899 const auto ConstantMemorySemantics =
1900 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
1901 AddrSpaceSemanticsBits);
1902
1903 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
1904 ConstantMemorySemantics, ValueToXor};
1905 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
1906 }
David Neto22f144c2017-06-12 14:26:21 -04001907
1908 // Lastly, remember to remove the user.
1909 ToRemoves.push_back(CI);
1910 }
1911 }
1912
1913 Changed = !ToRemoves.empty();
1914
1915 // And cleanup the calls we don't use anymore.
1916 for (auto V : ToRemoves) {
1917 V->eraseFromParent();
1918 }
1919
1920 // And remove the function we don't need either too.
1921 F->eraseFromParent();
1922 }
1923 }
1924
1925 return Changed;
1926}
1927
1928bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
1929 bool Changed = false;
1930
David Netoe2871522018-06-08 11:09:54 -07001931 const std::vector<const char *> Map = {
1932 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
1933 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
1934 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
1935 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
1936 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
1937 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
1938 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
1939 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
1940 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
1941 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
1942 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
1943 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
1944 };
David Neto22f144c2017-06-12 14:26:21 -04001945
1946 for (auto Name : Map) {
1947 // If we find a function with the matching name.
1948 if (auto F = M.getFunction(Name)) {
1949 SmallVector<Instruction *, 4> ToRemoves;
1950
1951 // Walk the users of the function.
1952 for (auto &U : F->uses()) {
1953 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1954 // The value to store.
1955 auto Arg0 = CI->getOperand(0);
1956
1957 // The index argument from vstore_half.
1958 auto Arg1 = CI->getOperand(1);
1959
1960 // The pointer argument from vstore_half.
1961 auto Arg2 = CI->getOperand(2);
1962
1963 auto IntTy = Type::getInt32Ty(M.getContext());
1964 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1965 auto NewPointerTy = PointerType::get(
1966 IntTy, Arg2->getType()->getPointerAddressSpace());
1967 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
1968
1969 // Our intrinsic to pack a float2 to an int.
1970 auto SPIRVIntrinsic = "spirv.pack.v2f16";
1971
1972 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1973
1974 // Turn the packed x & y into the final packing.
1975 auto X = CallInst::Create(NewF, Arg0, "", CI);
1976
1977 // Cast the half* pointer to int*.
1978 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
1979
1980 // Index into the correct address of the casted pointer.
1981 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
1982
1983 // Store to the int* we casted to.
1984 auto Store = new StoreInst(X, Index, CI);
1985
1986 CI->replaceAllUsesWith(Store);
1987
1988 // Lastly, remember to remove the user.
1989 ToRemoves.push_back(CI);
1990 }
1991 }
1992
1993 Changed = !ToRemoves.empty();
1994
1995 // And cleanup the calls we don't use anymore.
1996 for (auto V : ToRemoves) {
1997 V->eraseFromParent();
1998 }
1999
2000 // And remove the function we don't need either too.
2001 F->eraseFromParent();
2002 }
2003 }
2004
2005 return Changed;
2006}
2007
2008bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2009 bool Changed = false;
2010
David Netoe2871522018-06-08 11:09:54 -07002011 const std::vector<const char *> Map = {
2012 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2013 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2014 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2015 "_Z13vstorea_half4Dv4_fjPDh", // private
2016 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2017 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2018 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2019 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2020 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2021 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2022 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2023 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2024 };
David Neto22f144c2017-06-12 14:26:21 -04002025
2026 for (auto Name : Map) {
2027 // If we find a function with the matching name.
2028 if (auto F = M.getFunction(Name)) {
2029 SmallVector<Instruction *, 4> ToRemoves;
2030
2031 // Walk the users of the function.
2032 for (auto &U : F->uses()) {
2033 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2034 // The value to store.
2035 auto Arg0 = CI->getOperand(0);
2036
2037 // The index argument from vstore_half.
2038 auto Arg1 = CI->getOperand(1);
2039
2040 // The pointer argument from vstore_half.
2041 auto Arg2 = CI->getOperand(2);
2042
2043 auto IntTy = Type::getInt32Ty(M.getContext());
2044 auto Int2Ty = VectorType::get(IntTy, 2);
2045 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2046 auto NewPointerTy = PointerType::get(
2047 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2048 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2049
2050 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2051 ConstantInt::get(IntTy, 1)};
2052
2053 // Extract out the x & y components of our to store value.
2054 auto Lo =
2055 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2056 ConstantVector::get(LoShuffleMask), "", CI);
2057
2058 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2059 ConstantInt::get(IntTy, 3)};
2060
2061 // Extract out the z & w components of our to store value.
2062 auto Hi =
2063 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2064 ConstantVector::get(HiShuffleMask), "", CI);
2065
2066 // Our intrinsic to pack a float2 to an int.
2067 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2068
2069 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2070
2071 // Turn the packed x & y into the final component of our int2.
2072 auto X = CallInst::Create(NewF, Lo, "", CI);
2073
2074 // Turn the packed z & w into the final component of our int2.
2075 auto Y = CallInst::Create(NewF, Hi, "", CI);
2076
2077 auto Combine = InsertElementInst::Create(
2078 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2079 Combine = InsertElementInst::Create(
2080 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2081
2082 // Cast the half* pointer to int2*.
2083 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2084
2085 // Index into the correct address of the casted pointer.
2086 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2087
2088 // Store to the int2* we casted to.
2089 auto Store = new StoreInst(Combine, Index, CI);
2090
2091 CI->replaceAllUsesWith(Store);
2092
2093 // Lastly, remember to remove the user.
2094 ToRemoves.push_back(CI);
2095 }
2096 }
2097
2098 Changed = !ToRemoves.empty();
2099
2100 // And cleanup the calls we don't use anymore.
2101 for (auto V : ToRemoves) {
2102 V->eraseFromParent();
2103 }
2104
2105 // And remove the function we don't need either too.
2106 F->eraseFromParent();
2107 }
2108 }
2109
2110 return Changed;
2111}
2112
2113bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2114 bool Changed = false;
2115
2116 const std::map<const char *, const char*> Map = {
2117 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2118 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2119 };
2120
2121 for (auto Pair : Map) {
2122 // If we find a function with the matching name.
2123 if (auto F = M.getFunction(Pair.first)) {
2124 SmallVector<Instruction *, 4> ToRemoves;
2125
2126 // Walk the users of the function.
2127 for (auto &U : F->uses()) {
2128 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2129 // The image.
2130 auto Arg0 = CI->getOperand(0);
2131
2132 // The sampler.
2133 auto Arg1 = CI->getOperand(1);
2134
2135 // The coordinate (integer type that we can't handle).
2136 auto Arg2 = CI->getOperand(2);
2137
2138 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2139
2140 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2141
2142 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2143
2144 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2145
2146 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2147
2148 CI->replaceAllUsesWith(NewCI);
2149
2150 // Lastly, remember to remove the user.
2151 ToRemoves.push_back(CI);
2152 }
2153 }
2154
2155 Changed = !ToRemoves.empty();
2156
2157 // And cleanup the calls we don't use anymore.
2158 for (auto V : ToRemoves) {
2159 V->eraseFromParent();
2160 }
2161
2162 // And remove the function we don't need either too.
2163 F->eraseFromParent();
2164 }
2165 }
2166
2167 return Changed;
2168}
2169
2170bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2171 bool Changed = false;
2172
2173 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002174 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2175 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2176 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2177 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2178 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2179 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002180 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2181 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2182 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2183 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2184 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002185 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002186
2187 for (auto Pair : Map) {
2188 // If we find a function with the matching name.
2189 if (auto F = M.getFunction(Pair.first)) {
2190 SmallVector<Instruction *, 4> ToRemoves;
2191
2192 // Walk the users of the function.
2193 for (auto &U : F->uses()) {
2194 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2195 auto FType = F->getFunctionType();
2196 SmallVector<Type *, 5> ParamTypes;
2197
2198 // The pointer type.
2199 ParamTypes.push_back(FType->getParamType(0));
2200
2201 auto IntTy = Type::getInt32Ty(M.getContext());
2202
2203 // The memory scope type.
2204 ParamTypes.push_back(IntTy);
2205
2206 // The memory semantics type.
2207 ParamTypes.push_back(IntTy);
2208
2209 if (2 < CI->getNumArgOperands()) {
2210 // The unequal memory semantics type.
2211 ParamTypes.push_back(IntTy);
2212
2213 // The value type.
2214 ParamTypes.push_back(FType->getParamType(2));
2215
2216 // The comparator type.
2217 ParamTypes.push_back(FType->getParamType(1));
2218 } else if (1 < CI->getNumArgOperands()) {
2219 // The value type.
2220 ParamTypes.push_back(FType->getParamType(1));
2221 }
2222
2223 auto NewFType =
2224 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2225 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2226
2227 // We need to map the OpenCL constants to the SPIR-V equivalents.
2228 const auto ConstantScopeDevice =
2229 ConstantInt::get(IntTy, spv::ScopeDevice);
2230 const auto ConstantMemorySemantics = ConstantInt::get(
2231 IntTy, spv::MemorySemanticsUniformMemoryMask |
2232 spv::MemorySemanticsSequentiallyConsistentMask);
2233
2234 SmallVector<Value *, 5> Params;
2235
2236 // The pointer.
2237 Params.push_back(CI->getArgOperand(0));
2238
2239 // The memory scope.
2240 Params.push_back(ConstantScopeDevice);
2241
2242 // The memory semantics.
2243 Params.push_back(ConstantMemorySemantics);
2244
2245 if (2 < CI->getNumArgOperands()) {
2246 // The unequal memory semantics.
2247 Params.push_back(ConstantMemorySemantics);
2248
2249 // The value.
2250 Params.push_back(CI->getArgOperand(2));
2251
2252 // The comparator.
2253 Params.push_back(CI->getArgOperand(1));
2254 } else if (1 < CI->getNumArgOperands()) {
2255 // The value.
2256 Params.push_back(CI->getArgOperand(1));
2257 }
2258
2259 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2260
2261 CI->replaceAllUsesWith(NewCI);
2262
2263 // Lastly, remember to remove the user.
2264 ToRemoves.push_back(CI);
2265 }
2266 }
2267
2268 Changed = !ToRemoves.empty();
2269
2270 // And cleanup the calls we don't use anymore.
2271 for (auto V : ToRemoves) {
2272 V->eraseFromParent();
2273 }
2274
2275 // And remove the function we don't need either too.
2276 F->eraseFromParent();
2277 }
2278 }
2279
Neil Henning39672102017-09-29 14:33:13 +01002280 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002281 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2282 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2283 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2284 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2285 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2286 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2287 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2288 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2289 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2290 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2291 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2292 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2293 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2294 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2295 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2296 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002297 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2298 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2299 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2300 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2301 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2302 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2303 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2304 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2305 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2306 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2307 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2308 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2309 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2310 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2311 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2312 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2313
2314 for (auto Pair : Map2) {
2315 // If we find a function with the matching name.
2316 if (auto F = M.getFunction(Pair.first)) {
2317 SmallVector<Instruction *, 4> ToRemoves;
2318
2319 // Walk the users of the function.
2320 for (auto &U : F->uses()) {
2321 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2322 auto AtomicOp = new AtomicRMWInst(
2323 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2324 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2325
2326 CI->replaceAllUsesWith(AtomicOp);
2327
2328 // Lastly, remember to remove the user.
2329 ToRemoves.push_back(CI);
2330 }
2331 }
2332
2333 Changed = !ToRemoves.empty();
2334
2335 // And cleanup the calls we don't use anymore.
2336 for (auto V : ToRemoves) {
2337 V->eraseFromParent();
2338 }
2339
2340 // And remove the function we don't need either too.
2341 F->eraseFromParent();
2342 }
2343 }
2344
David Neto22f144c2017-06-12 14:26:21 -04002345 return Changed;
2346}
2347
2348bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2349 bool Changed = false;
2350
2351 // If we find a function with the matching name.
2352 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2353 SmallVector<Instruction *, 4> ToRemoves;
2354
2355 auto IntTy = Type::getInt32Ty(M.getContext());
2356 auto FloatTy = Type::getFloatTy(M.getContext());
2357
2358 Constant *DownShuffleMask[3] = {
2359 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2360 ConstantInt::get(IntTy, 2)};
2361
2362 Constant *UpShuffleMask[4] = {
2363 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2364 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2365
2366 Constant *FloatVec[3] = {
2367 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
2368 };
2369
2370 // Walk the users of the function.
2371 for (auto &U : F->uses()) {
2372 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2373 auto Vec4Ty = CI->getArgOperand(0)->getType();
2374 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2375 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2376 auto Vec3Ty = Arg0->getType();
2377
2378 auto NewFType =
2379 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
2380
2381 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
2382
2383 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
2384
2385 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
2386
2387 CI->replaceAllUsesWith(Result);
2388
2389 // Lastly, remember to remove the user.
2390 ToRemoves.push_back(CI);
2391 }
2392 }
2393
2394 Changed = !ToRemoves.empty();
2395
2396 // And cleanup the calls we don't use anymore.
2397 for (auto V : ToRemoves) {
2398 V->eraseFromParent();
2399 }
2400
2401 // And remove the function we don't need either too.
2402 F->eraseFromParent();
2403 }
2404
2405 return Changed;
2406}
David Neto62653202017-10-16 19:05:18 -04002407
2408bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2409 bool Changed = false;
2410
2411 // OpenCL's float result = fract(float x, float* ptr)
2412 //
2413 // In the LLVM domain:
2414 //
2415 // %floor_result = call spir_func float @floor(float %x)
2416 // store float %floor_result, float * %ptr
2417 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2418 // %result = call spir_func float
2419 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2420 //
2421 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2422 // and clspv.fract occur in the SPIR-V generator pass:
2423 //
2424 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2425 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2426 // ...
2427 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2428 // OpStore %ptr %floor_result
2429 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2430 // %fract_result = OpExtInst %float
2431 // %glsl_ext Fmin %fract_intermediate %just_under_1
2432
2433
2434 using std::string;
2435
2436 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2437 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
2438 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
2439 auto make_quad = [](const char *a, const char *b, const char *c,
2440 const char *d) {
2441 return std::tuple<const char *, const char *, const char *, const char *>(
2442 a, b, c, d);
2443 };
2444 const std::vector<QuadType> Functions = {
2445 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
2446 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
2447 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
2448 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
2449 };
2450
2451 for (auto& quad : Functions) {
2452 const StringRef fract_name(std::get<0>(quad));
2453
2454 // If we find a function with the matching name.
2455 if (auto F = M.getFunction(fract_name)) {
2456 if (F->use_begin() == F->use_end())
2457 continue;
2458
2459 // We have some uses.
2460 Changed = true;
2461
2462 auto& Context = M.getContext();
2463
2464 const StringRef floor_name(std::get<1>(quad));
2465 const StringRef fmin_name(std::get<2>(quad));
2466 const StringRef clspv_fract_name(std::get<3>(quad));
2467
2468 // This is either float or a float vector. All the float-like
2469 // types are this type.
2470 auto result_ty = F->getReturnType();
2471
2472 Function* fmin_fn = M.getFunction(fmin_name);
2473 if (!fmin_fn) {
2474 // Make the fmin function.
2475 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
2476 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002477 fmin_fn->addFnAttr(Attribute::ReadNone);
2478 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2479 }
2480
2481 Function* floor_fn = M.getFunction(floor_name);
2482 if (!floor_fn) {
2483 // Make the floor function.
2484 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2485 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002486 floor_fn->addFnAttr(Attribute::ReadNone);
2487 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2488 }
2489
2490 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
2491 if (!clspv_fract_fn) {
2492 // Make the clspv_fract function.
2493 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2494 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002495 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2496 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2497 }
2498
2499 // Number of significant significand bits, whether represented or not.
2500 unsigned num_significand_bits;
2501 switch (result_ty->getScalarType()->getTypeID()) {
2502 case Type::HalfTyID:
2503 num_significand_bits = 11;
2504 break;
2505 case Type::FloatTyID:
2506 num_significand_bits = 24;
2507 break;
2508 case Type::DoubleTyID:
2509 num_significand_bits = 53;
2510 break;
2511 default:
2512 assert(false && "Unhandled float type when processing fract builtin");
2513 break;
2514 }
2515 // Beware that the disassembler displays this value as
2516 // OpConstant %float 1
2517 // which is not quite right.
2518 const double kJustUnderOneScalar =
2519 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2520
2521 Constant *just_under_one =
2522 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2523 if (result_ty->isVectorTy()) {
2524 just_under_one = ConstantVector::getSplat(
2525 result_ty->getVectorNumElements(), just_under_one);
2526 }
2527
2528 IRBuilder<> Builder(Context);
2529
2530 SmallVector<Instruction *, 4> ToRemoves;
2531
2532 // Walk the users of the function.
2533 for (auto &U : F->uses()) {
2534 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2535
2536 Builder.SetInsertPoint(CI);
2537 auto arg = CI->getArgOperand(0);
2538 auto ptr = CI->getArgOperand(1);
2539
2540 // Compute floor result and store it.
2541 auto floor = Builder.CreateCall(floor_fn, {arg});
2542 Builder.CreateStore(floor, ptr);
2543
2544 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2545 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2546
2547 CI->replaceAllUsesWith(fract_result);
2548
2549 // Lastly, remember to remove the user.
2550 ToRemoves.push_back(CI);
2551 }
2552 }
2553
2554 // And cleanup the calls we don't use anymore.
2555 for (auto V : ToRemoves) {
2556 V->eraseFromParent();
2557 }
2558
2559 // And remove the function we don't need either too.
2560 F->eraseFromParent();
2561 }
2562 }
2563
2564 return Changed;
2565}