blob: 50f7d2956a4dd32de1bba384711fe5834dd3622b [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
David Neto118188e2018-08-24 11:27:54 -040019#include "llvm/IR/Constants.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000023#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/Pass.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/raw_ostream.h"
27#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040028
David Neto118188e2018-08-24 11:27:54 -040029#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040030
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Neto22f144c2017-06-12 14:26:21 -040033using namespace llvm;
34
35#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
36
37namespace {
38uint32_t clz(uint32_t v) {
39 uint32_t r;
40 uint32_t shift;
41
42 r = (v > 0xFFFF) << 4;
43 v >>= r;
44 shift = (v > 0xFF) << 3;
45 v >>= shift;
46 r |= shift;
47 shift = (v > 0xF) << 2;
48 v >>= shift;
49 r |= shift;
50 shift = (v > 0x3) << 1;
51 v >>= shift;
52 r |= shift;
53 r |= (v >> 1);
54
55 return r;
56}
57
58Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
59 if (1 == elements) {
60 return Type::getInt1Ty(C);
61 } else {
62 return VectorType::get(Type::getInt1Ty(C), elements);
63 }
64}
65
66struct ReplaceOpenCLBuiltinPass final : public ModulePass {
67 static char ID;
68 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
69
70 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +000071 bool replaceAbs(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040072 bool replaceRecip(Module &M);
73 bool replaceDivide(Module &M);
74 bool replaceExp10(Module &M);
75 bool replaceLog10(Module &M);
76 bool replaceBarrier(Module &M);
77 bool replaceMemFence(Module &M);
78 bool replaceRelational(Module &M);
79 bool replaceIsInfAndIsNan(Module &M);
80 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +000081 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +000082 bool replaceRotate(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +000083 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +000084 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +000085 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040086 bool replaceSignbit(Module &M);
87 bool replaceMadandMad24andMul24(Module &M);
88 bool replaceVloadHalf(Module &M);
89 bool replaceVloadHalf2(Module &M);
90 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -070091 bool replaceClspvVloadaHalf2(Module &M);
92 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040093 bool replaceVstoreHalf(Module &M);
94 bool replaceVstoreHalf2(Module &M);
95 bool replaceVstoreHalf4(Module &M);
96 bool replaceReadImageF(Module &M);
97 bool replaceAtomics(Module &M);
98 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -040099 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700100 bool replaceVload(Module &M);
101 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400102};
103}
104
105char ReplaceOpenCLBuiltinPass::ID = 0;
106static RegisterPass<ReplaceOpenCLBuiltinPass> X("ReplaceOpenCLBuiltin",
107 "Replace OpenCL Builtins Pass");
108
109namespace clspv {
110ModulePass *createReplaceOpenCLBuiltinPass() {
111 return new ReplaceOpenCLBuiltinPass();
112}
113}
114
115bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
116 bool Changed = false;
117
Kévin Petit2444e9b2018-11-09 14:14:37 +0000118 Changed |= replaceAbs(M);
David Neto22f144c2017-06-12 14:26:21 -0400119 Changed |= replaceRecip(M);
120 Changed |= replaceDivide(M);
121 Changed |= replaceExp10(M);
122 Changed |= replaceLog10(M);
123 Changed |= replaceBarrier(M);
124 Changed |= replaceMemFence(M);
125 Changed |= replaceRelational(M);
126 Changed |= replaceIsInfAndIsNan(M);
127 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000128 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000129 Changed |= replaceRotate(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000130 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000131 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000132 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400133 Changed |= replaceSignbit(M);
134 Changed |= replaceMadandMad24andMul24(M);
135 Changed |= replaceVloadHalf(M);
136 Changed |= replaceVloadHalf2(M);
137 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700138 Changed |= replaceClspvVloadaHalf2(M);
139 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400140 Changed |= replaceVstoreHalf(M);
141 Changed |= replaceVstoreHalf2(M);
142 Changed |= replaceVstoreHalf4(M);
143 Changed |= replaceReadImageF(M);
144 Changed |= replaceAtomics(M);
145 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400146 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700147 Changed |= replaceVload(M);
148 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400149
150 return Changed;
151}
152
Kévin Petit2444e9b2018-11-09 14:14:37 +0000153bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
154 bool Changed = false;
155
156 const char *Names[] = {
157 "_Z3abst",
158 "_Z3absDv2_t",
159 "_Z3absDv3_t",
160 "_Z3absDv4_t",
161 "_Z3absj",
162 "_Z3absDv2_j",
163 "_Z3absDv3_j",
164 "_Z3absDv4_j",
165 "_Z3absm",
166 "_Z3absDv2_m",
167 "_Z3absDv3_m",
168 "_Z3absDv4_m",
169 };
170
171 for (auto Name : Names) {
172 // If we find a function with the matching name.
173 if (auto F = M.getFunction(Name)) {
174 SmallVector<Instruction *, 4> ToRemoves;
175
176 // Walk the users of the function.
177 for (auto &U : F->uses()) {
178 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
179 // Abs has one arg.
180 auto Arg = CI->getOperand(0);
181
182 // Use the argument unchanged, we know it's unsigned
183 CI->replaceAllUsesWith(Arg);
184
185 // Lastly, remember to remove the user.
186 ToRemoves.push_back(CI);
187 }
188 }
189
190 Changed = !ToRemoves.empty();
191
192 // And cleanup the calls we don't use anymore.
193 for (auto V : ToRemoves) {
194 V->eraseFromParent();
195 }
196
197 // And remove the function we don't need either too.
198 F->eraseFromParent();
199 }
200 }
201
202 return Changed;
203}
204
David Neto22f144c2017-06-12 14:26:21 -0400205bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
206 bool Changed = false;
207
208 const char *Names[] = {
209 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
210 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
211 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
212 };
213
214 for (auto Name : Names) {
215 // If we find a function with the matching name.
216 if (auto F = M.getFunction(Name)) {
217 SmallVector<Instruction *, 4> ToRemoves;
218
219 // Walk the users of the function.
220 for (auto &U : F->uses()) {
221 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
222 // Recip has one arg.
223 auto Arg = CI->getOperand(0);
224
225 auto Div = BinaryOperator::Create(
226 Instruction::FDiv, ConstantFP::get(Arg->getType(), 1.0), Arg, "",
227 CI);
228
229 CI->replaceAllUsesWith(Div);
230
231 // Lastly, remember to remove the user.
232 ToRemoves.push_back(CI);
233 }
234 }
235
236 Changed = !ToRemoves.empty();
237
238 // And cleanup the calls we don't use anymore.
239 for (auto V : ToRemoves) {
240 V->eraseFromParent();
241 }
242
243 // And remove the function we don't need either too.
244 F->eraseFromParent();
245 }
246 }
247
248 return Changed;
249}
250
251bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
252 bool Changed = false;
253
254 const char *Names[] = {
255 "_Z11half_divideff", "_Z13native_divideff",
256 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
257 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
258 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
259 };
260
261 for (auto Name : Names) {
262 // If we find a function with the matching name.
263 if (auto F = M.getFunction(Name)) {
264 SmallVector<Instruction *, 4> ToRemoves;
265
266 // Walk the users of the function.
267 for (auto &U : F->uses()) {
268 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
269 auto Div = BinaryOperator::Create(
270 Instruction::FDiv, CI->getOperand(0), CI->getOperand(1), "", CI);
271
272 CI->replaceAllUsesWith(Div);
273
274 // Lastly, remember to remove the user.
275 ToRemoves.push_back(CI);
276 }
277 }
278
279 Changed = !ToRemoves.empty();
280
281 // And cleanup the calls we don't use anymore.
282 for (auto V : ToRemoves) {
283 V->eraseFromParent();
284 }
285
286 // And remove the function we don't need either too.
287 F->eraseFromParent();
288 }
289 }
290
291 return Changed;
292}
293
294bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
295 bool Changed = false;
296
297 const std::map<const char *, const char *> Map = {
298 {"_Z5exp10f", "_Z3expf"},
299 {"_Z10half_exp10f", "_Z8half_expf"},
300 {"_Z12native_exp10f", "_Z10native_expf"},
301 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
302 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
303 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
304 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
305 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
306 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
307 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
308 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
309 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
310
311 for (auto Pair : Map) {
312 // If we find a function with the matching name.
313 if (auto F = M.getFunction(Pair.first)) {
314 SmallVector<Instruction *, 4> ToRemoves;
315
316 // Walk the users of the function.
317 for (auto &U : F->uses()) {
318 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
319 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
320
321 auto Arg = CI->getOperand(0);
322
323 // Constant of the natural log of 10 (ln(10)).
324 const double Ln10 =
325 2.302585092994045684017991454684364207601101488628772976033;
326
327 auto Mul = BinaryOperator::Create(
328 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
329 CI);
330
331 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
332
333 CI->replaceAllUsesWith(NewCI);
334
335 // Lastly, remember to remove the user.
336 ToRemoves.push_back(CI);
337 }
338 }
339
340 Changed = !ToRemoves.empty();
341
342 // And cleanup the calls we don't use anymore.
343 for (auto V : ToRemoves) {
344 V->eraseFromParent();
345 }
346
347 // And remove the function we don't need either too.
348 F->eraseFromParent();
349 }
350 }
351
352 return Changed;
353}
354
355bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
356 bool Changed = false;
357
358 const std::map<const char *, const char *> Map = {
359 {"_Z5log10f", "_Z3logf"},
360 {"_Z10half_log10f", "_Z8half_logf"},
361 {"_Z12native_log10f", "_Z10native_logf"},
362 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
363 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
364 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
365 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
366 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
367 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
368 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
369 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
370 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
371
372 for (auto Pair : Map) {
373 // If we find a function with the matching name.
374 if (auto F = M.getFunction(Pair.first)) {
375 SmallVector<Instruction *, 4> ToRemoves;
376
377 // Walk the users of the function.
378 for (auto &U : F->uses()) {
379 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
380 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
381
382 auto Arg = CI->getOperand(0);
383
384 // Constant of the reciprocal of the natural log of 10 (ln(10)).
385 const double Ln10 =
386 0.434294481903251827651128918916605082294397005803666566114;
387
388 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
389
390 auto Mul = BinaryOperator::Create(
391 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
392 "", CI);
393
394 CI->replaceAllUsesWith(Mul);
395
396 // Lastly, remember to remove the user.
397 ToRemoves.push_back(CI);
398 }
399 }
400
401 Changed = !ToRemoves.empty();
402
403 // And cleanup the calls we don't use anymore.
404 for (auto V : ToRemoves) {
405 V->eraseFromParent();
406 }
407
408 // And remove the function we don't need either too.
409 F->eraseFromParent();
410 }
411 }
412
413 return Changed;
414}
415
416bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
417 bool Changed = false;
418
419 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
420
421 const std::map<const char *, const char *> Map = {
422 {"_Z7barrierj", "__spirv_control_barrier"}};
423
424 for (auto Pair : Map) {
425 // If we find a function with the matching name.
426 if (auto F = M.getFunction(Pair.first)) {
427 SmallVector<Instruction *, 4> ToRemoves;
428
429 // Walk the users of the function.
430 for (auto &U : F->uses()) {
431 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
432 auto FType = F->getFunctionType();
433 SmallVector<Type *, 3> Params;
434 for (unsigned i = 0; i < 3; i++) {
435 Params.push_back(FType->getParamType(0));
436 }
437 auto NewFType =
438 FunctionType::get(FType->getReturnType(), Params, false);
439 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
440
441 auto Arg = CI->getOperand(0);
442
443 // We need to map the OpenCL constants to the SPIR-V equivalents.
444 const auto LocalMemFence =
445 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
446 const auto GlobalMemFence =
447 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
448 const auto ConstantSequentiallyConsistent = ConstantInt::get(
449 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
450 const auto ConstantScopeDevice =
451 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
452 const auto ConstantScopeWorkgroup =
453 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
454
455 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
456 const auto LocalMemFenceMask = BinaryOperator::Create(
457 Instruction::And, LocalMemFence, Arg, "", CI);
458 const auto WorkgroupShiftAmount =
459 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
460 clz(CLK_LOCAL_MEM_FENCE);
461 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
462 Instruction::Shl, LocalMemFenceMask,
463 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
464
465 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
466 const auto GlobalMemFenceMask = BinaryOperator::Create(
467 Instruction::And, GlobalMemFence, Arg, "", CI);
468 const auto UniformShiftAmount =
469 clz(spv::MemorySemanticsUniformMemoryMask) -
470 clz(CLK_GLOBAL_MEM_FENCE);
471 const auto MemorySemanticsUniform = BinaryOperator::Create(
472 Instruction::Shl, GlobalMemFenceMask,
473 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
474
475 // And combine the above together, also adding in
476 // MemorySemanticsSequentiallyConsistentMask.
477 auto MemorySemantics =
478 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
479 ConstantSequentiallyConsistent, "", CI);
480 MemorySemantics = BinaryOperator::Create(
481 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
482
483 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
484 // Device Scope, otherwise Workgroup Scope.
485 const auto Cmp =
486 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
487 GlobalMemFenceMask, GlobalMemFence, "", CI);
488 const auto MemoryScope = SelectInst::Create(
489 Cmp, ConstantScopeDevice, ConstantScopeWorkgroup, "", CI);
490
491 // Lastly, the Execution Scope is always Workgroup Scope.
492 const auto ExecutionScope = ConstantScopeWorkgroup;
493
494 auto NewCI = CallInst::Create(
495 NewF, {ExecutionScope, MemoryScope, MemorySemantics}, "", CI);
496
497 CI->replaceAllUsesWith(NewCI);
498
499 // Lastly, remember to remove the user.
500 ToRemoves.push_back(CI);
501 }
502 }
503
504 Changed = !ToRemoves.empty();
505
506 // And cleanup the calls we don't use anymore.
507 for (auto V : ToRemoves) {
508 V->eraseFromParent();
509 }
510
511 // And remove the function we don't need either too.
512 F->eraseFromParent();
513 }
514 }
515
516 return Changed;
517}
518
519bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
520 bool Changed = false;
521
522 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
523
Neil Henning39672102017-09-29 14:33:13 +0100524 using Tuple = std::tuple<const char *, unsigned>;
525 const std::map<const char *, Tuple> Map = {
526 {"_Z9mem_fencej",
527 Tuple("__spirv_memory_barrier",
528 spv::MemorySemanticsSequentiallyConsistentMask)},
529 {"_Z14read_mem_fencej",
530 Tuple("__spirv_memory_barrier", spv::MemorySemanticsAcquireMask)},
531 {"_Z15write_mem_fencej",
532 Tuple("__spirv_memory_barrier", spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400533
534 for (auto Pair : Map) {
535 // If we find a function with the matching name.
536 if (auto F = M.getFunction(Pair.first)) {
537 SmallVector<Instruction *, 4> ToRemoves;
538
539 // Walk the users of the function.
540 for (auto &U : F->uses()) {
541 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
542 auto FType = F->getFunctionType();
543 SmallVector<Type *, 2> Params;
544 for (unsigned i = 0; i < 2; i++) {
545 Params.push_back(FType->getParamType(0));
546 }
547 auto NewFType =
548 FunctionType::get(FType->getReturnType(), Params, false);
Neil Henning39672102017-09-29 14:33:13 +0100549 auto NewF = M.getOrInsertFunction(std::get<0>(Pair.second), NewFType);
David Neto22f144c2017-06-12 14:26:21 -0400550
551 auto Arg = CI->getOperand(0);
552
553 // We need to map the OpenCL constants to the SPIR-V equivalents.
554 const auto LocalMemFence =
555 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
556 const auto GlobalMemFence =
557 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
558 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100559 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400560 const auto ConstantScopeDevice =
561 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
562
563 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
564 const auto LocalMemFenceMask = BinaryOperator::Create(
565 Instruction::And, LocalMemFence, Arg, "", CI);
566 const auto WorkgroupShiftAmount =
567 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
568 clz(CLK_LOCAL_MEM_FENCE);
569 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
570 Instruction::Shl, LocalMemFenceMask,
571 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
572
573 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
574 const auto GlobalMemFenceMask = BinaryOperator::Create(
575 Instruction::And, GlobalMemFence, Arg, "", CI);
576 const auto UniformShiftAmount =
577 clz(spv::MemorySemanticsUniformMemoryMask) -
578 clz(CLK_GLOBAL_MEM_FENCE);
579 const auto MemorySemanticsUniform = BinaryOperator::Create(
580 Instruction::Shl, GlobalMemFenceMask,
581 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
582
583 // And combine the above together, also adding in
584 // MemorySemanticsSequentiallyConsistentMask.
585 auto MemorySemantics =
586 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
587 ConstantMemorySemantics, "", CI);
588 MemorySemantics = BinaryOperator::Create(
589 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
590
591 // Memory Scope is always device.
592 const auto MemoryScope = ConstantScopeDevice;
593
594 auto NewCI =
595 CallInst::Create(NewF, {MemoryScope, MemorySemantics}, "", CI);
596
597 CI->replaceAllUsesWith(NewCI);
598
599 // Lastly, remember to remove the user.
600 ToRemoves.push_back(CI);
601 }
602 }
603
604 Changed = !ToRemoves.empty();
605
606 // And cleanup the calls we don't use anymore.
607 for (auto V : ToRemoves) {
608 V->eraseFromParent();
609 }
610
611 // And remove the function we don't need either too.
612 F->eraseFromParent();
613 }
614 }
615
616 return Changed;
617}
618
619bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
620 bool Changed = false;
621
622 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
623 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
624 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
625 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
626 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
627 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
628 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
629 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
630 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
631 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
632 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
633 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
634 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
635 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
636 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
637 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
638 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
639 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
640 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
641 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
642 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
643 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
644 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
645 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
646 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
647 };
648
649 for (auto Pair : Map) {
650 // If we find a function with the matching name.
651 if (auto F = M.getFunction(Pair.first)) {
652 SmallVector<Instruction *, 4> ToRemoves;
653
654 // Walk the users of the function.
655 for (auto &U : F->uses()) {
656 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
657 // The predicate to use in the CmpInst.
658 auto Predicate = Pair.second.first;
659
660 // The value to return for true.
661 auto TrueValue =
662 ConstantInt::getSigned(CI->getType(), Pair.second.second);
663
664 // The value to return for false.
665 auto FalseValue = Constant::getNullValue(CI->getType());
666
667 auto Arg1 = CI->getOperand(0);
668 auto Arg2 = CI->getOperand(1);
669
670 const auto Cmp =
671 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
672
673 const auto Select =
674 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
675
676 CI->replaceAllUsesWith(Select);
677
678 // Lastly, remember to remove the user.
679 ToRemoves.push_back(CI);
680 }
681 }
682
683 Changed = !ToRemoves.empty();
684
685 // And cleanup the calls we don't use anymore.
686 for (auto V : ToRemoves) {
687 V->eraseFromParent();
688 }
689
690 // And remove the function we don't need either too.
691 F->eraseFromParent();
692 }
693 }
694
695 return Changed;
696}
697
698bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
699 bool Changed = false;
700
701 const std::map<const char *, std::pair<const char *, int32_t>> Map = {
702 {"_Z5isinff", {"__spirv_isinff", 1}},
703 {"_Z5isinfDv2_f", {"__spirv_isinfDv2_f", -1}},
704 {"_Z5isinfDv3_f", {"__spirv_isinfDv3_f", -1}},
705 {"_Z5isinfDv4_f", {"__spirv_isinfDv4_f", -1}},
706 {"_Z5isnanf", {"__spirv_isnanf", 1}},
707 {"_Z5isnanDv2_f", {"__spirv_isnanDv2_f", -1}},
708 {"_Z5isnanDv3_f", {"__spirv_isnanDv3_f", -1}},
709 {"_Z5isnanDv4_f", {"__spirv_isnanDv4_f", -1}},
710 };
711
712 for (auto Pair : Map) {
713 // If we find a function with the matching name.
714 if (auto F = M.getFunction(Pair.first)) {
715 SmallVector<Instruction *, 4> ToRemoves;
716
717 // Walk the users of the function.
718 for (auto &U : F->uses()) {
719 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
720 const auto CITy = CI->getType();
721
722 // The fake SPIR-V intrinsic to generate.
723 auto SPIRVIntrinsic = Pair.second.first;
724
725 // The value to return for true.
726 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
727
728 // The value to return for false.
729 auto FalseValue = Constant::getNullValue(CITy);
730
731 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
732 M.getContext(),
733 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
734
735 auto NewFType =
736 FunctionType::get(CorrespondingBoolTy,
737 F->getFunctionType()->getParamType(0), false);
738
739 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
740
741 auto Arg = CI->getOperand(0);
742
743 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
744
745 const auto Select =
746 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
747
748 CI->replaceAllUsesWith(Select);
749
750 // Lastly, remember to remove the user.
751 ToRemoves.push_back(CI);
752 }
753 }
754
755 Changed = !ToRemoves.empty();
756
757 // And cleanup the calls we don't use anymore.
758 for (auto V : ToRemoves) {
759 V->eraseFromParent();
760 }
761
762 // And remove the function we don't need either too.
763 F->eraseFromParent();
764 }
765 }
766
767 return Changed;
768}
769
770bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
771 bool Changed = false;
772
773 const std::map<const char *, const char *> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000774 // all
775 {"_Z3alls", ""},
776 {"_Z3allDv2_s", "__spirv_allDv2_s"},
777 {"_Z3allDv3_s", "__spirv_allDv3_s"},
778 {"_Z3allDv4_s", "__spirv_allDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400779 {"_Z3alli", ""},
780 {"_Z3allDv2_i", "__spirv_allDv2_i"},
781 {"_Z3allDv3_i", "__spirv_allDv3_i"},
782 {"_Z3allDv4_i", "__spirv_allDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000783 {"_Z3alll", ""},
784 {"_Z3allDv2_l", "__spirv_allDv2_l"},
785 {"_Z3allDv3_l", "__spirv_allDv3_l"},
786 {"_Z3allDv4_l", "__spirv_allDv4_l"},
787
788 // any
789 {"_Z3anys", ""},
790 {"_Z3anyDv2_s", "__spirv_anyDv2_s"},
791 {"_Z3anyDv3_s", "__spirv_anyDv3_s"},
792 {"_Z3anyDv4_s", "__spirv_anyDv4_s"},
David Neto22f144c2017-06-12 14:26:21 -0400793 {"_Z3anyi", ""},
794 {"_Z3anyDv2_i", "__spirv_anyDv2_i"},
795 {"_Z3anyDv3_i", "__spirv_anyDv3_i"},
796 {"_Z3anyDv4_i", "__spirv_anyDv4_i"},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000797 {"_Z3anyl", ""},
798 {"_Z3anyDv2_l", "__spirv_anyDv2_l"},
799 {"_Z3anyDv3_l", "__spirv_anyDv3_l"},
800 {"_Z3anyDv4_l", "__spirv_anyDv4_l"},
David Neto22f144c2017-06-12 14:26:21 -0400801 };
802
803 for (auto Pair : Map) {
804 // If we find a function with the matching name.
805 if (auto F = M.getFunction(Pair.first)) {
806 SmallVector<Instruction *, 4> ToRemoves;
807
808 // Walk the users of the function.
809 for (auto &U : F->uses()) {
810 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
811 // The fake SPIR-V intrinsic to generate.
812 auto SPIRVIntrinsic = Pair.second;
813
814 auto Arg = CI->getOperand(0);
815
816 Value *V;
817
Kévin Petitfd27cca2018-10-31 13:00:17 +0000818 // If the argument is a 32-bit int, just use a shift
819 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
820 V = BinaryOperator::Create(Instruction::LShr, Arg,
821 ConstantInt::get(Arg->getType(), 31), "",
822 CI);
823 } else {
David Neto22f144c2017-06-12 14:26:21 -0400824 // The value for zero to compare against.
825 const auto ZeroValue = Constant::getNullValue(Arg->getType());
826
David Neto22f144c2017-06-12 14:26:21 -0400827 // The value to return for true.
828 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
829
830 // The value to return for false.
831 const auto FalseValue = Constant::getNullValue(CI->getType());
832
Kévin Petitfd27cca2018-10-31 13:00:17 +0000833 const auto Cmp = CmpInst::Create(
834 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
835
836 Value* SelectSource;
837
838 // If we have a function to call, call it!
839 if (0 < strlen(SPIRVIntrinsic)) {
840
841 const auto NewFType = FunctionType::get(
842 Type::getInt1Ty(M.getContext()), Cmp->getType(), false);
843
844 const auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
845
846 const auto NewCI = CallInst::Create(NewF, Cmp, "", CI);
847
848 SelectSource = NewCI;
849
850 } else {
851 SelectSource = Cmp;
852 }
853
854 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400855 }
856
857 CI->replaceAllUsesWith(V);
858
859 // Lastly, remember to remove the user.
860 ToRemoves.push_back(CI);
861 }
862 }
863
864 Changed = !ToRemoves.empty();
865
866 // And cleanup the calls we don't use anymore.
867 for (auto V : ToRemoves) {
868 V->eraseFromParent();
869 }
870
871 // And remove the function we don't need either too.
872 F->eraseFromParent();
873 }
874 }
875
876 return Changed;
877}
878
Kévin Petitbf0036c2019-03-06 13:57:10 +0000879bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
880 bool Changed = false;
881
882 for (auto const &SymVal : M.getValueSymbolTable()) {
883 // Skip symbols whose name doesn't match
884 if (!SymVal.getKey().startswith("_Z8upsample")) {
885 continue;
886 }
887 // Is there a function going by that name?
888 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
889
890 SmallVector<Instruction *, 4> ToRemoves;
891
892 // Walk the users of the function.
893 for (auto &U : F->uses()) {
894 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
895
896 // Get arguments
897 auto HiValue = CI->getOperand(0);
898 auto LoValue = CI->getOperand(1);
899
900 // Don't touch overloads that aren't in OpenCL C
901 auto HiType = HiValue->getType();
902 auto LoType = LoValue->getType();
903
904 if (HiType != LoType) {
905 continue;
906 }
907
908 if (!HiType->isIntOrIntVectorTy()) {
909 continue;
910 }
911
912 if (HiType->getScalarSizeInBits() * 2 !=
913 CI->getType()->getScalarSizeInBits()) {
914 continue;
915 }
916
917 if ((HiType->getScalarSizeInBits() != 8) &&
918 (HiType->getScalarSizeInBits() != 16) &&
919 (HiType->getScalarSizeInBits() != 32)) {
920 continue;
921 }
922
923 if (HiType->isVectorTy()) {
924 if ((HiType->getVectorNumElements() != 2) &&
925 (HiType->getVectorNumElements() != 3) &&
926 (HiType->getVectorNumElements() != 4) &&
927 (HiType->getVectorNumElements() != 8) &&
928 (HiType->getVectorNumElements() != 16)) {
929 continue;
930 }
931 }
932
933 // Convert both operands to the result type
934 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(),
935 "", CI);
936 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(),
937 "", CI);
938
939 // Shift high operand
940 auto ShiftAmount = ConstantInt::get(CI->getType(),
941 HiType->getScalarSizeInBits());
942 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
943 ShiftAmount, "", CI);
944
945 // OR both results
946 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
947 "", CI);
948
949 // Replace call with the expression
950 CI->replaceAllUsesWith(V);
951
952 // Lastly, remember to remove the user.
953 ToRemoves.push_back(CI);
954 }
955 }
956
957 Changed = !ToRemoves.empty();
958
959 // And cleanup the calls we don't use anymore.
960 for (auto V : ToRemoves) {
961 V->eraseFromParent();
962 }
963
964 // And remove the function we don't need either too.
965 F->eraseFromParent();
966 }
967 }
968
969 return Changed;
970}
971
Kévin Petitd44eef52019-03-08 13:22:14 +0000972bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
973 bool Changed = false;
974
975 for (auto const &SymVal : M.getValueSymbolTable()) {
976 // Skip symbols whose name doesn't match
977 if (!SymVal.getKey().startswith("_Z6rotate")) {
978 continue;
979 }
980 // Is there a function going by that name?
981 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
982
983 SmallVector<Instruction *, 4> ToRemoves;
984
985 // Walk the users of the function.
986 for (auto &U : F->uses()) {
987 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
988
989 // Get arguments
990 auto SrcValue = CI->getOperand(0);
991 auto RotAmount = CI->getOperand(1);
992
993 // Don't touch overloads that aren't in OpenCL C
994 auto SrcType = SrcValue->getType();
995 auto RotType = RotAmount->getType();
996
997 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
998 continue;
999 }
1000
1001 if (!SrcType->isIntOrIntVectorTy()) {
1002 continue;
1003 }
1004
1005 if ((SrcType->getScalarSizeInBits() != 8) &&
1006 (SrcType->getScalarSizeInBits() != 16) &&
1007 (SrcType->getScalarSizeInBits() != 32) &&
1008 (SrcType->getScalarSizeInBits() != 64)) {
1009 continue;
1010 }
1011
1012 if (SrcType->isVectorTy()) {
1013 if ((SrcType->getVectorNumElements() != 2) &&
1014 (SrcType->getVectorNumElements() != 3) &&
1015 (SrcType->getVectorNumElements() != 4) &&
1016 (SrcType->getVectorNumElements() != 8) &&
1017 (SrcType->getVectorNumElements() != 16)) {
1018 continue;
1019 }
1020 }
1021
1022 // The approach used is to shift the top bits down, the bottom bits up
1023 // and OR the two shifted values.
1024
1025 // The rotation amount is to be treated modulo the element size.
1026 // Since SPIR-V shift ops don't support this, let's apply the
1027 // modulo ahead of shifting. The element size is always a power of
1028 // two so we can just AND with a mask.
1029 auto ModMask = ConstantInt::get(SrcType,
1030 SrcType->getScalarSizeInBits() - 1);
1031 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1032 ModMask, "", CI);
1033
1034 // Let's calc the amount by which to shift top bits down
1035 auto ScalarSize = ConstantInt::get(SrcType,
1036 SrcType->getScalarSizeInBits());
1037 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1038 RotAmount, "", CI);
1039
1040 // Now shift the bottom bits up and the top bits down
1041 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1042 RotAmount, "", CI);
1043 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1044 DownAmount, "", CI);
1045
1046 // Finally OR the two shifted values
1047 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1048 HiRotated, "", CI);
1049
1050 // Replace call with the expression
1051 CI->replaceAllUsesWith(V);
1052
1053 // Lastly, remember to remove the user.
1054 ToRemoves.push_back(CI);
1055 }
1056 }
1057
1058 Changed = !ToRemoves.empty();
1059
1060 // And cleanup the calls we don't use anymore.
1061 for (auto V : ToRemoves) {
1062 V->eraseFromParent();
1063 }
1064
1065 // And remove the function we don't need either too.
1066 F->eraseFromParent();
1067 }
1068 }
1069
1070 return Changed;
1071}
1072
Kévin Petitf5b78a22018-10-25 14:32:17 +00001073bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1074 bool Changed = false;
1075
1076 for (auto const &SymVal : M.getValueSymbolTable()) {
1077 // Skip symbols whose name doesn't match
1078 if (!SymVal.getKey().startswith("_Z6select")) {
1079 continue;
1080 }
1081 // Is there a function going by that name?
1082 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1083
1084 SmallVector<Instruction *, 4> ToRemoves;
1085
1086 // Walk the users of the function.
1087 for (auto &U : F->uses()) {
1088 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1089
1090 // Get arguments
1091 auto FalseValue = CI->getOperand(0);
1092 auto TrueValue = CI->getOperand(1);
1093 auto PredicateValue = CI->getOperand(2);
1094
1095 // Don't touch overloads that aren't in OpenCL C
1096 auto FalseType = FalseValue->getType();
1097 auto TrueType = TrueValue->getType();
1098 auto PredicateType = PredicateValue->getType();
1099
1100 if (FalseType != TrueType) {
1101 continue;
1102 }
1103
1104 if (!PredicateType->isIntOrIntVectorTy()) {
1105 continue;
1106 }
1107
1108 if (!FalseType->isIntOrIntVectorTy() &&
1109 !FalseType->getScalarType()->isFloatingPointTy()) {
1110 continue;
1111 }
1112
1113 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1114 continue;
1115 }
1116
1117 if (FalseType->getScalarSizeInBits() !=
1118 PredicateType->getScalarSizeInBits()) {
1119 continue;
1120 }
1121
1122 if (FalseType->isVectorTy()) {
1123 if (FalseType->getVectorNumElements() !=
1124 PredicateType->getVectorNumElements()) {
1125 continue;
1126 }
1127
1128 if ((FalseType->getVectorNumElements() != 2) &&
1129 (FalseType->getVectorNumElements() != 3) &&
1130 (FalseType->getVectorNumElements() != 4) &&
1131 (FalseType->getVectorNumElements() != 8) &&
1132 (FalseType->getVectorNumElements() != 16)) {
1133 continue;
1134 }
1135 }
1136
1137 // Create constant
1138 const auto ZeroValue = Constant::getNullValue(PredicateType);
1139
1140 // Scalar and vector are to be treated differently
1141 CmpInst::Predicate Pred;
1142 if (PredicateType->isVectorTy()) {
1143 Pred = CmpInst::ICMP_SLT;
1144 } else {
1145 Pred = CmpInst::ICMP_NE;
1146 }
1147
1148 // Create comparison instruction
1149 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1150 ZeroValue, "", CI);
1151
1152 // Create select
1153 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1154
1155 // Replace call with the selection
1156 CI->replaceAllUsesWith(V);
1157
1158 // Lastly, remember to remove the user.
1159 ToRemoves.push_back(CI);
1160 }
1161 }
1162
1163 Changed = !ToRemoves.empty();
1164
1165 // And cleanup the calls we don't use anymore.
1166 for (auto V : ToRemoves) {
1167 V->eraseFromParent();
1168 }
1169
1170 // And remove the function we don't need either too.
1171 F->eraseFromParent();
1172 }
1173 }
1174
1175 return Changed;
1176}
1177
Kévin Petite7d0cce2018-10-31 12:38:56 +00001178bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1179 bool Changed = false;
1180
1181 for (auto const &SymVal : M.getValueSymbolTable()) {
1182 // Skip symbols whose name doesn't match
1183 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1184 continue;
1185 }
1186 // Is there a function going by that name?
1187 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1188
1189 SmallVector<Instruction *, 4> ToRemoves;
1190
1191 // Walk the users of the function.
1192 for (auto &U : F->uses()) {
1193 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1194
1195 if (CI->getNumOperands() != 4) {
1196 continue;
1197 }
1198
1199 // Get arguments
1200 auto FalseValue = CI->getOperand(0);
1201 auto TrueValue = CI->getOperand(1);
1202 auto PredicateValue = CI->getOperand(2);
1203
1204 // Don't touch overloads that aren't in OpenCL C
1205 auto FalseType = FalseValue->getType();
1206 auto TrueType = TrueValue->getType();
1207 auto PredicateType = PredicateValue->getType();
1208
1209 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1210 continue;
1211 }
1212
1213 if (TrueType->isVectorTy()) {
1214 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1215 !TrueType->getScalarType()->isIntegerTy()) {
1216 continue;
1217 }
1218 if ((TrueType->getVectorNumElements() != 2) &&
1219 (TrueType->getVectorNumElements() != 3) &&
1220 (TrueType->getVectorNumElements() != 4) &&
1221 (TrueType->getVectorNumElements() != 8) &&
1222 (TrueType->getVectorNumElements() != 16)) {
1223 continue;
1224 }
1225 }
1226
1227 // Remember the type of the operands
1228 auto OpType = TrueType;
1229
1230 // The actual bit selection will always be done on an integer type,
1231 // declare it here
1232 Type *BitType;
1233
1234 // If the operands are float, then bitcast them to int
1235 if (OpType->getScalarType()->isFloatingPointTy()) {
1236
1237 // First create the new type
1238 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1239 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1240 if (OpType->isVectorTy()) {
1241 BitType = VectorType::get(BitType, OpType->getVectorNumElements());
1242 }
1243
1244 // Then bitcast all operands
1245 PredicateValue = CastInst::CreateZExtOrBitCast(PredicateValue,
1246 BitType, "", CI);
1247 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue,
1248 BitType, "", CI);
1249 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1250
1251 } else {
1252 // The operands have an integer type, use it directly
1253 BitType = OpType;
1254 }
1255
1256 // All the operands are now always integers
1257 // implement as (c & b) | (~c & a)
1258
1259 // Create our negated predicate value
1260 auto AllOnes = Constant::getAllOnesValue(BitType);
1261 auto NotPredicateValue = BinaryOperator::Create(Instruction::Xor,
1262 PredicateValue,
1263 AllOnes, "", CI);
1264
1265 // Then put everything together
1266 auto BitsFalse = BinaryOperator::Create(Instruction::And,
1267 NotPredicateValue,
1268 FalseValue, "", CI);
1269 auto BitsTrue = BinaryOperator::Create(Instruction::And,
1270 PredicateValue,
1271 TrueValue, "", CI);
1272
1273 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1274 BitsTrue, "", CI);
1275
1276 // If we were dealing with a floating point type, we must bitcast
1277 // the result back to that
1278 if (OpType->getScalarType()->isFloatingPointTy()) {
1279 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1280 }
1281
1282 // Replace call with our new code
1283 CI->replaceAllUsesWith(V);
1284
1285 // Lastly, remember to remove the user.
1286 ToRemoves.push_back(CI);
1287 }
1288 }
1289
1290 Changed = !ToRemoves.empty();
1291
1292 // And cleanup the calls we don't use anymore.
1293 for (auto V : ToRemoves) {
1294 V->eraseFromParent();
1295 }
1296
1297 // And remove the function we don't need either too.
1298 F->eraseFromParent();
1299 }
1300 }
1301
1302 return Changed;
1303}
1304
Kévin Petit6b0a9532018-10-30 20:00:39 +00001305bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1306 bool Changed = false;
1307
1308 const std::map<const char *, const char *> Map = {
1309 { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
1310 { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
1311 { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
1312 { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
1313 { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
1314 { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
1315 };
1316
1317 for (auto Pair : Map) {
1318 // If we find a function with the matching name.
1319 if (auto F = M.getFunction(Pair.first)) {
1320 SmallVector<Instruction *, 4> ToRemoves;
1321
1322 // Walk the users of the function.
1323 for (auto &U : F->uses()) {
1324 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1325
1326 auto ReplacementFn = Pair.second;
1327
1328 SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
1329 Value *VectorArg;
1330
1331 // First figure out which function we're dealing with
1332 if (F->getName().startswith("_Z10smoothstep")) {
1333 ArgsToSplat.push_back(CI->getOperand(1));
1334 VectorArg = CI->getOperand(2);
1335 } else {
1336 VectorArg = CI->getOperand(1);
1337 }
1338
1339 // Splat arguments that need to be
1340 SmallVector<Value*, 2> SplatArgs;
1341 auto VecType = VectorArg->getType();
1342
1343 for (auto arg : ArgsToSplat) {
1344 Value* NewVectorArg = UndefValue::get(VecType);
1345 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
1346 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1347 NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1348 }
1349 SplatArgs.push_back(NewVectorArg);
1350 }
1351
1352 // Replace the call with the vector/vector flavour
1353 SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1354 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1355
1356 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1357
1358 SmallVector<Value*, 3> NewArgs;
1359 for (auto arg : SplatArgs) {
1360 NewArgs.push_back(arg);
1361 }
1362 NewArgs.push_back(VectorArg);
1363
1364 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1365
1366 CI->replaceAllUsesWith(NewCI);
1367
1368 // Lastly, remember to remove the user.
1369 ToRemoves.push_back(CI);
1370 }
1371 }
1372
1373 Changed = !ToRemoves.empty();
1374
1375 // And cleanup the calls we don't use anymore.
1376 for (auto V : ToRemoves) {
1377 V->eraseFromParent();
1378 }
1379
1380 // And remove the function we don't need either too.
1381 F->eraseFromParent();
1382 }
1383 }
1384
1385 return Changed;
1386}
1387
David Neto22f144c2017-06-12 14:26:21 -04001388bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1389 bool Changed = false;
1390
1391 const std::map<const char *, Instruction::BinaryOps> Map = {
1392 {"_Z7signbitf", Instruction::LShr},
1393 {"_Z7signbitDv2_f", Instruction::AShr},
1394 {"_Z7signbitDv3_f", Instruction::AShr},
1395 {"_Z7signbitDv4_f", Instruction::AShr},
1396 };
1397
1398 for (auto Pair : Map) {
1399 // If we find a function with the matching name.
1400 if (auto F = M.getFunction(Pair.first)) {
1401 SmallVector<Instruction *, 4> ToRemoves;
1402
1403 // Walk the users of the function.
1404 for (auto &U : F->uses()) {
1405 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1406 auto Arg = CI->getOperand(0);
1407
1408 auto Bitcast =
1409 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1410
1411 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1412 ConstantInt::get(CI->getType(), 31),
1413 "", CI);
1414
1415 CI->replaceAllUsesWith(Shr);
1416
1417 // Lastly, remember to remove the user.
1418 ToRemoves.push_back(CI);
1419 }
1420 }
1421
1422 Changed = !ToRemoves.empty();
1423
1424 // And cleanup the calls we don't use anymore.
1425 for (auto V : ToRemoves) {
1426 V->eraseFromParent();
1427 }
1428
1429 // And remove the function we don't need either too.
1430 F->eraseFromParent();
1431 }
1432 }
1433
1434 return Changed;
1435}
1436
1437bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1438 bool Changed = false;
1439
1440 const std::map<const char *,
1441 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1442 Map = {
1443 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1444 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1445 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1446 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1447 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1448 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1449 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1450 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1451 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1452 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1453 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1454 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1455 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1456 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1457 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1458 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1459 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1460 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1461 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1462 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1463 };
1464
1465 for (auto Pair : Map) {
1466 // If we find a function with the matching name.
1467 if (auto F = M.getFunction(Pair.first)) {
1468 SmallVector<Instruction *, 4> ToRemoves;
1469
1470 // Walk the users of the function.
1471 for (auto &U : F->uses()) {
1472 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1473 // The multiply instruction to use.
1474 auto MulInst = Pair.second.first;
1475
1476 // The add instruction to use.
1477 auto AddInst = Pair.second.second;
1478
1479 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1480
1481 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1482 CI->getArgOperand(1), "", CI);
1483
1484 if (Instruction::BinaryOpsEnd != AddInst) {
1485 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1486 CI);
1487 }
1488
1489 CI->replaceAllUsesWith(I);
1490
1491 // Lastly, remember to remove the user.
1492 ToRemoves.push_back(CI);
1493 }
1494 }
1495
1496 Changed = !ToRemoves.empty();
1497
1498 // And cleanup the calls we don't use anymore.
1499 for (auto V : ToRemoves) {
1500 V->eraseFromParent();
1501 }
1502
1503 // And remove the function we don't need either too.
1504 F->eraseFromParent();
1505 }
1506 }
1507
1508 return Changed;
1509}
1510
Derek Chowcfd368b2017-10-19 20:58:45 -07001511bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1512 bool Changed = false;
1513
1514 struct VectorStoreOps {
1515 const char* name;
1516 int n;
1517 Type* (*get_scalar_type_function)(LLVMContext&);
1518 } vector_store_ops[] = {
1519 // TODO(derekjchow): Expand this list.
1520 { "_Z7vstore4Dv4_fjPU3AS1f", 4, Type::getFloatTy }
1521 };
1522
David Neto544fffc2017-11-16 18:35:14 -05001523 for (const auto& Op : vector_store_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001524 auto Name = Op.name;
1525 auto N = Op.n;
1526 auto TypeFn = Op.get_scalar_type_function;
1527 if (auto F = M.getFunction(Name)) {
1528 SmallVector<Instruction *, 4> ToRemoves;
1529
1530 // Walk the users of the function.
1531 for (auto &U : F->uses()) {
1532 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1533 // The value argument from vstoren.
1534 auto Arg0 = CI->getOperand(0);
1535
1536 // The index argument from vstoren.
1537 auto Arg1 = CI->getOperand(1);
1538
1539 // The pointer argument from vstoren.
1540 auto Arg2 = CI->getOperand(2);
1541
1542 // Get types.
1543 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1544 auto ScalarNPointerTy = PointerType::get(
1545 ScalarNTy, Arg2->getType()->getPointerAddressSpace());
1546
1547 // Cast to scalarn
1548 auto Cast = CastInst::CreatePointerCast(
1549 Arg2, ScalarNPointerTy, "", CI);
1550 // Index to correct address
1551 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg1, "", CI);
1552 // Store
1553 auto Store = new StoreInst(Arg0, Index, CI);
1554
1555 CI->replaceAllUsesWith(Store);
1556 ToRemoves.push_back(CI);
1557 }
1558 }
1559
1560 Changed = !ToRemoves.empty();
1561
1562 // And cleanup the calls we don't use anymore.
1563 for (auto V : ToRemoves) {
1564 V->eraseFromParent();
1565 }
1566
1567 // And remove the function we don't need either too.
1568 F->eraseFromParent();
1569 }
1570 }
1571
1572 return Changed;
1573}
1574
1575bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1576 bool Changed = false;
1577
1578 struct VectorLoadOps {
1579 const char* name;
1580 int n;
1581 Type* (*get_scalar_type_function)(LLVMContext&);
1582 } vector_load_ops[] = {
1583 // TODO(derekjchow): Expand this list.
1584 { "_Z6vload4jPU3AS1Kf", 4, Type::getFloatTy }
1585 };
1586
David Neto544fffc2017-11-16 18:35:14 -05001587 for (const auto& Op : vector_load_ops) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001588 auto Name = Op.name;
1589 auto N = Op.n;
1590 auto TypeFn = Op.get_scalar_type_function;
1591 // If we find a function with the matching name.
1592 if (auto F = M.getFunction(Name)) {
1593 SmallVector<Instruction *, 4> ToRemoves;
1594
1595 // Walk the users of the function.
1596 for (auto &U : F->uses()) {
1597 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1598 // The index argument from vloadn.
1599 auto Arg0 = CI->getOperand(0);
1600
1601 // The pointer argument from vloadn.
1602 auto Arg1 = CI->getOperand(1);
1603
1604 // Get types.
1605 auto ScalarNTy = VectorType::get(TypeFn(M.getContext()), N);
1606 auto ScalarNPointerTy = PointerType::get(
1607 ScalarNTy, Arg1->getType()->getPointerAddressSpace());
1608
1609 // Cast to scalarn
1610 auto Cast = CastInst::CreatePointerCast(
1611 Arg1, ScalarNPointerTy, "", CI);
1612 // Index to correct address
1613 auto Index = GetElementPtrInst::Create(ScalarNTy, Cast, Arg0, "", CI);
1614 // Load
1615 auto Load = new LoadInst(Index, "", CI);
1616
1617 CI->replaceAllUsesWith(Load);
1618 ToRemoves.push_back(CI);
1619 }
1620 }
1621
1622 Changed = !ToRemoves.empty();
1623
1624 // And cleanup the calls we don't use anymore.
1625 for (auto V : ToRemoves) {
1626 V->eraseFromParent();
1627 }
1628
1629 // And remove the function we don't need either too.
1630 F->eraseFromParent();
1631
1632 }
1633 }
1634
1635 return Changed;
1636}
1637
David Neto22f144c2017-06-12 14:26:21 -04001638bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
1639 bool Changed = false;
1640
1641 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
1642 "_Z10vload_halfjPU3AS2KDh"};
1643
1644 for (auto Name : Map) {
1645 // If we find a function with the matching name.
1646 if (auto F = M.getFunction(Name)) {
1647 SmallVector<Instruction *, 4> ToRemoves;
1648
1649 // Walk the users of the function.
1650 for (auto &U : F->uses()) {
1651 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1652 // The index argument from vload_half.
1653 auto Arg0 = CI->getOperand(0);
1654
1655 // The pointer argument from vload_half.
1656 auto Arg1 = CI->getOperand(1);
1657
David Neto22f144c2017-06-12 14:26:21 -04001658 auto IntTy = Type::getInt32Ty(M.getContext());
1659 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04001660 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1661
David Neto22f144c2017-06-12 14:26:21 -04001662 // Our intrinsic to unpack a float2 from an int.
1663 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1664
1665 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1666
David Neto482550a2018-03-24 05:21:07 -07001667 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04001668 auto ShortTy = Type::getInt16Ty(M.getContext());
1669 auto ShortPointerTy = PointerType::get(
1670 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04001671
David Netoac825b82017-05-30 12:49:01 -04001672 // Cast the half* pointer to short*.
1673 auto Cast =
1674 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001675
David Netoac825b82017-05-30 12:49:01 -04001676 // Index into the correct address of the casted pointer.
1677 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1678
1679 // Load from the short* we casted to.
1680 auto Load = new LoadInst(Index, "", CI);
1681
1682 // ZExt the short -> int.
1683 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1684
1685 // Get our float2.
1686 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1687
1688 // Extract out the bottom element which is our float result.
1689 auto Extract = ExtractElementInst::Create(
1690 Call, ConstantInt::get(IntTy, 0), "", CI);
1691
1692 CI->replaceAllUsesWith(Extract);
1693 } else {
1694 // Assume the pointer argument points to storage aligned to 32bits
1695 // or more.
1696 // TODO(dneto): Do more analysis to make sure this is true?
1697 //
1698 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1699 // with:
1700 //
1701 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1702 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1703 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1704 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1705 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1706 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1707 // x float> %converted, %index_is_odd32
1708
1709 auto IntPointerTy = PointerType::get(
1710 IntTy, Arg1->getType()->getPointerAddressSpace());
1711
David Neto973e6a82017-05-30 13:48:18 -04001712 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04001713 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04001714 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04001715 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1716
1717 auto One = ConstantInt::get(IntTy, 1);
1718 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1719 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1720
1721 // Index into the correct address of the casted pointer.
1722 auto Ptr =
1723 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1724
1725 // Load from the int* we casted to.
1726 auto Load = new LoadInst(Ptr, "", CI);
1727
1728 // Get our float2.
1729 auto Call = CallInst::Create(NewF, Load, "", CI);
1730
1731 // Extract out the float result, where the element number is
1732 // determined by whether the original index was even or odd.
1733 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1734
1735 CI->replaceAllUsesWith(Extract);
1736 }
David Neto22f144c2017-06-12 14:26:21 -04001737
1738 // Lastly, remember to remove the user.
1739 ToRemoves.push_back(CI);
1740 }
1741 }
1742
1743 Changed = !ToRemoves.empty();
1744
1745 // And cleanup the calls we don't use anymore.
1746 for (auto V : ToRemoves) {
1747 V->eraseFromParent();
1748 }
1749
1750 // And remove the function we don't need either too.
1751 F->eraseFromParent();
1752 }
1753 }
1754
1755 return Changed;
1756}
1757
1758bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
1759 bool Changed = false;
1760
David Neto556c7e62018-06-08 13:45:55 -07001761 const std::vector<const char *> Map = {
1762 "_Z11vload_half2jPU3AS1KDh",
1763 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
1764 "_Z11vload_half2jPU3AS2KDh",
1765 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
1766 };
David Neto22f144c2017-06-12 14:26:21 -04001767
1768 for (auto Name : Map) {
1769 // If we find a function with the matching name.
1770 if (auto F = M.getFunction(Name)) {
1771 SmallVector<Instruction *, 4> ToRemoves;
1772
1773 // Walk the users of the function.
1774 for (auto &U : F->uses()) {
1775 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1776 // The index argument from vload_half.
1777 auto Arg0 = CI->getOperand(0);
1778
1779 // The pointer argument from vload_half.
1780 auto Arg1 = CI->getOperand(1);
1781
1782 auto IntTy = Type::getInt32Ty(M.getContext());
1783 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1784 auto NewPointerTy = PointerType::get(
1785 IntTy, Arg1->getType()->getPointerAddressSpace());
1786 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1787
1788 // Cast the half* pointer to int*.
1789 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1790
1791 // Index into the correct address of the casted pointer.
1792 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
1793
1794 // Load from the int* we casted to.
1795 auto Load = new LoadInst(Index, "", CI);
1796
1797 // Our intrinsic to unpack a float2 from an int.
1798 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1799
1800 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1801
1802 // Get our float2.
1803 auto Call = CallInst::Create(NewF, Load, "", CI);
1804
1805 CI->replaceAllUsesWith(Call);
1806
1807 // Lastly, remember to remove the user.
1808 ToRemoves.push_back(CI);
1809 }
1810 }
1811
1812 Changed = !ToRemoves.empty();
1813
1814 // And cleanup the calls we don't use anymore.
1815 for (auto V : ToRemoves) {
1816 V->eraseFromParent();
1817 }
1818
1819 // And remove the function we don't need either too.
1820 F->eraseFromParent();
1821 }
1822 }
1823
1824 return Changed;
1825}
1826
1827bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
1828 bool Changed = false;
1829
David Neto556c7e62018-06-08 13:45:55 -07001830 const std::vector<const char *> Map = {
1831 "_Z11vload_half4jPU3AS1KDh",
1832 "_Z12vloada_half4jPU3AS1KDh",
1833 "_Z11vload_half4jPU3AS2KDh",
1834 "_Z12vloada_half4jPU3AS2KDh",
1835 };
David Neto22f144c2017-06-12 14:26:21 -04001836
1837 for (auto Name : Map) {
1838 // If we find a function with the matching name.
1839 if (auto F = M.getFunction(Name)) {
1840 SmallVector<Instruction *, 4> ToRemoves;
1841
1842 // Walk the users of the function.
1843 for (auto &U : F->uses()) {
1844 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1845 // The index argument from vload_half.
1846 auto Arg0 = CI->getOperand(0);
1847
1848 // The pointer argument from vload_half.
1849 auto Arg1 = CI->getOperand(1);
1850
1851 auto IntTy = Type::getInt32Ty(M.getContext());
1852 auto Int2Ty = VectorType::get(IntTy, 2);
1853 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1854 auto NewPointerTy = PointerType::get(
1855 Int2Ty, Arg1->getType()->getPointerAddressSpace());
1856 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1857
1858 // Cast the half* pointer to int2*.
1859 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
1860
1861 // Index into the correct address of the casted pointer.
1862 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
1863
1864 // Load from the int2* we casted to.
1865 auto Load = new LoadInst(Index, "", CI);
1866
1867 // Extract each element from the loaded int2.
1868 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
1869 "", CI);
1870 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
1871 "", CI);
1872
1873 // Our intrinsic to unpack a float2 from an int.
1874 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1875
1876 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1877
1878 // Get the lower (x & y) components of our final float4.
1879 auto Lo = CallInst::Create(NewF, X, "", CI);
1880
1881 // Get the higher (z & w) components of our final float4.
1882 auto Hi = CallInst::Create(NewF, Y, "", CI);
1883
1884 Constant *ShuffleMask[4] = {
1885 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1886 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
1887
1888 // Combine our two float2's into one float4.
1889 auto Combine = new ShuffleVectorInst(
1890 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
1891
1892 CI->replaceAllUsesWith(Combine);
1893
1894 // Lastly, remember to remove the user.
1895 ToRemoves.push_back(CI);
1896 }
1897 }
1898
1899 Changed = !ToRemoves.empty();
1900
1901 // And cleanup the calls we don't use anymore.
1902 for (auto V : ToRemoves) {
1903 V->eraseFromParent();
1904 }
1905
1906 // And remove the function we don't need either too.
1907 F->eraseFromParent();
1908 }
1909 }
1910
1911 return Changed;
1912}
1913
David Neto6ad93232018-06-07 15:42:58 -07001914bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
1915 bool Changed = false;
1916
1917 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1918 //
1919 // %u = load i32 %ptr
1920 // %fxy = call <2 x float> Unpack2xHalf(u)
1921 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1922 const std::vector<const char *> Map = {
1923 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
1924 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
1925 "_Z20__clspv_vloada_half2jPKj", // private
1926 };
1927
1928 for (auto Name : Map) {
1929 // If we find a function with the matching name.
1930 if (auto F = M.getFunction(Name)) {
1931 SmallVector<Instruction *, 4> ToRemoves;
1932
1933 // Walk the users of the function.
1934 for (auto &U : F->uses()) {
1935 if (auto* CI = dyn_cast<CallInst>(U.getUser())) {
1936 auto Index = CI->getOperand(0);
1937 auto Ptr = CI->getOperand(1);
1938
1939 auto IntTy = Type::getInt32Ty(M.getContext());
1940 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
1941 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1942
1943 auto IndexedPtr =
1944 GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
1945 auto Load = new LoadInst(IndexedPtr, "", CI);
1946
1947 // Our intrinsic to unpack a float2 from an int.
1948 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
1949
1950 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1951
1952 // Get our final float2.
1953 auto Result = CallInst::Create(NewF, Load, "", CI);
1954
1955 CI->replaceAllUsesWith(Result);
1956
1957 // Lastly, remember to remove the user.
1958 ToRemoves.push_back(CI);
1959 }
1960 }
1961
1962 Changed = true;
1963
1964 // And cleanup the calls we don't use anymore.
1965 for (auto V : ToRemoves) {
1966 V->eraseFromParent();
1967 }
1968
1969 // And remove the function we don't need either too.
1970 F->eraseFromParent();
1971 }
1972 }
1973
1974 return Changed;
1975}
1976
1977bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
1978 bool Changed = false;
1979
1980 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1981 //
1982 // %u2 = load <2 x i32> %ptr
1983 // %u2xy = extractelement %u2, 0
1984 // %u2zw = extractelement %u2, 1
1985 // %fxy = call <2 x float> Unpack2xHalf(uint)
1986 // %fzw = call <2 x float> Unpack2xHalf(uint)
1987 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
1988 const std::vector<const char *> Map = {
1989 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
1990 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
1991 "_Z20__clspv_vloada_half4jPKDv2_j", // private
1992 };
1993
1994 for (auto Name : Map) {
1995 // If we find a function with the matching name.
1996 if (auto F = M.getFunction(Name)) {
1997 SmallVector<Instruction *, 4> ToRemoves;
1998
1999 // Walk the users of the function.
2000 for (auto &U : F->uses()) {
2001 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2002 auto Index = CI->getOperand(0);
2003 auto Ptr = CI->getOperand(1);
2004
2005 auto IntTy = Type::getInt32Ty(M.getContext());
2006 auto Int2Ty = VectorType::get(IntTy, 2);
2007 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2008 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2009
2010 auto IndexedPtr =
2011 GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
2012 auto Load = new LoadInst(IndexedPtr, "", CI);
2013
2014 // Extract each element from the loaded int2.
2015 auto X = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0),
2016 "", CI);
2017 auto Y = ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1),
2018 "", CI);
2019
2020 // Our intrinsic to unpack a float2 from an int.
2021 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2022
2023 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2024
2025 // Get the lower (x & y) components of our final float4.
2026 auto Lo = CallInst::Create(NewF, X, "", CI);
2027
2028 // Get the higher (z & w) components of our final float4.
2029 auto Hi = CallInst::Create(NewF, Y, "", CI);
2030
2031 Constant *ShuffleMask[4] = {
2032 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2033 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2034
2035 // Combine our two float2's into one float4.
2036 auto Combine = new ShuffleVectorInst(
2037 Lo, Hi, ConstantVector::get(ShuffleMask), "", CI);
2038
2039 CI->replaceAllUsesWith(Combine);
2040
2041 // Lastly, remember to remove the user.
2042 ToRemoves.push_back(CI);
2043 }
2044 }
2045
2046 Changed = true;
2047
2048 // And cleanup the calls we don't use anymore.
2049 for (auto V : ToRemoves) {
2050 V->eraseFromParent();
2051 }
2052
2053 // And remove the function we don't need either too.
2054 F->eraseFromParent();
2055 }
2056 }
2057
2058 return Changed;
2059}
2060
David Neto22f144c2017-06-12 14:26:21 -04002061bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
2062 bool Changed = false;
2063
2064 const std::vector<const char *> Map = {"_Z11vstore_halffjPU3AS1Dh",
2065 "_Z15vstore_half_rtefjPU3AS1Dh",
2066 "_Z15vstore_half_rtzfjPU3AS1Dh"};
2067
2068 for (auto Name : Map) {
2069 // If we find a function with the matching name.
2070 if (auto F = M.getFunction(Name)) {
2071 SmallVector<Instruction *, 4> ToRemoves;
2072
2073 // Walk the users of the function.
2074 for (auto &U : F->uses()) {
2075 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2076 // The value to store.
2077 auto Arg0 = CI->getOperand(0);
2078
2079 // The index argument from vstore_half.
2080 auto Arg1 = CI->getOperand(1);
2081
2082 // The pointer argument from vstore_half.
2083 auto Arg2 = CI->getOperand(2);
2084
David Neto22f144c2017-06-12 14:26:21 -04002085 auto IntTy = Type::getInt32Ty(M.getContext());
2086 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002087 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto17852de2017-05-29 17:29:31 -04002088 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002089
2090 // Our intrinsic to pack a float2 to an int.
2091 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2092
2093 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2094
2095 // Insert our value into a float2 so that we can pack it.
David Neto17852de2017-05-29 17:29:31 -04002096 auto TempVec =
2097 InsertElementInst::Create(UndefValue::get(Float2Ty), Arg0,
2098 ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002099
2100 // Pack the float2 -> half2 (in an int).
2101 auto X = CallInst::Create(NewF, TempVec, "", CI);
2102
David Neto482550a2018-03-24 05:21:07 -07002103 if (clspv::Option::F16BitStorage()) {
David Neto17852de2017-05-29 17:29:31 -04002104 auto ShortTy = Type::getInt16Ty(M.getContext());
2105 auto ShortPointerTy = PointerType::get(
2106 ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002107
David Neto17852de2017-05-29 17:29:31 -04002108 // Truncate our i32 to an i16.
2109 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002110
David Neto17852de2017-05-29 17:29:31 -04002111 // Cast the half* pointer to short*.
2112 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002113
David Neto17852de2017-05-29 17:29:31 -04002114 // Index into the correct address of the casted pointer.
2115 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002116
David Neto17852de2017-05-29 17:29:31 -04002117 // Store to the int* we casted to.
2118 auto Store = new StoreInst(Trunc, Index, CI);
2119
2120 CI->replaceAllUsesWith(Store);
2121 } else {
2122 // We can only write to 32-bit aligned words.
2123 //
2124 // Assuming base is aligned to 32-bits, replace the equivalent of
2125 // vstore_half(value, index, base)
2126 // with:
2127 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2128 // uint32_t write_to_upper_half = index & 1u;
2129 // uint32_t shift = write_to_upper_half << 4;
2130 //
2131 // // Pack the float value as a half number in bottom 16 bits
2132 // // of an i32.
2133 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2134 //
2135 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2136 // ^ ((packed & 0xffff) << shift)
2137 // // We only need relaxed consistency, but OpenCL 1.2 only has
2138 // // sequentially consistent atomics.
2139 // // TODO(dneto): Use relaxed consistency.
2140 // atomic_xor(target_ptr, xor_value)
2141 auto IntPointerTy = PointerType::get(
2142 IntTy, Arg2->getType()->getPointerAddressSpace());
2143
2144 auto Four = ConstantInt::get(IntTy, 4);
2145 auto FFFF = ConstantInt::get(IntTy, 0xffff);
2146
2147 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
2148 // Compute index / 2
2149 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2150 auto BaseI32Ptr = CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2151 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32, "base_i32_ptr", CI);
2152 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2153 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
2154 auto MaskBitsToWrite = BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2155 auto MaskedCurrent = BinaryOperator::CreateAnd(MaskBitsToWrite, CurrentValue, "masked_current", CI);
2156
2157 auto XLowerBits = BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2158 auto NewBitsToWrite = BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2159 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite, "value_to_xor", CI);
2160
2161 // Generate the call to atomi_xor.
2162 SmallVector<Type *, 5> ParamTypes;
2163 // The pointer type.
2164 ParamTypes.push_back(IntPointerTy);
2165 // The Types for memory scope, semantics, and value.
2166 ParamTypes.push_back(IntTy);
2167 ParamTypes.push_back(IntTy);
2168 ParamTypes.push_back(IntTy);
2169 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2170 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
2171
2172 const auto ConstantScopeDevice =
2173 ConstantInt::get(IntTy, spv::ScopeDevice);
2174 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2175 // (SPIR-V Workgroup).
2176 const auto AddrSpaceSemanticsBits =
2177 IntPointerTy->getPointerAddressSpace() == 1
2178 ? spv::MemorySemanticsUniformMemoryMask
2179 : spv::MemorySemanticsWorkgroupMemoryMask;
2180
2181 // We're using relaxed consistency here.
2182 const auto ConstantMemorySemantics =
2183 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2184 AddrSpaceSemanticsBits);
2185
2186 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2187 ConstantMemorySemantics, ValueToXor};
2188 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2189 }
David Neto22f144c2017-06-12 14:26:21 -04002190
2191 // Lastly, remember to remove the user.
2192 ToRemoves.push_back(CI);
2193 }
2194 }
2195
2196 Changed = !ToRemoves.empty();
2197
2198 // And cleanup the calls we don't use anymore.
2199 for (auto V : ToRemoves) {
2200 V->eraseFromParent();
2201 }
2202
2203 // And remove the function we don't need either too.
2204 F->eraseFromParent();
2205 }
2206 }
2207
2208 return Changed;
2209}
2210
2211bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
2212 bool Changed = false;
2213
David Netoe2871522018-06-08 11:09:54 -07002214 const std::vector<const char *> Map = {
2215 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2216 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2217 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2218 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2219 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2220 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2221 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2222 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2223 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2224 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2225 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2226 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2227 };
David Neto22f144c2017-06-12 14:26:21 -04002228
2229 for (auto Name : Map) {
2230 // If we find a function with the matching name.
2231 if (auto F = M.getFunction(Name)) {
2232 SmallVector<Instruction *, 4> ToRemoves;
2233
2234 // Walk the users of the function.
2235 for (auto &U : F->uses()) {
2236 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2237 // The value to store.
2238 auto Arg0 = CI->getOperand(0);
2239
2240 // The index argument from vstore_half.
2241 auto Arg1 = CI->getOperand(1);
2242
2243 // The pointer argument from vstore_half.
2244 auto Arg2 = CI->getOperand(2);
2245
2246 auto IntTy = Type::getInt32Ty(M.getContext());
2247 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2248 auto NewPointerTy = PointerType::get(
2249 IntTy, Arg2->getType()->getPointerAddressSpace());
2250 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2251
2252 // Our intrinsic to pack a float2 to an int.
2253 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2254
2255 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2256
2257 // Turn the packed x & y into the final packing.
2258 auto X = CallInst::Create(NewF, Arg0, "", CI);
2259
2260 // Cast the half* pointer to int*.
2261 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2262
2263 // Index into the correct address of the casted pointer.
2264 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
2265
2266 // Store to the int* we casted to.
2267 auto Store = new StoreInst(X, Index, CI);
2268
2269 CI->replaceAllUsesWith(Store);
2270
2271 // Lastly, remember to remove the user.
2272 ToRemoves.push_back(CI);
2273 }
2274 }
2275
2276 Changed = !ToRemoves.empty();
2277
2278 // And cleanup the calls we don't use anymore.
2279 for (auto V : ToRemoves) {
2280 V->eraseFromParent();
2281 }
2282
2283 // And remove the function we don't need either too.
2284 F->eraseFromParent();
2285 }
2286 }
2287
2288 return Changed;
2289}
2290
2291bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
2292 bool Changed = false;
2293
David Netoe2871522018-06-08 11:09:54 -07002294 const std::vector<const char *> Map = {
2295 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2296 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2297 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2298 "_Z13vstorea_half4Dv4_fjPDh", // private
2299 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2300 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2301 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2302 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2303 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2304 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2305 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2306 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2307 };
David Neto22f144c2017-06-12 14:26:21 -04002308
2309 for (auto Name : Map) {
2310 // If we find a function with the matching name.
2311 if (auto F = M.getFunction(Name)) {
2312 SmallVector<Instruction *, 4> ToRemoves;
2313
2314 // Walk the users of the function.
2315 for (auto &U : F->uses()) {
2316 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2317 // The value to store.
2318 auto Arg0 = CI->getOperand(0);
2319
2320 // The index argument from vstore_half.
2321 auto Arg1 = CI->getOperand(1);
2322
2323 // The pointer argument from vstore_half.
2324 auto Arg2 = CI->getOperand(2);
2325
2326 auto IntTy = Type::getInt32Ty(M.getContext());
2327 auto Int2Ty = VectorType::get(IntTy, 2);
2328 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2329 auto NewPointerTy = PointerType::get(
2330 Int2Ty, Arg2->getType()->getPointerAddressSpace());
2331 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2332
2333 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2334 ConstantInt::get(IntTy, 1)};
2335
2336 // Extract out the x & y components of our to store value.
2337 auto Lo =
2338 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2339 ConstantVector::get(LoShuffleMask), "", CI);
2340
2341 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2342 ConstantInt::get(IntTy, 3)};
2343
2344 // Extract out the z & w components of our to store value.
2345 auto Hi =
2346 new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2347 ConstantVector::get(HiShuffleMask), "", CI);
2348
2349 // Our intrinsic to pack a float2 to an int.
2350 auto SPIRVIntrinsic = "spirv.pack.v2f16";
2351
2352 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2353
2354 // Turn the packed x & y into the final component of our int2.
2355 auto X = CallInst::Create(NewF, Lo, "", CI);
2356
2357 // Turn the packed z & w into the final component of our int2.
2358 auto Y = CallInst::Create(NewF, Hi, "", CI);
2359
2360 auto Combine = InsertElementInst::Create(
2361 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
2362 Combine = InsertElementInst::Create(
2363 Combine, Y, ConstantInt::get(IntTy, 1), "", CI);
2364
2365 // Cast the half* pointer to int2*.
2366 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
2367
2368 // Index into the correct address of the casted pointer.
2369 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
2370
2371 // Store to the int2* we casted to.
2372 auto Store = new StoreInst(Combine, Index, CI);
2373
2374 CI->replaceAllUsesWith(Store);
2375
2376 // Lastly, remember to remove the user.
2377 ToRemoves.push_back(CI);
2378 }
2379 }
2380
2381 Changed = !ToRemoves.empty();
2382
2383 // And cleanup the calls we don't use anymore.
2384 for (auto V : ToRemoves) {
2385 V->eraseFromParent();
2386 }
2387
2388 // And remove the function we don't need either too.
2389 F->eraseFromParent();
2390 }
2391 }
2392
2393 return Changed;
2394}
2395
2396bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2397 bool Changed = false;
2398
2399 const std::map<const char *, const char*> Map = {
2400 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f" },
2401 { "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i", "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f" }
2402 };
2403
2404 for (auto Pair : Map) {
2405 // If we find a function with the matching name.
2406 if (auto F = M.getFunction(Pair.first)) {
2407 SmallVector<Instruction *, 4> ToRemoves;
2408
2409 // Walk the users of the function.
2410 for (auto &U : F->uses()) {
2411 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2412 // The image.
2413 auto Arg0 = CI->getOperand(0);
2414
2415 // The sampler.
2416 auto Arg1 = CI->getOperand(1);
2417
2418 // The coordinate (integer type that we can't handle).
2419 auto Arg2 = CI->getOperand(2);
2420
2421 auto FloatVecTy = VectorType::get(Type::getFloatTy(M.getContext()), Arg2->getType()->getVectorNumElements());
2422
2423 auto NewFType = FunctionType::get(CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy}, false);
2424
2425 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2426
2427 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
2428
2429 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2430
2431 CI->replaceAllUsesWith(NewCI);
2432
2433 // Lastly, remember to remove the user.
2434 ToRemoves.push_back(CI);
2435 }
2436 }
2437
2438 Changed = !ToRemoves.empty();
2439
2440 // And cleanup the calls we don't use anymore.
2441 for (auto V : ToRemoves) {
2442 V->eraseFromParent();
2443 }
2444
2445 // And remove the function we don't need either too.
2446 F->eraseFromParent();
2447 }
2448 }
2449
2450 return Changed;
2451}
2452
2453bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2454 bool Changed = false;
2455
2456 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002457 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
2458 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
2459 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
2460 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
2461 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
2462 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002463 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
2464 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
2465 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
2466 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
2467 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Neil Henning39672102017-09-29 14:33:13 +01002468 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002469
2470 for (auto Pair : Map) {
2471 // If we find a function with the matching name.
2472 if (auto F = M.getFunction(Pair.first)) {
2473 SmallVector<Instruction *, 4> ToRemoves;
2474
2475 // Walk the users of the function.
2476 for (auto &U : F->uses()) {
2477 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2478 auto FType = F->getFunctionType();
2479 SmallVector<Type *, 5> ParamTypes;
2480
2481 // The pointer type.
2482 ParamTypes.push_back(FType->getParamType(0));
2483
2484 auto IntTy = Type::getInt32Ty(M.getContext());
2485
2486 // The memory scope type.
2487 ParamTypes.push_back(IntTy);
2488
2489 // The memory semantics type.
2490 ParamTypes.push_back(IntTy);
2491
2492 if (2 < CI->getNumArgOperands()) {
2493 // The unequal memory semantics type.
2494 ParamTypes.push_back(IntTy);
2495
2496 // The value type.
2497 ParamTypes.push_back(FType->getParamType(2));
2498
2499 // The comparator type.
2500 ParamTypes.push_back(FType->getParamType(1));
2501 } else if (1 < CI->getNumArgOperands()) {
2502 // The value type.
2503 ParamTypes.push_back(FType->getParamType(1));
2504 }
2505
2506 auto NewFType =
2507 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2508 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2509
2510 // We need to map the OpenCL constants to the SPIR-V equivalents.
2511 const auto ConstantScopeDevice =
2512 ConstantInt::get(IntTy, spv::ScopeDevice);
2513 const auto ConstantMemorySemantics = ConstantInt::get(
2514 IntTy, spv::MemorySemanticsUniformMemoryMask |
2515 spv::MemorySemanticsSequentiallyConsistentMask);
2516
2517 SmallVector<Value *, 5> Params;
2518
2519 // The pointer.
2520 Params.push_back(CI->getArgOperand(0));
2521
2522 // The memory scope.
2523 Params.push_back(ConstantScopeDevice);
2524
2525 // The memory semantics.
2526 Params.push_back(ConstantMemorySemantics);
2527
2528 if (2 < CI->getNumArgOperands()) {
2529 // The unequal memory semantics.
2530 Params.push_back(ConstantMemorySemantics);
2531
2532 // The value.
2533 Params.push_back(CI->getArgOperand(2));
2534
2535 // The comparator.
2536 Params.push_back(CI->getArgOperand(1));
2537 } else if (1 < CI->getNumArgOperands()) {
2538 // The value.
2539 Params.push_back(CI->getArgOperand(1));
2540 }
2541
2542 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2543
2544 CI->replaceAllUsesWith(NewCI);
2545
2546 // Lastly, remember to remove the user.
2547 ToRemoves.push_back(CI);
2548 }
2549 }
2550
2551 Changed = !ToRemoves.empty();
2552
2553 // And cleanup the calls we don't use anymore.
2554 for (auto V : ToRemoves) {
2555 V->eraseFromParent();
2556 }
2557
2558 // And remove the function we don't need either too.
2559 F->eraseFromParent();
2560 }
2561 }
2562
Neil Henning39672102017-09-29 14:33:13 +01002563 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002564 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2565 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2566 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2567 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2568 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2569 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2570 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2571 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2572 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2573 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2574 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2575 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2576 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2577 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2578 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2579 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002580 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
2581 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
2582 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
2583 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
2584 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
2585 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
2586 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
2587 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
2588 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
2589 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
2590 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
2591 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
2592 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
2593 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
2594 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
2595 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor}};
2596
2597 for (auto Pair : Map2) {
2598 // If we find a function with the matching name.
2599 if (auto F = M.getFunction(Pair.first)) {
2600 SmallVector<Instruction *, 4> ToRemoves;
2601
2602 // Walk the users of the function.
2603 for (auto &U : F->uses()) {
2604 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2605 auto AtomicOp = new AtomicRMWInst(
2606 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2607 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2608
2609 CI->replaceAllUsesWith(AtomicOp);
2610
2611 // Lastly, remember to remove the user.
2612 ToRemoves.push_back(CI);
2613 }
2614 }
2615
2616 Changed = !ToRemoves.empty();
2617
2618 // And cleanup the calls we don't use anymore.
2619 for (auto V : ToRemoves) {
2620 V->eraseFromParent();
2621 }
2622
2623 // And remove the function we don't need either too.
2624 F->eraseFromParent();
2625 }
2626 }
2627
David Neto22f144c2017-06-12 14:26:21 -04002628 return Changed;
2629}
2630
2631bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
2632 bool Changed = false;
2633
2634 // If we find a function with the matching name.
2635 if (auto F = M.getFunction("_Z5crossDv4_fS_")) {
2636 SmallVector<Instruction *, 4> ToRemoves;
2637
2638 auto IntTy = Type::getInt32Ty(M.getContext());
2639 auto FloatTy = Type::getFloatTy(M.getContext());
2640
2641 Constant *DownShuffleMask[3] = {
2642 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2643 ConstantInt::get(IntTy, 2)};
2644
2645 Constant *UpShuffleMask[4] = {
2646 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2647 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2648
2649 Constant *FloatVec[3] = {
2650 ConstantFP::get(FloatTy, 0.0f), UndefValue::get(FloatTy), UndefValue::get(FloatTy)
2651 };
2652
2653 // Walk the users of the function.
2654 for (auto &U : F->uses()) {
2655 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2656 auto Vec4Ty = CI->getArgOperand(0)->getType();
2657 auto Arg0 = new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2658 auto Arg1 = new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty), ConstantVector::get(DownShuffleMask), "", CI);
2659 auto Vec3Ty = Arg0->getType();
2660
2661 auto NewFType =
2662 FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
2663
2664 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
2665
2666 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
2667
2668 auto Result = new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec), ConstantVector::get(UpShuffleMask), "", CI);
2669
2670 CI->replaceAllUsesWith(Result);
2671
2672 // Lastly, remember to remove the user.
2673 ToRemoves.push_back(CI);
2674 }
2675 }
2676
2677 Changed = !ToRemoves.empty();
2678
2679 // And cleanup the calls we don't use anymore.
2680 for (auto V : ToRemoves) {
2681 V->eraseFromParent();
2682 }
2683
2684 // And remove the function we don't need either too.
2685 F->eraseFromParent();
2686 }
2687
2688 return Changed;
2689}
David Neto62653202017-10-16 19:05:18 -04002690
2691bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2692 bool Changed = false;
2693
2694 // OpenCL's float result = fract(float x, float* ptr)
2695 //
2696 // In the LLVM domain:
2697 //
2698 // %floor_result = call spir_func float @floor(float %x)
2699 // store float %floor_result, float * %ptr
2700 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2701 // %result = call spir_func float
2702 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2703 //
2704 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2705 // and clspv.fract occur in the SPIR-V generator pass:
2706 //
2707 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2708 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2709 // ...
2710 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2711 // OpStore %ptr %floor_result
2712 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2713 // %fract_result = OpExtInst %float
2714 // %glsl_ext Fmin %fract_intermediate %just_under_1
2715
2716
2717 using std::string;
2718
2719 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2720 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
2721 using QuadType = std::tuple<const char *, const char *, const char *, const char *>;
2722 auto make_quad = [](const char *a, const char *b, const char *c,
2723 const char *d) {
2724 return std::tuple<const char *, const char *, const char *, const char *>(
2725 a, b, c, d);
2726 };
2727 const std::vector<QuadType> Functions = {
2728 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
2729 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff", "clspv.fract.v2f"),
2730 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff", "clspv.fract.v3f"),
2731 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff", "clspv.fract.v4f"),
2732 };
2733
2734 for (auto& quad : Functions) {
2735 const StringRef fract_name(std::get<0>(quad));
2736
2737 // If we find a function with the matching name.
2738 if (auto F = M.getFunction(fract_name)) {
2739 if (F->use_begin() == F->use_end())
2740 continue;
2741
2742 // We have some uses.
2743 Changed = true;
2744
2745 auto& Context = M.getContext();
2746
2747 const StringRef floor_name(std::get<1>(quad));
2748 const StringRef fmin_name(std::get<2>(quad));
2749 const StringRef clspv_fract_name(std::get<3>(quad));
2750
2751 // This is either float or a float vector. All the float-like
2752 // types are this type.
2753 auto result_ty = F->getReturnType();
2754
2755 Function* fmin_fn = M.getFunction(fmin_name);
2756 if (!fmin_fn) {
2757 // Make the fmin function.
2758 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty, result_ty}, false);
2759 fmin_fn = cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002760 fmin_fn->addFnAttr(Attribute::ReadNone);
2761 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2762 }
2763
2764 Function* floor_fn = M.getFunction(floor_name);
2765 if (!floor_fn) {
2766 // Make the floor function.
2767 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2768 floor_fn = cast<Function>(M.getOrInsertFunction(floor_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002769 floor_fn->addFnAttr(Attribute::ReadNone);
2770 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2771 }
2772
2773 Function* clspv_fract_fn = M.getFunction(clspv_fract_name);
2774 if (!clspv_fract_fn) {
2775 // Make the clspv_fract function.
2776 FunctionType* fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2777 clspv_fract_fn = cast<Function>(M.getOrInsertFunction(clspv_fract_name, fn_ty));
David Neto62653202017-10-16 19:05:18 -04002778 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2779 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2780 }
2781
2782 // Number of significant significand bits, whether represented or not.
2783 unsigned num_significand_bits;
2784 switch (result_ty->getScalarType()->getTypeID()) {
2785 case Type::HalfTyID:
2786 num_significand_bits = 11;
2787 break;
2788 case Type::FloatTyID:
2789 num_significand_bits = 24;
2790 break;
2791 case Type::DoubleTyID:
2792 num_significand_bits = 53;
2793 break;
2794 default:
2795 assert(false && "Unhandled float type when processing fract builtin");
2796 break;
2797 }
2798 // Beware that the disassembler displays this value as
2799 // OpConstant %float 1
2800 // which is not quite right.
2801 const double kJustUnderOneScalar =
2802 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2803
2804 Constant *just_under_one =
2805 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2806 if (result_ty->isVectorTy()) {
2807 just_under_one = ConstantVector::getSplat(
2808 result_ty->getVectorNumElements(), just_under_one);
2809 }
2810
2811 IRBuilder<> Builder(Context);
2812
2813 SmallVector<Instruction *, 4> ToRemoves;
2814
2815 // Walk the users of the function.
2816 for (auto &U : F->uses()) {
2817 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2818
2819 Builder.SetInsertPoint(CI);
2820 auto arg = CI->getArgOperand(0);
2821 auto ptr = CI->getArgOperand(1);
2822
2823 // Compute floor result and store it.
2824 auto floor = Builder.CreateCall(floor_fn, {arg});
2825 Builder.CreateStore(floor, ptr);
2826
2827 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2828 auto fract_result = Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2829
2830 CI->replaceAllUsesWith(fract_result);
2831
2832 // Lastly, remember to remove the user.
2833 ToRemoves.push_back(CI);
2834 }
2835 }
2836
2837 // And cleanup the calls we don't use anymore.
2838 for (auto V : ToRemoves) {
2839 V->eraseFromParent();
2840 }
2841
2842 // And remove the function we don't need either too.
2843 F->eraseFromParent();
2844 }
2845 }
2846
2847 return Changed;
2848}