blob: d9ceef970b39924c5e136695c3e2922b2ae3d83e [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
David Neto118188e2018-08-24 11:27:54 -040021#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040022#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040023#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000024#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040025#include "llvm/Pass.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040029
David Neto118188e2018-08-24 11:27:54 -040030#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040031
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040032#include "clspv/Option.h"
David Neto482550a2018-03-24 05:21:07 -070033
Diego Novilloa4c44fa2019-04-11 10:56:15 -040034#include "Passes.h"
35#include "SPIRVOp.h"
36
David Neto22f144c2017-06-12 14:26:21 -040037using namespace llvm;
38
39#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
40
41namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000042
43struct ArgTypeInfo {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044 enum class SignedNess { None, Unsigned, Signed };
Kévin Petit8a560882019-03-21 15:24:34 +000045 SignedNess signedness;
46};
47
48struct FunctionInfo {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000049 StringRef name;
Kévin Petit8a560882019-03-21 15:24:34 +000050 std::vector<ArgTypeInfo> argTypeInfos;
Kévin Petit8a560882019-03-21 15:24:34 +000051
Kévin Petit91bc72e2019-04-08 15:17:46 +010052 bool isArgSigned(size_t arg) const {
53 assert(argTypeInfos.size() > arg);
54 return argTypeInfos[arg].signedness == ArgTypeInfo::SignedNess::Signed;
Kévin Petit8a560882019-03-21 15:24:34 +000055 }
56
Kévin Petit91bc72e2019-04-08 15:17:46 +010057 static FunctionInfo getFromMangledName(StringRef name) {
58 FunctionInfo fi;
59 if (!getFromMangledNameCheck(name, &fi)) {
60 llvm_unreachable("Can't parse mangled function name!");
Kévin Petit8a560882019-03-21 15:24:34 +000061 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010062 return fi;
63 }
Kévin Petit8a560882019-03-21 15:24:34 +000064
Kévin Petit91bc72e2019-04-08 15:17:46 +010065 static bool getFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
66 if (!name.consume_front("_Z")) {
67 return false;
68 }
69 size_t nameLen;
70 if (name.consumeInteger(10, nameLen)) {
Kévin Petit8a560882019-03-21 15:24:34 +000071 return false;
72 }
73
Kévin Petit91bc72e2019-04-08 15:17:46 +010074 finfo->name = name.take_front(nameLen);
75 name = name.drop_front(nameLen);
Kévin Petit8a560882019-03-21 15:24:34 +000076
Kévin Petit91bc72e2019-04-08 15:17:46 +010077 ArgTypeInfo prev_ti;
Kévin Petit8a560882019-03-21 15:24:34 +000078
Kévin Petit91bc72e2019-04-08 15:17:46 +010079 while (name.size() != 0) {
80
81 ArgTypeInfo ti;
82
83 // Try parsing a vector prefix
84 if (name.consume_front("Dv")) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040085 int numElems;
86 if (name.consumeInteger(10, numElems)) {
87 return false;
88 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010089
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040090 if (!name.consume_front("_")) {
91 return false;
92 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010093 }
94
95 // Parse the base type
96 char typeCode = name.front();
97 name = name.drop_front(1);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040098 switch (typeCode) {
Kévin Petit91bc72e2019-04-08 15:17:46 +010099 case 'c': // char
100 case 'a': // signed char
101 case 's': // short
102 case 'i': // int
103 case 'l': // long
104 ti.signedness = ArgTypeInfo::SignedNess::Signed;
105 break;
106 case 'h': // unsigned char
107 case 't': // unsigned short
108 case 'j': // unsigned int
109 case 'm': // unsigned long
110 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
111 break;
112 case 'f':
113 ti.signedness = ArgTypeInfo::SignedNess::None;
114 break;
115 case 'S':
116 ti = prev_ti;
117 if (!name.consume_front("_")) {
118 return false;
119 }
120 break;
121 default:
122 return false;
123 }
124
125 finfo->argTypeInfos.push_back(ti);
126
127 prev_ti = ti;
128 }
129
130 return true;
131 };
Kévin Petit8a560882019-03-21 15:24:34 +0000132};
133
David Neto22f144c2017-06-12 14:26:21 -0400134uint32_t clz(uint32_t v) {
135 uint32_t r;
136 uint32_t shift;
137
138 r = (v > 0xFFFF) << 4;
139 v >>= r;
140 shift = (v > 0xFF) << 3;
141 v >>= shift;
142 r |= shift;
143 shift = (v > 0xF) << 2;
144 v >>= shift;
145 r |= shift;
146 shift = (v > 0x3) << 1;
147 v >>= shift;
148 r |= shift;
149 r |= (v >> 1);
150
151 return r;
152}
153
154Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
155 if (1 == elements) {
156 return Type::getInt1Ty(C);
157 } else {
158 return VectorType::get(Type::getInt1Ty(C), elements);
159 }
160}
161
162struct ReplaceOpenCLBuiltinPass final : public ModulePass {
163 static char ID;
164 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
165
166 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000167 bool replaceAbs(Module &M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100168 bool replaceAbsDiff(Module &M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100169 bool replaceCopysign(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400170 bool replaceRecip(Module &M);
171 bool replaceDivide(Module &M);
Kévin Petit1329a002019-06-15 05:54:05 +0100172 bool replaceDot(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400173 bool replaceExp10(Module &M);
174 bool replaceLog10(Module &M);
175 bool replaceBarrier(Module &M);
176 bool replaceMemFence(Module &M);
177 bool replaceRelational(Module &M);
178 bool replaceIsInfAndIsNan(Module &M);
179 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000180 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000181 bool replaceRotate(Module &M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000182 bool replaceConvert(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000183 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000184 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000185 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000186 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400187 bool replaceSignbit(Module &M);
188 bool replaceMadandMad24andMul24(Module &M);
189 bool replaceVloadHalf(Module &M);
190 bool replaceVloadHalf2(Module &M);
191 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700192 bool replaceClspvVloadaHalf2(Module &M);
193 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400194 bool replaceVstoreHalf(Module &M);
195 bool replaceVstoreHalf2(Module &M);
196 bool replaceVstoreHalf4(Module &M);
197 bool replaceReadImageF(Module &M);
198 bool replaceAtomics(Module &M);
199 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400200 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700201 bool replaceVload(Module &M);
202 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400203};
Kévin Petit91bc72e2019-04-08 15:17:46 +0100204} // namespace
David Neto22f144c2017-06-12 14:26:21 -0400205
206char ReplaceOpenCLBuiltinPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400207INITIALIZE_PASS(ReplaceOpenCLBuiltinPass, "ReplaceOpenCLBuiltin",
208 "Replace OpenCL Builtins Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -0400209
210namespace clspv {
211ModulePass *createReplaceOpenCLBuiltinPass() {
212 return new ReplaceOpenCLBuiltinPass();
213}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400214} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -0400215
216bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
217 bool Changed = false;
218
Kévin Petit2444e9b2018-11-09 14:14:37 +0000219 Changed |= replaceAbs(M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100220 Changed |= replaceAbsDiff(M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100221 Changed |= replaceCopysign(M);
David Neto22f144c2017-06-12 14:26:21 -0400222 Changed |= replaceRecip(M);
223 Changed |= replaceDivide(M);
Kévin Petit1329a002019-06-15 05:54:05 +0100224 Changed |= replaceDot(M);
David Neto22f144c2017-06-12 14:26:21 -0400225 Changed |= replaceExp10(M);
226 Changed |= replaceLog10(M);
227 Changed |= replaceBarrier(M);
228 Changed |= replaceMemFence(M);
229 Changed |= replaceRelational(M);
230 Changed |= replaceIsInfAndIsNan(M);
231 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000232 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000233 Changed |= replaceRotate(M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000234 Changed |= replaceConvert(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000235 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000236 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000237 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000238 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400239 Changed |= replaceSignbit(M);
240 Changed |= replaceMadandMad24andMul24(M);
241 Changed |= replaceVloadHalf(M);
242 Changed |= replaceVloadHalf2(M);
243 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700244 Changed |= replaceClspvVloadaHalf2(M);
245 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400246 Changed |= replaceVstoreHalf(M);
247 Changed |= replaceVstoreHalf2(M);
248 Changed |= replaceVstoreHalf4(M);
249 Changed |= replaceReadImageF(M);
250 Changed |= replaceAtomics(M);
251 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400252 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700253 Changed |= replaceVload(M);
254 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400255
256 return Changed;
257}
258
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400259bool replaceCallsWithValue(Module &M, std::vector<const char *> Names,
260 std::function<Value *(CallInst *)> Replacer) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000261
Kévin Petite8edce32019-04-10 14:23:32 +0100262 bool Changed = false;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000263
264 for (auto Name : Names) {
265 // If we find a function with the matching name.
266 if (auto F = M.getFunction(Name)) {
267 SmallVector<Instruction *, 4> ToRemoves;
268
269 // Walk the users of the function.
270 for (auto &U : F->uses()) {
271 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000272
Kévin Petite8edce32019-04-10 14:23:32 +0100273 auto NewValue = Replacer(CI);
274
275 if (NewValue != nullptr) {
276 CI->replaceAllUsesWith(NewValue);
277 }
Kévin Petit2444e9b2018-11-09 14:14:37 +0000278
279 // Lastly, remember to remove the user.
280 ToRemoves.push_back(CI);
281 }
282 }
283
284 Changed = !ToRemoves.empty();
285
286 // And cleanup the calls we don't use anymore.
287 for (auto V : ToRemoves) {
288 V->eraseFromParent();
289 }
290
291 // And remove the function we don't need either too.
292 F->eraseFromParent();
293 }
294 }
295
296 return Changed;
297}
298
Kévin Petite8edce32019-04-10 14:23:32 +0100299bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
Kévin Petit91bc72e2019-04-08 15:17:46 +0100300
Kévin Petite8edce32019-04-10 14:23:32 +0100301 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400302 "_Z3absh", "_Z3absDv2_h", "_Z3absDv3_h", "_Z3absDv4_h",
303 "_Z3abst", "_Z3absDv2_t", "_Z3absDv3_t", "_Z3absDv4_t",
304 "_Z3absj", "_Z3absDv2_j", "_Z3absDv3_j", "_Z3absDv4_j",
305 "_Z3absm", "_Z3absDv2_m", "_Z3absDv3_m", "_Z3absDv4_m",
Kévin Petite8edce32019-04-10 14:23:32 +0100306 };
307
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400308 return replaceCallsWithValue(M, Names,
309 [](CallInst *CI) { return CI->getOperand(0); });
Kévin Petite8edce32019-04-10 14:23:32 +0100310}
311
312bool ReplaceOpenCLBuiltinPass::replaceAbsDiff(Module &M) {
313
314 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400315 "_Z8abs_diffcc", "_Z8abs_diffDv2_cS_", "_Z8abs_diffDv3_cS_",
316 "_Z8abs_diffDv4_cS_", "_Z8abs_diffhh", "_Z8abs_diffDv2_hS_",
317 "_Z8abs_diffDv3_hS_", "_Z8abs_diffDv4_hS_", "_Z8abs_diffss",
318 "_Z8abs_diffDv2_sS_", "_Z8abs_diffDv3_sS_", "_Z8abs_diffDv4_sS_",
319 "_Z8abs_difftt", "_Z8abs_diffDv2_tS_", "_Z8abs_diffDv3_tS_",
320 "_Z8abs_diffDv4_tS_", "_Z8abs_diffii", "_Z8abs_diffDv2_iS_",
321 "_Z8abs_diffDv3_iS_", "_Z8abs_diffDv4_iS_", "_Z8abs_diffjj",
322 "_Z8abs_diffDv2_jS_", "_Z8abs_diffDv3_jS_", "_Z8abs_diffDv4_jS_",
323 "_Z8abs_diffll", "_Z8abs_diffDv2_lS_", "_Z8abs_diffDv3_lS_",
324 "_Z8abs_diffDv4_lS_", "_Z8abs_diffmm", "_Z8abs_diffDv2_mS_",
325 "_Z8abs_diffDv3_mS_", "_Z8abs_diffDv4_mS_",
Kévin Petit91bc72e2019-04-08 15:17:46 +0100326 };
327
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400328 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100329 auto XValue = CI->getOperand(0);
330 auto YValue = CI->getOperand(1);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100331
Kévin Petite8edce32019-04-10 14:23:32 +0100332 IRBuilder<> Builder(CI);
333 auto XmY = Builder.CreateSub(XValue, YValue);
334 auto YmX = Builder.CreateSub(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100335
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400336 Value *Cmp;
Kévin Petite8edce32019-04-10 14:23:32 +0100337 auto F = CI->getCalledFunction();
338 auto finfo = FunctionInfo::getFromMangledName(F->getName());
339 if (finfo.isArgSigned(0)) {
340 Cmp = Builder.CreateICmpSGT(YValue, XValue);
341 } else {
342 Cmp = Builder.CreateICmpUGT(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100343 }
Kévin Petit91bc72e2019-04-08 15:17:46 +0100344
Kévin Petite8edce32019-04-10 14:23:32 +0100345 return Builder.CreateSelect(Cmp, YmX, XmY);
346 });
Kévin Petit91bc72e2019-04-08 15:17:46 +0100347}
348
Kévin Petit8c1be282019-04-02 19:34:25 +0100349bool ReplaceOpenCLBuiltinPass::replaceCopysign(Module &M) {
Kévin Petit8c1be282019-04-02 19:34:25 +0100350
Kévin Petite8edce32019-04-10 14:23:32 +0100351 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400352 "_Z8copysignff",
353 "_Z8copysignDv2_fS_",
354 "_Z8copysignDv3_fS_",
355 "_Z8copysignDv4_fS_",
Kévin Petit8c1be282019-04-02 19:34:25 +0100356 };
357
Kévin Petite8edce32019-04-10 14:23:32 +0100358 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
359 auto XValue = CI->getOperand(0);
360 auto YValue = CI->getOperand(1);
Kévin Petit8c1be282019-04-02 19:34:25 +0100361
Kévin Petite8edce32019-04-10 14:23:32 +0100362 auto Ty = XValue->getType();
Kévin Petit8c1be282019-04-02 19:34:25 +0100363
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400364 Type *IntTy = Type::getIntNTy(M.getContext(), Ty->getScalarSizeInBits());
Kévin Petite8edce32019-04-10 14:23:32 +0100365 if (Ty->isVectorTy()) {
366 IntTy = VectorType::get(IntTy, Ty->getVectorNumElements());
Kévin Petit8c1be282019-04-02 19:34:25 +0100367 }
Kévin Petit8c1be282019-04-02 19:34:25 +0100368
Kévin Petite8edce32019-04-10 14:23:32 +0100369 // Return X with the sign of Y
370
371 // Sign bit masks
372 auto SignBit = IntTy->getScalarSizeInBits() - 1;
373 auto SignBitMask = 1 << SignBit;
374 auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
375 auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);
376
377 IRBuilder<> Builder(CI);
378
379 // Extract sign of Y
380 auto YInt = Builder.CreateBitCast(YValue, IntTy);
381 auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);
382
383 // Clear sign bit in X
384 auto XInt = Builder.CreateBitCast(XValue, IntTy);
385 XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);
386
387 // Insert sign bit of Y into X
388 auto NewXInt = Builder.CreateOr(XInt, YSign);
389
390 // And cast back to floating-point
391 return Builder.CreateBitCast(NewXInt, Ty);
392 });
Kévin Petit8c1be282019-04-02 19:34:25 +0100393}
394
David Neto22f144c2017-06-12 14:26:21 -0400395bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400396
Kévin Petite8edce32019-04-10 14:23:32 +0100397 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400398 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
399 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
400 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
401 };
402
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400403 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100404 // Recip has one arg.
405 auto Arg = CI->getOperand(0);
406 auto Cst1 = ConstantFP::get(Arg->getType(), 1.0);
407 return BinaryOperator::Create(Instruction::FDiv, Cst1, Arg, "", CI);
408 });
David Neto22f144c2017-06-12 14:26:21 -0400409}
410
411bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400412
Kévin Petite8edce32019-04-10 14:23:32 +0100413 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400414 "_Z11half_divideff", "_Z13native_divideff",
415 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
416 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
417 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
418 };
419
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400420 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100421 auto Op0 = CI->getOperand(0);
422 auto Op1 = CI->getOperand(1);
423 return BinaryOperator::Create(Instruction::FDiv, Op0, Op1, "", CI);
424 });
David Neto22f144c2017-06-12 14:26:21 -0400425}
426
Kévin Petit1329a002019-06-15 05:54:05 +0100427bool ReplaceOpenCLBuiltinPass::replaceDot(Module &M) {
428
429 std::vector<const char *> Names = {
430 "_Z3dotff",
431 "_Z3dotDv2_fS_",
432 "_Z3dotDv3_fS_",
433 "_Z3dotDv4_fS_",
434 };
435
436 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
437 auto Op0 = CI->getOperand(0);
438 auto Op1 = CI->getOperand(1);
439
440 Value *V;
441 if (Op0->getType()->isVectorTy()) {
442 V = clspv::InsertSPIRVOp(CI, spv::OpDot, {Attribute::ReadNone},
443 CI->getType(), {Op0, Op1});
444 } else {
445 V = BinaryOperator::Create(Instruction::FMul, Op0, Op1, "", CI);
446 }
447
448 return V;
449 });
450}
451
David Neto22f144c2017-06-12 14:26:21 -0400452bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
453 bool Changed = false;
454
455 const std::map<const char *, const char *> Map = {
456 {"_Z5exp10f", "_Z3expf"},
457 {"_Z10half_exp10f", "_Z8half_expf"},
458 {"_Z12native_exp10f", "_Z10native_expf"},
459 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
460 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
461 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
462 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
463 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
464 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
465 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
466 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
467 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
468
469 for (auto Pair : Map) {
470 // If we find a function with the matching name.
471 if (auto F = M.getFunction(Pair.first)) {
472 SmallVector<Instruction *, 4> ToRemoves;
473
474 // Walk the users of the function.
475 for (auto &U : F->uses()) {
476 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
477 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
478
479 auto Arg = CI->getOperand(0);
480
481 // Constant of the natural log of 10 (ln(10)).
482 const double Ln10 =
483 2.302585092994045684017991454684364207601101488628772976033;
484
485 auto Mul = BinaryOperator::Create(
486 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
487 CI);
488
489 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
490
491 CI->replaceAllUsesWith(NewCI);
492
493 // Lastly, remember to remove the user.
494 ToRemoves.push_back(CI);
495 }
496 }
497
498 Changed = !ToRemoves.empty();
499
500 // And cleanup the calls we don't use anymore.
501 for (auto V : ToRemoves) {
502 V->eraseFromParent();
503 }
504
505 // And remove the function we don't need either too.
506 F->eraseFromParent();
507 }
508 }
509
510 return Changed;
511}
512
513bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
514 bool Changed = false;
515
516 const std::map<const char *, const char *> Map = {
517 {"_Z5log10f", "_Z3logf"},
518 {"_Z10half_log10f", "_Z8half_logf"},
519 {"_Z12native_log10f", "_Z10native_logf"},
520 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
521 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
522 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
523 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
524 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
525 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
526 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
527 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
528 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
529
530 for (auto Pair : Map) {
531 // If we find a function with the matching name.
532 if (auto F = M.getFunction(Pair.first)) {
533 SmallVector<Instruction *, 4> ToRemoves;
534
535 // Walk the users of the function.
536 for (auto &U : F->uses()) {
537 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
538 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
539
540 auto Arg = CI->getOperand(0);
541
542 // Constant of the reciprocal of the natural log of 10 (ln(10)).
543 const double Ln10 =
544 0.434294481903251827651128918916605082294397005803666566114;
545
546 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
547
548 auto Mul = BinaryOperator::Create(
549 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
550 "", CI);
551
552 CI->replaceAllUsesWith(Mul);
553
554 // Lastly, remember to remove the user.
555 ToRemoves.push_back(CI);
556 }
557 }
558
559 Changed = !ToRemoves.empty();
560
561 // And cleanup the calls we don't use anymore.
562 for (auto V : ToRemoves) {
563 V->eraseFromParent();
564 }
565
566 // And remove the function we don't need either too.
567 F->eraseFromParent();
568 }
569 }
570
571 return Changed;
572}
573
574bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400575
576 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
577
Kévin Petitc4643922019-06-17 19:32:05 +0100578 const std::vector<const char *> Names = {
579 {"_Z7barrierj"},
580 };
David Neto22f144c2017-06-12 14:26:21 -0400581
Kévin Petitc4643922019-06-17 19:32:05 +0100582 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
583 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400584
Kévin Petitc4643922019-06-17 19:32:05 +0100585 // We need to map the OpenCL constants to the SPIR-V equivalents.
586 const auto LocalMemFence =
587 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
588 const auto GlobalMemFence =
589 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
590 const auto ConstantSequentiallyConsistent = ConstantInt::get(
591 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
592 const auto ConstantScopeDevice =
593 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
594 const auto ConstantScopeWorkgroup =
595 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
David Neto22f144c2017-06-12 14:26:21 -0400596
Kévin Petitc4643922019-06-17 19:32:05 +0100597 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
598 const auto LocalMemFenceMask =
599 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
600 const auto WorkgroupShiftAmount =
601 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
602 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
603 Instruction::Shl, LocalMemFenceMask,
604 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400605
Kévin Petitc4643922019-06-17 19:32:05 +0100606 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
607 const auto GlobalMemFenceMask =
608 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
609 const auto UniformShiftAmount =
610 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
611 const auto MemorySemanticsUniform = BinaryOperator::Create(
612 Instruction::Shl, GlobalMemFenceMask,
613 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400614
Kévin Petitc4643922019-06-17 19:32:05 +0100615 // And combine the above together, also adding in
616 // MemorySemanticsSequentiallyConsistentMask.
617 auto MemorySemantics =
618 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
619 ConstantSequentiallyConsistent, "", CI);
620 MemorySemantics = BinaryOperator::Create(Instruction::Or, MemorySemantics,
621 MemorySemanticsUniform, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400622
Kévin Petitc4643922019-06-17 19:32:05 +0100623 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
624 // Device Scope, otherwise Workgroup Scope.
625 const auto Cmp =
626 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, GlobalMemFenceMask,
627 GlobalMemFence, "", CI);
628 const auto MemoryScope = SelectInst::Create(Cmp, ConstantScopeDevice,
629 ConstantScopeWorkgroup, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400630
Kévin Petitc4643922019-06-17 19:32:05 +0100631 // Lastly, the Execution Scope is always Workgroup Scope.
632 const auto ExecutionScope = ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400633
Kévin Petitc4643922019-06-17 19:32:05 +0100634 return clspv::InsertSPIRVOp(CI, spv::OpControlBarrier,
635 {Attribute::NoDuplicate}, CI->getType(),
636 {ExecutionScope, MemoryScope, MemorySemantics});
637 });
David Neto22f144c2017-06-12 14:26:21 -0400638}
639
640bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
641 bool Changed = false;
642
643 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
644
Kévin Petitc4643922019-06-17 19:32:05 +0100645 using Tuple = std::tuple<spv::Op, unsigned>;
Neil Henning39672102017-09-29 14:33:13 +0100646 const std::map<const char *, Tuple> Map = {
Kévin Petitc4643922019-06-17 19:32:05 +0100647 {"_Z9mem_fencej", Tuple(spv::OpMemoryBarrier,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400648 spv::MemorySemanticsSequentiallyConsistentMask)},
Neil Henning39672102017-09-29 14:33:13 +0100649 {"_Z14read_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100650 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsAcquireMask)},
Neil Henning39672102017-09-29 14:33:13 +0100651 {"_Z15write_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100652 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400653
654 for (auto Pair : Map) {
655 // If we find a function with the matching name.
656 if (auto F = M.getFunction(Pair.first)) {
657 SmallVector<Instruction *, 4> ToRemoves;
658
659 // Walk the users of the function.
660 for (auto &U : F->uses()) {
661 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -0400662
663 auto Arg = CI->getOperand(0);
664
665 // We need to map the OpenCL constants to the SPIR-V equivalents.
666 const auto LocalMemFence =
667 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
668 const auto GlobalMemFence =
669 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
670 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100671 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400672 const auto ConstantScopeDevice =
673 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
674
675 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
676 const auto LocalMemFenceMask = BinaryOperator::Create(
677 Instruction::And, LocalMemFence, Arg, "", CI);
678 const auto WorkgroupShiftAmount =
679 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
680 clz(CLK_LOCAL_MEM_FENCE);
681 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
682 Instruction::Shl, LocalMemFenceMask,
683 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
684
685 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
686 const auto GlobalMemFenceMask = BinaryOperator::Create(
687 Instruction::And, GlobalMemFence, Arg, "", CI);
688 const auto UniformShiftAmount =
689 clz(spv::MemorySemanticsUniformMemoryMask) -
690 clz(CLK_GLOBAL_MEM_FENCE);
691 const auto MemorySemanticsUniform = BinaryOperator::Create(
692 Instruction::Shl, GlobalMemFenceMask,
693 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
694
695 // And combine the above together, also adding in
696 // MemorySemanticsSequentiallyConsistentMask.
697 auto MemorySemantics =
698 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
699 ConstantMemorySemantics, "", CI);
700 MemorySemantics = BinaryOperator::Create(
701 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
702
703 // Memory Scope is always device.
704 const auto MemoryScope = ConstantScopeDevice;
705
Kévin Petitc4643922019-06-17 19:32:05 +0100706 const auto SPIRVOp = std::get<0>(Pair.second);
707 auto NewCI = clspv::InsertSPIRVOp(CI, SPIRVOp, {}, CI->getType(),
708 {MemoryScope, MemorySemantics});
David Neto22f144c2017-06-12 14:26:21 -0400709
710 CI->replaceAllUsesWith(NewCI);
711
712 // Lastly, remember to remove the user.
713 ToRemoves.push_back(CI);
714 }
715 }
716
717 Changed = !ToRemoves.empty();
718
719 // And cleanup the calls we don't use anymore.
720 for (auto V : ToRemoves) {
721 V->eraseFromParent();
722 }
723
724 // And remove the function we don't need either too.
725 F->eraseFromParent();
726 }
727 }
728
729 return Changed;
730}
731
732bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
733 bool Changed = false;
734
735 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
736 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
737 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
738 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
739 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
740 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
741 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
742 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
743 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
744 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
745 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
746 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
747 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
748 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
749 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
750 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
751 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
752 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
753 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
754 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
755 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
756 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
757 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
758 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
759 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
760 };
761
762 for (auto Pair : Map) {
763 // If we find a function with the matching name.
764 if (auto F = M.getFunction(Pair.first)) {
765 SmallVector<Instruction *, 4> ToRemoves;
766
767 // Walk the users of the function.
768 for (auto &U : F->uses()) {
769 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
770 // The predicate to use in the CmpInst.
771 auto Predicate = Pair.second.first;
772
773 // The value to return for true.
774 auto TrueValue =
775 ConstantInt::getSigned(CI->getType(), Pair.second.second);
776
777 // The value to return for false.
778 auto FalseValue = Constant::getNullValue(CI->getType());
779
780 auto Arg1 = CI->getOperand(0);
781 auto Arg2 = CI->getOperand(1);
782
783 const auto Cmp =
784 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
785
786 const auto Select =
787 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
788
789 CI->replaceAllUsesWith(Select);
790
791 // Lastly, remember to remove the user.
792 ToRemoves.push_back(CI);
793 }
794 }
795
796 Changed = !ToRemoves.empty();
797
798 // And cleanup the calls we don't use anymore.
799 for (auto V : ToRemoves) {
800 V->eraseFromParent();
801 }
802
803 // And remove the function we don't need either too.
804 F->eraseFromParent();
805 }
806 }
807
808 return Changed;
809}
810
811bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
812 bool Changed = false;
813
Kévin Petitff03aee2019-06-12 19:39:03 +0100814 const std::map<const char *, std::pair<spv::Op, int32_t>> Map = {
815 {"_Z5isinff", {spv::OpIsInf, 1}},
816 {"_Z5isinfDv2_f", {spv::OpIsInf, -1}},
817 {"_Z5isinfDv3_f", {spv::OpIsInf, -1}},
818 {"_Z5isinfDv4_f", {spv::OpIsInf, -1}},
819 {"_Z5isnanf", {spv::OpIsNan, 1}},
820 {"_Z5isnanDv2_f", {spv::OpIsNan, -1}},
821 {"_Z5isnanDv3_f", {spv::OpIsNan, -1}},
822 {"_Z5isnanDv4_f", {spv::OpIsNan, -1}},
David Neto22f144c2017-06-12 14:26:21 -0400823 };
824
825 for (auto Pair : Map) {
826 // If we find a function with the matching name.
827 if (auto F = M.getFunction(Pair.first)) {
828 SmallVector<Instruction *, 4> ToRemoves;
829
830 // Walk the users of the function.
831 for (auto &U : F->uses()) {
832 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
833 const auto CITy = CI->getType();
834
Kévin Petitff03aee2019-06-12 19:39:03 +0100835 auto SPIRVOp = Pair.second.first;
David Neto22f144c2017-06-12 14:26:21 -0400836
837 // The value to return for true.
838 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
839
840 // The value to return for false.
841 auto FalseValue = Constant::getNullValue(CITy);
842
843 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
844 M.getContext(),
845 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
846
Kévin Petitff03aee2019-06-12 19:39:03 +0100847 auto NewCI =
848 clspv::InsertSPIRVOp(CI, SPIRVOp, {Attribute::ReadNone},
849 CorrespondingBoolTy, {CI->getOperand(0)});
David Neto22f144c2017-06-12 14:26:21 -0400850
851 const auto Select =
852 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
853
854 CI->replaceAllUsesWith(Select);
855
856 // Lastly, remember to remove the user.
857 ToRemoves.push_back(CI);
858 }
859 }
860
861 Changed = !ToRemoves.empty();
862
863 // And cleanup the calls we don't use anymore.
864 for (auto V : ToRemoves) {
865 V->eraseFromParent();
866 }
867
868 // And remove the function we don't need either too.
869 F->eraseFromParent();
870 }
871 }
872
873 return Changed;
874}
875
876bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
877 bool Changed = false;
878
Kévin Petitff03aee2019-06-12 19:39:03 +0100879 const std::map<const char *, spv::Op> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000880 // all
Kévin Petitff03aee2019-06-12 19:39:03 +0100881 {"_Z3allc", spv::OpNop},
882 {"_Z3allDv2_c", spv::OpAll},
883 {"_Z3allDv3_c", spv::OpAll},
884 {"_Z3allDv4_c", spv::OpAll},
885 {"_Z3alls", spv::OpNop},
886 {"_Z3allDv2_s", spv::OpAll},
887 {"_Z3allDv3_s", spv::OpAll},
888 {"_Z3allDv4_s", spv::OpAll},
889 {"_Z3alli", spv::OpNop},
890 {"_Z3allDv2_i", spv::OpAll},
891 {"_Z3allDv3_i", spv::OpAll},
892 {"_Z3allDv4_i", spv::OpAll},
893 {"_Z3alll", spv::OpNop},
894 {"_Z3allDv2_l", spv::OpAll},
895 {"_Z3allDv3_l", spv::OpAll},
896 {"_Z3allDv4_l", spv::OpAll},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000897
898 // any
Kévin Petitff03aee2019-06-12 19:39:03 +0100899 {"_Z3anyc", spv::OpNop},
900 {"_Z3anyDv2_c", spv::OpAny},
901 {"_Z3anyDv3_c", spv::OpAny},
902 {"_Z3anyDv4_c", spv::OpAny},
903 {"_Z3anys", spv::OpNop},
904 {"_Z3anyDv2_s", spv::OpAny},
905 {"_Z3anyDv3_s", spv::OpAny},
906 {"_Z3anyDv4_s", spv::OpAny},
907 {"_Z3anyi", spv::OpNop},
908 {"_Z3anyDv2_i", spv::OpAny},
909 {"_Z3anyDv3_i", spv::OpAny},
910 {"_Z3anyDv4_i", spv::OpAny},
911 {"_Z3anyl", spv::OpNop},
912 {"_Z3anyDv2_l", spv::OpAny},
913 {"_Z3anyDv3_l", spv::OpAny},
914 {"_Z3anyDv4_l", spv::OpAny},
David Neto22f144c2017-06-12 14:26:21 -0400915 };
916
917 for (auto Pair : Map) {
918 // If we find a function with the matching name.
919 if (auto F = M.getFunction(Pair.first)) {
920 SmallVector<Instruction *, 4> ToRemoves;
921
922 // Walk the users of the function.
923 for (auto &U : F->uses()) {
924 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -0400925
926 auto Arg = CI->getOperand(0);
927
928 Value *V;
929
Kévin Petitfd27cca2018-10-31 13:00:17 +0000930 // If the argument is a 32-bit int, just use a shift
931 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
932 V = BinaryOperator::Create(Instruction::LShr, Arg,
933 ConstantInt::get(Arg->getType(), 31), "",
934 CI);
935 } else {
David Neto22f144c2017-06-12 14:26:21 -0400936 // The value for zero to compare against.
937 const auto ZeroValue = Constant::getNullValue(Arg->getType());
938
David Neto22f144c2017-06-12 14:26:21 -0400939 // The value to return for true.
940 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
941
942 // The value to return for false.
943 const auto FalseValue = Constant::getNullValue(CI->getType());
944
Kévin Petitfd27cca2018-10-31 13:00:17 +0000945 const auto Cmp = CmpInst::Create(
946 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
947
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400948 Value *SelectSource;
Kévin Petitfd27cca2018-10-31 13:00:17 +0000949
950 // If we have a function to call, call it!
Kévin Petitff03aee2019-06-12 19:39:03 +0100951 const auto SPIRVOp = Pair.second;
Kévin Petitfd27cca2018-10-31 13:00:17 +0000952
Kévin Petitff03aee2019-06-12 19:39:03 +0100953 if (SPIRVOp != spv::OpNop) {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000954
Kévin Petitff03aee2019-06-12 19:39:03 +0100955 const auto BoolTy = Type::getInt1Ty(M.getContext());
Kévin Petitfd27cca2018-10-31 13:00:17 +0000956
Kévin Petitff03aee2019-06-12 19:39:03 +0100957 const auto NewCI = clspv::InsertSPIRVOp(
958 CI, SPIRVOp, {Attribute::ReadNone}, BoolTy, {Cmp});
Kévin Petitfd27cca2018-10-31 13:00:17 +0000959 SelectSource = NewCI;
960
961 } else {
962 SelectSource = Cmp;
963 }
964
965 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400966 }
967
968 CI->replaceAllUsesWith(V);
969
970 // Lastly, remember to remove the user.
971 ToRemoves.push_back(CI);
972 }
973 }
974
975 Changed = !ToRemoves.empty();
976
977 // And cleanup the calls we don't use anymore.
978 for (auto V : ToRemoves) {
979 V->eraseFromParent();
980 }
981
982 // And remove the function we don't need either too.
983 F->eraseFromParent();
984 }
985 }
986
987 return Changed;
988}
989
Kévin Petitbf0036c2019-03-06 13:57:10 +0000990bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
991 bool Changed = false;
992
993 for (auto const &SymVal : M.getValueSymbolTable()) {
994 // Skip symbols whose name doesn't match
995 if (!SymVal.getKey().startswith("_Z8upsample")) {
996 continue;
997 }
998 // Is there a function going by that name?
999 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1000
1001 SmallVector<Instruction *, 4> ToRemoves;
1002
1003 // Walk the users of the function.
1004 for (auto &U : F->uses()) {
1005 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1006
1007 // Get arguments
1008 auto HiValue = CI->getOperand(0);
1009 auto LoValue = CI->getOperand(1);
1010
1011 // Don't touch overloads that aren't in OpenCL C
1012 auto HiType = HiValue->getType();
1013 auto LoType = LoValue->getType();
1014
1015 if (HiType != LoType) {
1016 continue;
1017 }
1018
1019 if (!HiType->isIntOrIntVectorTy()) {
1020 continue;
1021 }
1022
1023 if (HiType->getScalarSizeInBits() * 2 !=
1024 CI->getType()->getScalarSizeInBits()) {
1025 continue;
1026 }
1027
1028 if ((HiType->getScalarSizeInBits() != 8) &&
1029 (HiType->getScalarSizeInBits() != 16) &&
1030 (HiType->getScalarSizeInBits() != 32)) {
1031 continue;
1032 }
1033
1034 if (HiType->isVectorTy()) {
1035 if ((HiType->getVectorNumElements() != 2) &&
1036 (HiType->getVectorNumElements() != 3) &&
1037 (HiType->getVectorNumElements() != 4) &&
1038 (HiType->getVectorNumElements() != 8) &&
1039 (HiType->getVectorNumElements() != 16)) {
1040 continue;
1041 }
1042 }
1043
1044 // Convert both operands to the result type
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001045 auto HiCast =
1046 CastInst::CreateZExtOrBitCast(HiValue, CI->getType(), "", CI);
1047 auto LoCast =
1048 CastInst::CreateZExtOrBitCast(LoValue, CI->getType(), "", CI);
Kévin Petitbf0036c2019-03-06 13:57:10 +00001049
1050 // Shift high operand
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001051 auto ShiftAmount =
1052 ConstantInt::get(CI->getType(), HiType->getScalarSizeInBits());
Kévin Petitbf0036c2019-03-06 13:57:10 +00001053 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1054 ShiftAmount, "", CI);
1055
1056 // OR both results
1057 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1058 "", CI);
1059
1060 // Replace call with the expression
1061 CI->replaceAllUsesWith(V);
1062
1063 // Lastly, remember to remove the user.
1064 ToRemoves.push_back(CI);
1065 }
1066 }
1067
1068 Changed = !ToRemoves.empty();
1069
1070 // And cleanup the calls we don't use anymore.
1071 for (auto V : ToRemoves) {
1072 V->eraseFromParent();
1073 }
1074
1075 // And remove the function we don't need either too.
1076 F->eraseFromParent();
1077 }
1078 }
1079
1080 return Changed;
1081}
1082
Kévin Petitd44eef52019-03-08 13:22:14 +00001083bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1084 bool Changed = false;
1085
1086 for (auto const &SymVal : M.getValueSymbolTable()) {
1087 // Skip symbols whose name doesn't match
1088 if (!SymVal.getKey().startswith("_Z6rotate")) {
1089 continue;
1090 }
1091 // Is there a function going by that name?
1092 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1093
1094 SmallVector<Instruction *, 4> ToRemoves;
1095
1096 // Walk the users of the function.
1097 for (auto &U : F->uses()) {
1098 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1099
1100 // Get arguments
1101 auto SrcValue = CI->getOperand(0);
1102 auto RotAmount = CI->getOperand(1);
1103
1104 // Don't touch overloads that aren't in OpenCL C
1105 auto SrcType = SrcValue->getType();
1106 auto RotType = RotAmount->getType();
1107
1108 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1109 continue;
1110 }
1111
1112 if (!SrcType->isIntOrIntVectorTy()) {
1113 continue;
1114 }
1115
1116 if ((SrcType->getScalarSizeInBits() != 8) &&
1117 (SrcType->getScalarSizeInBits() != 16) &&
1118 (SrcType->getScalarSizeInBits() != 32) &&
1119 (SrcType->getScalarSizeInBits() != 64)) {
1120 continue;
1121 }
1122
1123 if (SrcType->isVectorTy()) {
1124 if ((SrcType->getVectorNumElements() != 2) &&
1125 (SrcType->getVectorNumElements() != 3) &&
1126 (SrcType->getVectorNumElements() != 4) &&
1127 (SrcType->getVectorNumElements() != 8) &&
1128 (SrcType->getVectorNumElements() != 16)) {
1129 continue;
1130 }
1131 }
1132
1133 // The approach used is to shift the top bits down, the bottom bits up
1134 // and OR the two shifted values.
1135
1136 // The rotation amount is to be treated modulo the element size.
1137 // Since SPIR-V shift ops don't support this, let's apply the
1138 // modulo ahead of shifting. The element size is always a power of
1139 // two so we can just AND with a mask.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001140 auto ModMask =
1141 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits() - 1);
Kévin Petitd44eef52019-03-08 13:22:14 +00001142 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1143 ModMask, "", CI);
1144
1145 // Let's calc the amount by which to shift top bits down
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001146 auto ScalarSize =
1147 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits());
Kévin Petitd44eef52019-03-08 13:22:14 +00001148 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1149 RotAmount, "", CI);
1150
1151 // Now shift the bottom bits up and the top bits down
1152 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1153 RotAmount, "", CI);
1154 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1155 DownAmount, "", CI);
1156
1157 // Finally OR the two shifted values
1158 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1159 HiRotated, "", CI);
1160
1161 // Replace call with the expression
1162 CI->replaceAllUsesWith(V);
1163
1164 // Lastly, remember to remove the user.
1165 ToRemoves.push_back(CI);
1166 }
1167 }
1168
1169 Changed = !ToRemoves.empty();
1170
1171 // And cleanup the calls we don't use anymore.
1172 for (auto V : ToRemoves) {
1173 V->eraseFromParent();
1174 }
1175
1176 // And remove the function we don't need either too.
1177 F->eraseFromParent();
1178 }
1179 }
1180
1181 return Changed;
1182}
1183
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001184bool ReplaceOpenCLBuiltinPass::replaceConvert(Module &M) {
1185 bool Changed = false;
1186
1187 for (auto const &SymVal : M.getValueSymbolTable()) {
1188
1189 // Skip symbols whose name obviously doesn't match
1190 if (!SymVal.getKey().contains("convert_")) {
1191 continue;
1192 }
1193
1194 // Is there a function going by that name?
1195 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1196
1197 // Get info from the mangled name
1198 FunctionInfo finfo;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001199 bool parsed = FunctionInfo::getFromMangledNameCheck(F->getName(), &finfo);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001200
1201 // All functions of interest are handled by our mangled name parser
1202 if (!parsed) {
1203 continue;
1204 }
1205
1206 // Move on if this isn't a call to convert_
1207 if (!finfo.name.startswith("convert_")) {
1208 continue;
1209 }
1210
1211 // Extract the destination type from the function name
1212 StringRef DstTypeName = finfo.name;
1213 DstTypeName.consume_front("convert_");
1214
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001215 auto DstSignedNess =
1216 StringSwitch<ArgTypeInfo::SignedNess>(DstTypeName)
1217 .StartsWith("char", ArgTypeInfo::SignedNess::Signed)
1218 .StartsWith("short", ArgTypeInfo::SignedNess::Signed)
1219 .StartsWith("int", ArgTypeInfo::SignedNess::Signed)
1220 .StartsWith("long", ArgTypeInfo::SignedNess::Signed)
1221 .StartsWith("uchar", ArgTypeInfo::SignedNess::Unsigned)
1222 .StartsWith("ushort", ArgTypeInfo::SignedNess::Unsigned)
1223 .StartsWith("uint", ArgTypeInfo::SignedNess::Unsigned)
1224 .StartsWith("ulong", ArgTypeInfo::SignedNess::Unsigned)
1225 .Default(ArgTypeInfo::SignedNess::None);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001226
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001227 bool DstIsSigned = DstSignedNess == ArgTypeInfo::SignedNess::Signed;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001228 bool SrcIsSigned = finfo.isArgSigned(0);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001229
1230 SmallVector<Instruction *, 4> ToRemoves;
1231
1232 // Walk the users of the function.
1233 for (auto &U : F->uses()) {
1234 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1235
1236 // Get arguments
1237 auto SrcValue = CI->getOperand(0);
1238
1239 // Don't touch overloads that aren't in OpenCL C
1240 auto SrcType = SrcValue->getType();
1241 auto DstType = CI->getType();
1242
1243 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1244 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1245 continue;
1246 }
1247
1248 if (SrcType->isVectorTy()) {
1249
1250 if (SrcType->getVectorNumElements() !=
1251 DstType->getVectorNumElements()) {
1252 continue;
1253 }
1254
1255 if ((SrcType->getVectorNumElements() != 2) &&
1256 (SrcType->getVectorNumElements() != 3) &&
1257 (SrcType->getVectorNumElements() != 4) &&
1258 (SrcType->getVectorNumElements() != 8) &&
1259 (SrcType->getVectorNumElements() != 16)) {
1260 continue;
1261 }
1262 }
1263
1264 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1265 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1266
1267 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1268 bool DstIsInt = DstType->isIntOrIntVectorTy();
1269
1270 Value *V;
1271 if (SrcIsFloat && DstIsFloat) {
1272 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1273 } else if (SrcIsFloat && DstIsInt) {
1274 if (DstIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001275 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "",
1276 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001277 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001278 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "",
1279 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001280 }
1281 } else if (SrcIsInt && DstIsFloat) {
1282 if (SrcIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001283 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "",
1284 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001285 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001286 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "",
1287 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001288 }
1289 } else if (SrcIsInt && DstIsInt) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001290 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "",
1291 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001292 } else {
1293 // Not something we're supposed to handle, just move on
1294 continue;
1295 }
1296
1297 // Replace call with the expression
1298 CI->replaceAllUsesWith(V);
1299
1300 // Lastly, remember to remove the user.
1301 ToRemoves.push_back(CI);
1302 }
1303 }
1304
1305 Changed = !ToRemoves.empty();
1306
1307 // And cleanup the calls we don't use anymore.
1308 for (auto V : ToRemoves) {
1309 V->eraseFromParent();
1310 }
1311
1312 // And remove the function we don't need either too.
1313 F->eraseFromParent();
1314 }
1315 }
1316
1317 return Changed;
1318}
1319
Kévin Petit8a560882019-03-21 15:24:34 +00001320bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1321 bool Changed = false;
1322
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001323 SmallVector<Function *, 4> FnWorklist;
Kévin Petit8a560882019-03-21 15:24:34 +00001324
Kévin Petit617a76d2019-04-04 13:54:16 +01001325 for (auto const &SymVal : M.getValueSymbolTable()) {
Kévin Petit8a560882019-03-21 15:24:34 +00001326 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1327 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1328
1329 // Skip symbols whose name doesn't match
1330 if (!isMad && !isMul) {
1331 continue;
1332 }
1333
1334 // Is there a function going by that name?
1335 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001336 FnWorklist.push_back(F);
Kévin Petit8a560882019-03-21 15:24:34 +00001337 }
1338 }
1339
Kévin Petit617a76d2019-04-04 13:54:16 +01001340 for (auto F : FnWorklist) {
1341 SmallVector<Instruction *, 4> ToRemoves;
1342
1343 bool isMad = F->getName().startswith("_Z6mad_hi");
1344 // Walk the users of the function.
1345 for (auto &U : F->uses()) {
1346 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1347
1348 // Get arguments
1349 auto AValue = CI->getOperand(0);
1350 auto BValue = CI->getOperand(1);
1351 auto CValue = CI->getOperand(2);
1352
1353 // Don't touch overloads that aren't in OpenCL C
1354 auto AType = AValue->getType();
1355 auto BType = BValue->getType();
1356 auto CType = CValue->getType();
1357
1358 if ((AType != BType) || (CI->getType() != AType) ||
1359 (isMad && (AType != CType))) {
1360 continue;
1361 }
1362
1363 if (!AType->isIntOrIntVectorTy()) {
1364 continue;
1365 }
1366
1367 if ((AType->getScalarSizeInBits() != 8) &&
1368 (AType->getScalarSizeInBits() != 16) &&
1369 (AType->getScalarSizeInBits() != 32) &&
1370 (AType->getScalarSizeInBits() != 64)) {
1371 continue;
1372 }
1373
1374 if (AType->isVectorTy()) {
1375 if ((AType->getVectorNumElements() != 2) &&
1376 (AType->getVectorNumElements() != 3) &&
1377 (AType->getVectorNumElements() != 4) &&
1378 (AType->getVectorNumElements() != 8) &&
1379 (AType->getVectorNumElements() != 16)) {
1380 continue;
1381 }
1382 }
1383
1384 // Get infos from the mangled OpenCL built-in function name
Kévin Petit91bc72e2019-04-08 15:17:46 +01001385 auto finfo = FunctionInfo::getFromMangledName(F->getName());
Kévin Petit617a76d2019-04-04 13:54:16 +01001386
1387 // Select the appropriate signed/unsigned SPIR-V op
1388 spv::Op opcode;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001389 if (finfo.isArgSigned(0)) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001390 opcode = spv::OpSMulExtended;
1391 } else {
1392 opcode = spv::OpUMulExtended;
1393 }
1394
1395 // Our SPIR-V op returns a struct, create a type for it
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001396 SmallVector<Type *, 2> TwoValueType = {AType, AType};
Kévin Petit617a76d2019-04-04 13:54:16 +01001397 auto ExMulRetType = StructType::create(TwoValueType);
1398
1399 // Call the SPIR-V op
1400 auto Call = clspv::InsertSPIRVOp(CI, opcode, {Attribute::ReadNone},
1401 ExMulRetType, {AValue, BValue});
1402
1403 // Get the high part of the result
1404 unsigned Idxs[] = {1};
1405 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1406
1407 // If we're handling a mad_hi, add the third argument to the result
1408 if (isMad) {
1409 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1410 }
1411
1412 // Replace call with the expression
1413 CI->replaceAllUsesWith(V);
1414
1415 // Lastly, remember to remove the user.
1416 ToRemoves.push_back(CI);
1417 }
1418 }
1419
1420 Changed = !ToRemoves.empty();
1421
1422 // And cleanup the calls we don't use anymore.
1423 for (auto V : ToRemoves) {
1424 V->eraseFromParent();
1425 }
1426
1427 // And remove the function we don't need either too.
1428 F->eraseFromParent();
1429 }
1430
Kévin Petit8a560882019-03-21 15:24:34 +00001431 return Changed;
1432}
1433
Kévin Petitf5b78a22018-10-25 14:32:17 +00001434bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1435 bool Changed = false;
1436
1437 for (auto const &SymVal : M.getValueSymbolTable()) {
1438 // Skip symbols whose name doesn't match
1439 if (!SymVal.getKey().startswith("_Z6select")) {
1440 continue;
1441 }
1442 // Is there a function going by that name?
1443 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1444
1445 SmallVector<Instruction *, 4> ToRemoves;
1446
1447 // Walk the users of the function.
1448 for (auto &U : F->uses()) {
1449 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1450
1451 // Get arguments
1452 auto FalseValue = CI->getOperand(0);
1453 auto TrueValue = CI->getOperand(1);
1454 auto PredicateValue = CI->getOperand(2);
1455
1456 // Don't touch overloads that aren't in OpenCL C
1457 auto FalseType = FalseValue->getType();
1458 auto TrueType = TrueValue->getType();
1459 auto PredicateType = PredicateValue->getType();
1460
1461 if (FalseType != TrueType) {
1462 continue;
1463 }
1464
1465 if (!PredicateType->isIntOrIntVectorTy()) {
1466 continue;
1467 }
1468
1469 if (!FalseType->isIntOrIntVectorTy() &&
1470 !FalseType->getScalarType()->isFloatingPointTy()) {
1471 continue;
1472 }
1473
1474 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1475 continue;
1476 }
1477
1478 if (FalseType->getScalarSizeInBits() !=
1479 PredicateType->getScalarSizeInBits()) {
1480 continue;
1481 }
1482
1483 if (FalseType->isVectorTy()) {
1484 if (FalseType->getVectorNumElements() !=
1485 PredicateType->getVectorNumElements()) {
1486 continue;
1487 }
1488
1489 if ((FalseType->getVectorNumElements() != 2) &&
1490 (FalseType->getVectorNumElements() != 3) &&
1491 (FalseType->getVectorNumElements() != 4) &&
1492 (FalseType->getVectorNumElements() != 8) &&
1493 (FalseType->getVectorNumElements() != 16)) {
1494 continue;
1495 }
1496 }
1497
1498 // Create constant
1499 const auto ZeroValue = Constant::getNullValue(PredicateType);
1500
1501 // Scalar and vector are to be treated differently
1502 CmpInst::Predicate Pred;
1503 if (PredicateType->isVectorTy()) {
1504 Pred = CmpInst::ICMP_SLT;
1505 } else {
1506 Pred = CmpInst::ICMP_NE;
1507 }
1508
1509 // Create comparison instruction
1510 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1511 ZeroValue, "", CI);
1512
1513 // Create select
1514 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1515
1516 // Replace call with the selection
1517 CI->replaceAllUsesWith(V);
1518
1519 // Lastly, remember to remove the user.
1520 ToRemoves.push_back(CI);
1521 }
1522 }
1523
1524 Changed = !ToRemoves.empty();
1525
1526 // And cleanup the calls we don't use anymore.
1527 for (auto V : ToRemoves) {
1528 V->eraseFromParent();
1529 }
1530
1531 // And remove the function we don't need either too.
1532 F->eraseFromParent();
1533 }
1534 }
1535
1536 return Changed;
1537}
1538
Kévin Petite7d0cce2018-10-31 12:38:56 +00001539bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1540 bool Changed = false;
1541
1542 for (auto const &SymVal : M.getValueSymbolTable()) {
1543 // Skip symbols whose name doesn't match
1544 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1545 continue;
1546 }
1547 // Is there a function going by that name?
1548 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1549
1550 SmallVector<Instruction *, 4> ToRemoves;
1551
1552 // Walk the users of the function.
1553 for (auto &U : F->uses()) {
1554 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1555
1556 if (CI->getNumOperands() != 4) {
1557 continue;
1558 }
1559
1560 // Get arguments
1561 auto FalseValue = CI->getOperand(0);
1562 auto TrueValue = CI->getOperand(1);
1563 auto PredicateValue = CI->getOperand(2);
1564
1565 // Don't touch overloads that aren't in OpenCL C
1566 auto FalseType = FalseValue->getType();
1567 auto TrueType = TrueValue->getType();
1568 auto PredicateType = PredicateValue->getType();
1569
1570 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1571 continue;
1572 }
1573
1574 if (TrueType->isVectorTy()) {
1575 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1576 !TrueType->getScalarType()->isIntegerTy()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001577 continue;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001578 }
1579 if ((TrueType->getVectorNumElements() != 2) &&
1580 (TrueType->getVectorNumElements() != 3) &&
1581 (TrueType->getVectorNumElements() != 4) &&
1582 (TrueType->getVectorNumElements() != 8) &&
1583 (TrueType->getVectorNumElements() != 16)) {
1584 continue;
1585 }
1586 }
1587
1588 // Remember the type of the operands
1589 auto OpType = TrueType;
1590
1591 // The actual bit selection will always be done on an integer type,
1592 // declare it here
1593 Type *BitType;
1594
1595 // If the operands are float, then bitcast them to int
1596 if (OpType->getScalarType()->isFloatingPointTy()) {
1597
1598 // First create the new type
1599 auto ScalarSize = OpType->getScalarType()->getPrimitiveSizeInBits();
1600 BitType = Type::getIntNTy(M.getContext(), ScalarSize);
1601 if (OpType->isVectorTy()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001602 BitType =
1603 VectorType::get(BitType, OpType->getVectorNumElements());
Kévin Petite7d0cce2018-10-31 12:38:56 +00001604 }
1605
1606 // Then bitcast all operands
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001607 PredicateValue =
1608 CastInst::CreateZExtOrBitCast(PredicateValue, BitType, "", CI);
1609 FalseValue =
1610 CastInst::CreateZExtOrBitCast(FalseValue, BitType, "", CI);
1611 TrueValue =
1612 CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001613
1614 } else {
1615 // The operands have an integer type, use it directly
1616 BitType = OpType;
1617 }
1618
1619 // All the operands are now always integers
1620 // implement as (c & b) | (~c & a)
1621
1622 // Create our negated predicate value
1623 auto AllOnes = Constant::getAllOnesValue(BitType);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001624 auto NotPredicateValue = BinaryOperator::Create(
1625 Instruction::Xor, PredicateValue, AllOnes, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001626
1627 // Then put everything together
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001628 auto BitsFalse = BinaryOperator::Create(
1629 Instruction::And, NotPredicateValue, FalseValue, "", CI);
1630 auto BitsTrue = BinaryOperator::Create(
1631 Instruction::And, PredicateValue, TrueValue, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001632
1633 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1634 BitsTrue, "", CI);
1635
1636 // If we were dealing with a floating point type, we must bitcast
1637 // the result back to that
1638 if (OpType->getScalarType()->isFloatingPointTy()) {
1639 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1640 }
1641
1642 // Replace call with our new code
1643 CI->replaceAllUsesWith(V);
1644
1645 // Lastly, remember to remove the user.
1646 ToRemoves.push_back(CI);
1647 }
1648 }
1649
1650 Changed = !ToRemoves.empty();
1651
1652 // And cleanup the calls we don't use anymore.
1653 for (auto V : ToRemoves) {
1654 V->eraseFromParent();
1655 }
1656
1657 // And remove the function we don't need either too.
1658 F->eraseFromParent();
1659 }
1660 }
1661
1662 return Changed;
1663}
1664
Kévin Petit6b0a9532018-10-30 20:00:39 +00001665bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1666 bool Changed = false;
1667
1668 const std::map<const char *, const char *> Map = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001669 {"_Z4stepfDv2_f", "_Z4stepDv2_fS_"},
1670 {"_Z4stepfDv3_f", "_Z4stepDv3_fS_"},
1671 {"_Z4stepfDv4_f", "_Z4stepDv4_fS_"},
1672 {"_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_"},
1673 {"_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_"},
1674 {"_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_"},
Kévin Petit6b0a9532018-10-30 20:00:39 +00001675 };
1676
1677 for (auto Pair : Map) {
1678 // If we find a function with the matching name.
1679 if (auto F = M.getFunction(Pair.first)) {
1680 SmallVector<Instruction *, 4> ToRemoves;
1681
1682 // Walk the users of the function.
1683 for (auto &U : F->uses()) {
1684 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1685
1686 auto ReplacementFn = Pair.second;
1687
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001688 SmallVector<Value *, 2> ArgsToSplat = {CI->getOperand(0)};
Kévin Petit6b0a9532018-10-30 20:00:39 +00001689 Value *VectorArg;
1690
1691 // First figure out which function we're dealing with
1692 if (F->getName().startswith("_Z10smoothstep")) {
1693 ArgsToSplat.push_back(CI->getOperand(1));
1694 VectorArg = CI->getOperand(2);
1695 } else {
1696 VectorArg = CI->getOperand(1);
1697 }
1698
1699 // Splat arguments that need to be
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001700 SmallVector<Value *, 2> SplatArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001701 auto VecType = VectorArg->getType();
1702
1703 for (auto arg : ArgsToSplat) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001704 Value *NewVectorArg = UndefValue::get(VecType);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001705 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001706 auto index =
1707 ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1708 NewVectorArg =
1709 InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001710 }
1711 SplatArgs.push_back(NewVectorArg);
1712 }
1713
1714 // Replace the call with the vector/vector flavour
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001715 SmallVector<Type *, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1716 const auto NewFType =
1717 FunctionType::get(CI->getType(), NewArgTypes, false);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001718
1719 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1720
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001721 SmallVector<Value *, 3> NewArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001722 for (auto arg : SplatArgs) {
1723 NewArgs.push_back(arg);
1724 }
1725 NewArgs.push_back(VectorArg);
1726
1727 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1728
1729 CI->replaceAllUsesWith(NewCI);
1730
1731 // Lastly, remember to remove the user.
1732 ToRemoves.push_back(CI);
1733 }
1734 }
1735
1736 Changed = !ToRemoves.empty();
1737
1738 // And cleanup the calls we don't use anymore.
1739 for (auto V : ToRemoves) {
1740 V->eraseFromParent();
1741 }
1742
1743 // And remove the function we don't need either too.
1744 F->eraseFromParent();
1745 }
1746 }
1747
1748 return Changed;
1749}
1750
David Neto22f144c2017-06-12 14:26:21 -04001751bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1752 bool Changed = false;
1753
1754 const std::map<const char *, Instruction::BinaryOps> Map = {
1755 {"_Z7signbitf", Instruction::LShr},
1756 {"_Z7signbitDv2_f", Instruction::AShr},
1757 {"_Z7signbitDv3_f", Instruction::AShr},
1758 {"_Z7signbitDv4_f", Instruction::AShr},
1759 };
1760
1761 for (auto Pair : Map) {
1762 // If we find a function with the matching name.
1763 if (auto F = M.getFunction(Pair.first)) {
1764 SmallVector<Instruction *, 4> ToRemoves;
1765
1766 // Walk the users of the function.
1767 for (auto &U : F->uses()) {
1768 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1769 auto Arg = CI->getOperand(0);
1770
1771 auto Bitcast =
1772 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1773
1774 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1775 ConstantInt::get(CI->getType(), 31),
1776 "", CI);
1777
1778 CI->replaceAllUsesWith(Shr);
1779
1780 // Lastly, remember to remove the user.
1781 ToRemoves.push_back(CI);
1782 }
1783 }
1784
1785 Changed = !ToRemoves.empty();
1786
1787 // And cleanup the calls we don't use anymore.
1788 for (auto V : ToRemoves) {
1789 V->eraseFromParent();
1790 }
1791
1792 // And remove the function we don't need either too.
1793 F->eraseFromParent();
1794 }
1795 }
1796
1797 return Changed;
1798}
1799
1800bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1801 bool Changed = false;
1802
1803 const std::map<const char *,
1804 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1805 Map = {
1806 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1807 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1808 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1809 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1810 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1811 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1812 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1813 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1814 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1815 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1816 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1817 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1818 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1819 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1820 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1821 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1822 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1823 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1824 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1825 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1826 };
1827
1828 for (auto Pair : Map) {
1829 // If we find a function with the matching name.
1830 if (auto F = M.getFunction(Pair.first)) {
1831 SmallVector<Instruction *, 4> ToRemoves;
1832
1833 // Walk the users of the function.
1834 for (auto &U : F->uses()) {
1835 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1836 // The multiply instruction to use.
1837 auto MulInst = Pair.second.first;
1838
1839 // The add instruction to use.
1840 auto AddInst = Pair.second.second;
1841
1842 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1843
1844 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1845 CI->getArgOperand(1), "", CI);
1846
1847 if (Instruction::BinaryOpsEnd != AddInst) {
1848 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1849 CI);
1850 }
1851
1852 CI->replaceAllUsesWith(I);
1853
1854 // Lastly, remember to remove the user.
1855 ToRemoves.push_back(CI);
1856 }
1857 }
1858
1859 Changed = !ToRemoves.empty();
1860
1861 // And cleanup the calls we don't use anymore.
1862 for (auto V : ToRemoves) {
1863 V->eraseFromParent();
1864 }
1865
1866 // And remove the function we don't need either too.
1867 F->eraseFromParent();
1868 }
1869 }
1870
1871 return Changed;
1872}
1873
Derek Chowcfd368b2017-10-19 20:58:45 -07001874bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1875 bool Changed = false;
1876
alan-bakerf795f392019-06-11 18:24:34 -04001877 for (auto const &SymVal : M.getValueSymbolTable()) {
1878 if (!SymVal.getKey().contains("vstore"))
1879 continue;
1880 if (SymVal.getKey().contains("vstore_"))
1881 continue;
1882 if (SymVal.getKey().contains("vstorea"))
1883 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001884
alan-bakerf795f392019-06-11 18:24:34 -04001885 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001886 SmallVector<Instruction *, 4> ToRemoves;
1887
alan-bakerf795f392019-06-11 18:24:34 -04001888 auto fname = F->getName();
1889 if (!fname.consume_front("_Z"))
1890 continue;
1891 size_t name_len;
1892 if (fname.consumeInteger(10, name_len))
1893 continue;
1894 std::string name = fname.take_front(name_len);
1895
1896 bool ok = StringSwitch<bool>(name)
1897 .Case("vstore2", true)
1898 .Case("vstore3", true)
1899 .Case("vstore4", true)
1900 .Case("vstore8", true)
1901 .Case("vstore16", true)
1902 .Default(false);
1903 if (!ok)
1904 continue;
1905
Derek Chowcfd368b2017-10-19 20:58:45 -07001906 for (auto &U : F->uses()) {
1907 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04001908 auto data = CI->getOperand(0);
Derek Chowcfd368b2017-10-19 20:58:45 -07001909
alan-bakerf795f392019-06-11 18:24:34 -04001910 auto data_type = data->getType();
1911 if (!data_type->isVectorTy())
1912 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001913
alan-bakerf795f392019-06-11 18:24:34 -04001914 auto elems = data_type->getVectorNumElements();
1915 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
1916 elems != 16)
1917 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001918
alan-bakerf795f392019-06-11 18:24:34 -04001919 auto offset = CI->getOperand(1);
1920 auto ptr = CI->getOperand(2);
1921 auto ptr_type = ptr->getType();
1922 auto pointee_type = ptr_type->getPointerElementType();
1923 if (pointee_type != data_type->getVectorElementType())
1924 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001925
alan-bakerf795f392019-06-11 18:24:34 -04001926 // Avoid pointer casts. Instead generate the correct number of stores
1927 // and rely on drivers to coalesce appropriately.
1928 IRBuilder<> builder(CI);
1929 auto elems_const = builder.getInt32(elems);
1930 auto adjust = builder.CreateMul(offset, elems_const);
1931 for (auto i = 0; i < elems; ++i) {
1932 auto idx = builder.getInt32(i);
1933 auto add = builder.CreateAdd(adjust, idx);
1934 auto gep = builder.CreateGEP(ptr, add);
1935 auto extract = builder.CreateExtractElement(data, i);
1936 auto store = builder.CreateStore(extract, gep);
1937 }
Derek Chowcfd368b2017-10-19 20:58:45 -07001938
Derek Chowcfd368b2017-10-19 20:58:45 -07001939 ToRemoves.push_back(CI);
1940 }
1941 }
1942
1943 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07001944 for (auto V : ToRemoves) {
1945 V->eraseFromParent();
1946 }
Derek Chowcfd368b2017-10-19 20:58:45 -07001947 F->eraseFromParent();
1948 }
1949 }
1950
1951 return Changed;
1952}
1953
1954bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
1955 bool Changed = false;
1956
alan-bakerf795f392019-06-11 18:24:34 -04001957 for (auto const &SymVal : M.getValueSymbolTable()) {
1958 if (!SymVal.getKey().contains("vload"))
1959 continue;
1960 if (SymVal.getKey().contains("vload_"))
1961 continue;
1962 if (SymVal.getKey().contains("vloada"))
1963 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001964
alan-bakerf795f392019-06-11 18:24:34 -04001965 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001966 SmallVector<Instruction *, 4> ToRemoves;
1967
alan-bakerf795f392019-06-11 18:24:34 -04001968 auto fname = F->getName();
1969 if (!fname.consume_front("_Z"))
1970 continue;
1971 size_t name_len;
1972 if (fname.consumeInteger(10, name_len))
1973 continue;
1974 std::string name = fname.take_front(name_len);
1975
1976 bool ok = StringSwitch<bool>(name)
1977 .Case("vload2", true)
1978 .Case("vload3", true)
1979 .Case("vload4", true)
1980 .Case("vload8", true)
1981 .Case("vload16", true)
1982 .Default(false);
1983 if (!ok)
1984 continue;
1985
Derek Chowcfd368b2017-10-19 20:58:45 -07001986 for (auto &U : F->uses()) {
1987 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04001988 auto ret_type = F->getReturnType();
1989 if (!ret_type->isVectorTy())
1990 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001991
alan-bakerf795f392019-06-11 18:24:34 -04001992 auto elems = ret_type->getVectorNumElements();
1993 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
1994 elems != 16)
1995 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001996
alan-bakerf795f392019-06-11 18:24:34 -04001997 auto offset = CI->getOperand(0);
1998 auto ptr = CI->getOperand(1);
1999 auto ptr_type = ptr->getType();
2000 auto pointee_type = ptr_type->getPointerElementType();
2001 if (pointee_type != ret_type->getVectorElementType())
2002 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002003
alan-bakerf795f392019-06-11 18:24:34 -04002004 // Avoid pointer casts. Instead generate the correct number of loads
2005 // and rely on drivers to coalesce appropriately.
2006 IRBuilder<> builder(CI);
2007 auto elems_const = builder.getInt32(elems);
2008 Value *insert = UndefValue::get(ret_type);
2009 auto adjust = builder.CreateMul(offset, elems_const);
2010 for (auto i = 0; i < elems; ++i) {
2011 auto idx = builder.getInt32(i);
2012 auto add = builder.CreateAdd(adjust, idx);
2013 auto gep = builder.CreateGEP(ptr, add);
2014 auto load = builder.CreateLoad(gep);
2015 insert = builder.CreateInsertElement(insert, load, i);
2016 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002017
alan-bakerf795f392019-06-11 18:24:34 -04002018 CI->replaceAllUsesWith(insert);
Derek Chowcfd368b2017-10-19 20:58:45 -07002019 ToRemoves.push_back(CI);
2020 }
2021 }
2022
2023 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07002024 for (auto V : ToRemoves) {
2025 V->eraseFromParent();
2026 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002027 F->eraseFromParent();
Derek Chowcfd368b2017-10-19 20:58:45 -07002028 }
2029 }
2030
2031 return Changed;
2032}
2033
David Neto22f144c2017-06-12 14:26:21 -04002034bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
2035 bool Changed = false;
2036
2037 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
2038 "_Z10vload_halfjPU3AS2KDh"};
2039
2040 for (auto Name : Map) {
2041 // If we find a function with the matching name.
2042 if (auto F = M.getFunction(Name)) {
2043 SmallVector<Instruction *, 4> ToRemoves;
2044
2045 // Walk the users of the function.
2046 for (auto &U : F->uses()) {
2047 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2048 // The index argument from vload_half.
2049 auto Arg0 = CI->getOperand(0);
2050
2051 // The pointer argument from vload_half.
2052 auto Arg1 = CI->getOperand(1);
2053
David Neto22f144c2017-06-12 14:26:21 -04002054 auto IntTy = Type::getInt32Ty(M.getContext());
2055 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002056 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2057
David Neto22f144c2017-06-12 14:26:21 -04002058 // Our intrinsic to unpack a float2 from an int.
2059 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2060
2061 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2062
David Neto482550a2018-03-24 05:21:07 -07002063 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04002064 auto ShortTy = Type::getInt16Ty(M.getContext());
2065 auto ShortPointerTy = PointerType::get(
2066 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002067
David Netoac825b82017-05-30 12:49:01 -04002068 // Cast the half* pointer to short*.
2069 auto Cast =
2070 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002071
David Netoac825b82017-05-30 12:49:01 -04002072 // Index into the correct address of the casted pointer.
2073 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
2074
2075 // Load from the short* we casted to.
2076 auto Load = new LoadInst(Index, "", CI);
2077
2078 // ZExt the short -> int.
2079 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
2080
2081 // Get our float2.
2082 auto Call = CallInst::Create(NewF, ZExt, "", CI);
2083
2084 // Extract out the bottom element which is our float result.
2085 auto Extract = ExtractElementInst::Create(
2086 Call, ConstantInt::get(IntTy, 0), "", CI);
2087
2088 CI->replaceAllUsesWith(Extract);
2089 } else {
2090 // Assume the pointer argument points to storage aligned to 32bits
2091 // or more.
2092 // TODO(dneto): Do more analysis to make sure this is true?
2093 //
2094 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
2095 // with:
2096 //
2097 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
2098 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
2099 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
2100 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
2101 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
2102 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
2103 // x float> %converted, %index_is_odd32
2104
2105 auto IntPointerTy = PointerType::get(
2106 IntTy, Arg1->getType()->getPointerAddressSpace());
2107
David Neto973e6a82017-05-30 13:48:18 -04002108 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04002109 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04002110 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04002111 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
2112
2113 auto One = ConstantInt::get(IntTy, 1);
2114 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
2115 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
2116
2117 // Index into the correct address of the casted pointer.
2118 auto Ptr =
2119 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
2120
2121 // Load from the int* we casted to.
2122 auto Load = new LoadInst(Ptr, "", CI);
2123
2124 // Get our float2.
2125 auto Call = CallInst::Create(NewF, Load, "", CI);
2126
2127 // Extract out the float result, where the element number is
2128 // determined by whether the original index was even or odd.
2129 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
2130
2131 CI->replaceAllUsesWith(Extract);
2132 }
David Neto22f144c2017-06-12 14:26:21 -04002133
2134 // Lastly, remember to remove the user.
2135 ToRemoves.push_back(CI);
2136 }
2137 }
2138
2139 Changed = !ToRemoves.empty();
2140
2141 // And cleanup the calls we don't use anymore.
2142 for (auto V : ToRemoves) {
2143 V->eraseFromParent();
2144 }
2145
2146 // And remove the function we don't need either too.
2147 F->eraseFromParent();
2148 }
2149 }
2150
2151 return Changed;
2152}
2153
2154bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002155
Kévin Petite8edce32019-04-10 14:23:32 +01002156 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002157 "_Z11vload_half2jPU3AS1KDh",
2158 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
2159 "_Z11vload_half2jPU3AS2KDh",
2160 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
2161 };
David Neto22f144c2017-06-12 14:26:21 -04002162
Kévin Petite8edce32019-04-10 14:23:32 +01002163 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2164 // The index argument from vload_half.
2165 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002166
Kévin Petite8edce32019-04-10 14:23:32 +01002167 // The pointer argument from vload_half.
2168 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002169
Kévin Petite8edce32019-04-10 14:23:32 +01002170 auto IntTy = Type::getInt32Ty(M.getContext());
2171 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002172 auto NewPointerTy =
2173 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002174 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002175
Kévin Petite8edce32019-04-10 14:23:32 +01002176 // Cast the half* pointer to int*.
2177 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002178
Kévin Petite8edce32019-04-10 14:23:32 +01002179 // Index into the correct address of the casted pointer.
2180 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002181
Kévin Petite8edce32019-04-10 14:23:32 +01002182 // Load from the int* we casted to.
2183 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002184
Kévin Petite8edce32019-04-10 14:23:32 +01002185 // Our intrinsic to unpack a float2 from an int.
2186 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002187
Kévin Petite8edce32019-04-10 14:23:32 +01002188 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002189
Kévin Petite8edce32019-04-10 14:23:32 +01002190 // Get our float2.
2191 return CallInst::Create(NewF, Load, "", CI);
2192 });
David Neto22f144c2017-06-12 14:26:21 -04002193}
2194
2195bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002196
Kévin Petite8edce32019-04-10 14:23:32 +01002197 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002198 "_Z11vload_half4jPU3AS1KDh",
2199 "_Z12vloada_half4jPU3AS1KDh",
2200 "_Z11vload_half4jPU3AS2KDh",
2201 "_Z12vloada_half4jPU3AS2KDh",
2202 };
David Neto22f144c2017-06-12 14:26:21 -04002203
Kévin Petite8edce32019-04-10 14:23:32 +01002204 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2205 // The index argument from vload_half.
2206 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002207
Kévin Petite8edce32019-04-10 14:23:32 +01002208 // The pointer argument from vload_half.
2209 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002210
Kévin Petite8edce32019-04-10 14:23:32 +01002211 auto IntTy = Type::getInt32Ty(M.getContext());
2212 auto Int2Ty = VectorType::get(IntTy, 2);
2213 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002214 auto NewPointerTy =
2215 PointerType::get(Int2Ty, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002216 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002217
Kévin Petite8edce32019-04-10 14:23:32 +01002218 // Cast the half* pointer to int2*.
2219 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002220
Kévin Petite8edce32019-04-10 14:23:32 +01002221 // Index into the correct address of the casted pointer.
2222 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002223
Kévin Petite8edce32019-04-10 14:23:32 +01002224 // Load from the int2* we casted to.
2225 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002226
Kévin Petite8edce32019-04-10 14:23:32 +01002227 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002228 auto X =
2229 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2230 auto Y =
2231 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002232
Kévin Petite8edce32019-04-10 14:23:32 +01002233 // Our intrinsic to unpack a float2 from an int.
2234 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002235
Kévin Petite8edce32019-04-10 14:23:32 +01002236 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002237
Kévin Petite8edce32019-04-10 14:23:32 +01002238 // Get the lower (x & y) components of our final float4.
2239 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002240
Kévin Petite8edce32019-04-10 14:23:32 +01002241 // Get the higher (z & w) components of our final float4.
2242 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002243
Kévin Petite8edce32019-04-10 14:23:32 +01002244 Constant *ShuffleMask[4] = {
2245 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2246 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002247
Kévin Petite8edce32019-04-10 14:23:32 +01002248 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002249 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2250 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002251 });
David Neto22f144c2017-06-12 14:26:21 -04002252}
2253
David Neto6ad93232018-06-07 15:42:58 -07002254bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002255
2256 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2257 //
2258 // %u = load i32 %ptr
2259 // %fxy = call <2 x float> Unpack2xHalf(u)
2260 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002261 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002262 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2263 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2264 "_Z20__clspv_vloada_half2jPKj", // private
2265 };
2266
Kévin Petite8edce32019-04-10 14:23:32 +01002267 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2268 auto Index = CI->getOperand(0);
2269 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002270
Kévin Petite8edce32019-04-10 14:23:32 +01002271 auto IntTy = Type::getInt32Ty(M.getContext());
2272 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2273 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002274
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002275 auto IndexedPtr = GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002276 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002277
Kévin Petite8edce32019-04-10 14:23:32 +01002278 // Our intrinsic to unpack a float2 from an int.
2279 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002280
Kévin Petite8edce32019-04-10 14:23:32 +01002281 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002282
Kévin Petite8edce32019-04-10 14:23:32 +01002283 // Get our final float2.
2284 return CallInst::Create(NewF, Load, "", CI);
2285 });
David Neto6ad93232018-06-07 15:42:58 -07002286}
2287
2288bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002289
2290 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2291 //
2292 // %u2 = load <2 x i32> %ptr
2293 // %u2xy = extractelement %u2, 0
2294 // %u2zw = extractelement %u2, 1
2295 // %fxy = call <2 x float> Unpack2xHalf(uint)
2296 // %fzw = call <2 x float> Unpack2xHalf(uint)
2297 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002298 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002299 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2300 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2301 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2302 };
2303
Kévin Petite8edce32019-04-10 14:23:32 +01002304 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2305 auto Index = CI->getOperand(0);
2306 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002307
Kévin Petite8edce32019-04-10 14:23:32 +01002308 auto IntTy = Type::getInt32Ty(M.getContext());
2309 auto Int2Ty = VectorType::get(IntTy, 2);
2310 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2311 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002312
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002313 auto IndexedPtr = GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002314 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002315
Kévin Petite8edce32019-04-10 14:23:32 +01002316 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002317 auto X =
2318 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2319 auto Y =
2320 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002321
Kévin Petite8edce32019-04-10 14:23:32 +01002322 // Our intrinsic to unpack a float2 from an int.
2323 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002324
Kévin Petite8edce32019-04-10 14:23:32 +01002325 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002326
Kévin Petite8edce32019-04-10 14:23:32 +01002327 // Get the lower (x & y) components of our final float4.
2328 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002329
Kévin Petite8edce32019-04-10 14:23:32 +01002330 // Get the higher (z & w) components of our final float4.
2331 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002332
Kévin Petite8edce32019-04-10 14:23:32 +01002333 Constant *ShuffleMask[4] = {
2334 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2335 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto6ad93232018-06-07 15:42:58 -07002336
Kévin Petite8edce32019-04-10 14:23:32 +01002337 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002338 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2339 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002340 });
David Neto6ad93232018-06-07 15:42:58 -07002341}
2342
David Neto22f144c2017-06-12 14:26:21 -04002343bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002344
Kévin Petite8edce32019-04-10 14:23:32 +01002345 const std::vector<const char *> Names = {"_Z11vstore_halffjPU3AS1Dh",
2346 "_Z15vstore_half_rtefjPU3AS1Dh",
2347 "_Z15vstore_half_rtzfjPU3AS1Dh"};
David Neto22f144c2017-06-12 14:26:21 -04002348
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002349 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002350 // The value to store.
2351 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002352
Kévin Petite8edce32019-04-10 14:23:32 +01002353 // The index argument from vstore_half.
2354 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002355
Kévin Petite8edce32019-04-10 14:23:32 +01002356 // The pointer argument from vstore_half.
2357 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002358
Kévin Petite8edce32019-04-10 14:23:32 +01002359 auto IntTy = Type::getInt32Ty(M.getContext());
2360 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2361 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2362 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002363
Kévin Petite8edce32019-04-10 14:23:32 +01002364 // Our intrinsic to pack a float2 to an int.
2365 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002366
Kévin Petite8edce32019-04-10 14:23:32 +01002367 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002368
Kévin Petite8edce32019-04-10 14:23:32 +01002369 // Insert our value into a float2 so that we can pack it.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002370 auto TempVec = InsertElementInst::Create(
2371 UndefValue::get(Float2Ty), Arg0, ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002372
Kévin Petite8edce32019-04-10 14:23:32 +01002373 // Pack the float2 -> half2 (in an int).
2374 auto X = CallInst::Create(NewF, TempVec, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002375
Kévin Petite8edce32019-04-10 14:23:32 +01002376 Value *Ret;
2377 if (clspv::Option::F16BitStorage()) {
2378 auto ShortTy = Type::getInt16Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002379 auto ShortPointerTy =
2380 PointerType::get(ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002381
Kévin Petite8edce32019-04-10 14:23:32 +01002382 // Truncate our i32 to an i16.
2383 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002384
Kévin Petite8edce32019-04-10 14:23:32 +01002385 // Cast the half* pointer to short*.
2386 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002387
Kévin Petite8edce32019-04-10 14:23:32 +01002388 // Index into the correct address of the casted pointer.
2389 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002390
Kévin Petite8edce32019-04-10 14:23:32 +01002391 // Store to the int* we casted to.
2392 Ret = new StoreInst(Trunc, Index, CI);
2393 } else {
2394 // We can only write to 32-bit aligned words.
2395 //
2396 // Assuming base is aligned to 32-bits, replace the equivalent of
2397 // vstore_half(value, index, base)
2398 // with:
2399 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2400 // uint32_t write_to_upper_half = index & 1u;
2401 // uint32_t shift = write_to_upper_half << 4;
2402 //
2403 // // Pack the float value as a half number in bottom 16 bits
2404 // // of an i32.
2405 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2406 //
2407 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2408 // ^ ((packed & 0xffff) << shift)
2409 // // We only need relaxed consistency, but OpenCL 1.2 only has
2410 // // sequentially consistent atomics.
2411 // // TODO(dneto): Use relaxed consistency.
2412 // atomic_xor(target_ptr, xor_value)
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002413 auto IntPointerTy =
2414 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002415
Kévin Petite8edce32019-04-10 14:23:32 +01002416 auto Four = ConstantInt::get(IntTy, 4);
2417 auto FFFF = ConstantInt::get(IntTy, 0xffff);
David Neto17852de2017-05-29 17:29:31 -04002418
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002419 auto IndexIsOdd =
2420 BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002421 // Compute index / 2
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002422 auto IndexIntoI32 =
2423 BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2424 auto BaseI32Ptr =
2425 CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2426 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32,
2427 "base_i32_ptr", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002428 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2429 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002430 auto MaskBitsToWrite =
2431 BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2432 auto MaskedCurrent = BinaryOperator::CreateAnd(
2433 MaskBitsToWrite, CurrentValue, "masked_current", CI);
David Neto17852de2017-05-29 17:29:31 -04002434
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002435 auto XLowerBits =
2436 BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2437 auto NewBitsToWrite =
2438 BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2439 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite,
2440 "value_to_xor", CI);
David Neto17852de2017-05-29 17:29:31 -04002441
Kévin Petite8edce32019-04-10 14:23:32 +01002442 // Generate the call to atomi_xor.
2443 SmallVector<Type *, 5> ParamTypes;
2444 // The pointer type.
2445 ParamTypes.push_back(IntPointerTy);
2446 // The Types for memory scope, semantics, and value.
2447 ParamTypes.push_back(IntTy);
2448 ParamTypes.push_back(IntTy);
2449 ParamTypes.push_back(IntTy);
2450 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2451 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
David Neto17852de2017-05-29 17:29:31 -04002452
Kévin Petite8edce32019-04-10 14:23:32 +01002453 const auto ConstantScopeDevice =
2454 ConstantInt::get(IntTy, spv::ScopeDevice);
2455 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2456 // (SPIR-V Workgroup).
2457 const auto AddrSpaceSemanticsBits =
2458 IntPointerTy->getPointerAddressSpace() == 1
2459 ? spv::MemorySemanticsUniformMemoryMask
2460 : spv::MemorySemanticsWorkgroupMemoryMask;
David Neto17852de2017-05-29 17:29:31 -04002461
Kévin Petite8edce32019-04-10 14:23:32 +01002462 // We're using relaxed consistency here.
2463 const auto ConstantMemorySemantics =
2464 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2465 AddrSpaceSemanticsBits);
David Neto17852de2017-05-29 17:29:31 -04002466
Kévin Petite8edce32019-04-10 14:23:32 +01002467 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2468 ConstantMemorySemantics, ValueToXor};
2469 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2470 Ret = nullptr;
David Neto22f144c2017-06-12 14:26:21 -04002471 }
David Neto22f144c2017-06-12 14:26:21 -04002472
Kévin Petite8edce32019-04-10 14:23:32 +01002473 return Ret;
2474 });
David Neto22f144c2017-06-12 14:26:21 -04002475}
2476
2477bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002478
Kévin Petite8edce32019-04-10 14:23:32 +01002479 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002480 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2481 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2482 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2483 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2484 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2485 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2486 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2487 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2488 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2489 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2490 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2491 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2492 };
David Neto22f144c2017-06-12 14:26:21 -04002493
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002494 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002495 // The value to store.
2496 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002497
Kévin Petite8edce32019-04-10 14:23:32 +01002498 // The index argument from vstore_half.
2499 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002500
Kévin Petite8edce32019-04-10 14:23:32 +01002501 // The pointer argument from vstore_half.
2502 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002503
Kévin Petite8edce32019-04-10 14:23:32 +01002504 auto IntTy = Type::getInt32Ty(M.getContext());
2505 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002506 auto NewPointerTy =
2507 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002508 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002509
Kévin Petite8edce32019-04-10 14:23:32 +01002510 // Our intrinsic to pack a float2 to an int.
2511 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002512
Kévin Petite8edce32019-04-10 14:23:32 +01002513 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002514
Kévin Petite8edce32019-04-10 14:23:32 +01002515 // Turn the packed x & y into the final packing.
2516 auto X = CallInst::Create(NewF, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002517
Kévin Petite8edce32019-04-10 14:23:32 +01002518 // Cast the half* pointer to int*.
2519 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002520
Kévin Petite8edce32019-04-10 14:23:32 +01002521 // Index into the correct address of the casted pointer.
2522 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002523
Kévin Petite8edce32019-04-10 14:23:32 +01002524 // Store to the int* we casted to.
2525 return new StoreInst(X, Index, CI);
2526 });
David Neto22f144c2017-06-12 14:26:21 -04002527}
2528
2529bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002530
Kévin Petite8edce32019-04-10 14:23:32 +01002531 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002532 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2533 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2534 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2535 "_Z13vstorea_half4Dv4_fjPDh", // private
2536 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2537 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2538 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2539 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2540 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2541 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2542 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2543 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2544 };
David Neto22f144c2017-06-12 14:26:21 -04002545
Kévin Petite8edce32019-04-10 14:23:32 +01002546 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2547 // The value to store.
2548 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002549
Kévin Petite8edce32019-04-10 14:23:32 +01002550 // The index argument from vstore_half.
2551 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002552
Kévin Petite8edce32019-04-10 14:23:32 +01002553 // The pointer argument from vstore_half.
2554 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002555
Kévin Petite8edce32019-04-10 14:23:32 +01002556 auto IntTy = Type::getInt32Ty(M.getContext());
2557 auto Int2Ty = VectorType::get(IntTy, 2);
2558 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002559 auto NewPointerTy =
2560 PointerType::get(Int2Ty, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002561 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002562
Kévin Petite8edce32019-04-10 14:23:32 +01002563 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2564 ConstantInt::get(IntTy, 1)};
David Neto22f144c2017-06-12 14:26:21 -04002565
Kévin Petite8edce32019-04-10 14:23:32 +01002566 // Extract out the x & y components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002567 auto Lo = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2568 ConstantVector::get(LoShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002569
Kévin Petite8edce32019-04-10 14:23:32 +01002570 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2571 ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002572
Kévin Petite8edce32019-04-10 14:23:32 +01002573 // Extract out the z & w components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002574 auto Hi = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2575 ConstantVector::get(HiShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002576
Kévin Petite8edce32019-04-10 14:23:32 +01002577 // Our intrinsic to pack a float2 to an int.
2578 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002579
Kévin Petite8edce32019-04-10 14:23:32 +01002580 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002581
Kévin Petite8edce32019-04-10 14:23:32 +01002582 // Turn the packed x & y into the final component of our int2.
2583 auto X = CallInst::Create(NewF, Lo, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002584
Kévin Petite8edce32019-04-10 14:23:32 +01002585 // Turn the packed z & w into the final component of our int2.
2586 auto Y = CallInst::Create(NewF, Hi, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002587
Kévin Petite8edce32019-04-10 14:23:32 +01002588 auto Combine = InsertElementInst::Create(
2589 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002590 Combine = InsertElementInst::Create(Combine, Y, ConstantInt::get(IntTy, 1),
2591 "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002592
Kévin Petite8edce32019-04-10 14:23:32 +01002593 // Cast the half* pointer to int2*.
2594 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002595
Kévin Petite8edce32019-04-10 14:23:32 +01002596 // Index into the correct address of the casted pointer.
2597 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002598
Kévin Petite8edce32019-04-10 14:23:32 +01002599 // Store to the int2* we casted to.
2600 return new StoreInst(Combine, Index, CI);
2601 });
David Neto22f144c2017-06-12 14:26:21 -04002602}
2603
2604bool ReplaceOpenCLBuiltinPass::replaceReadImageF(Module &M) {
2605 bool Changed = false;
2606
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002607 const std::map<const char *, const char *> Map = {
2608 {"_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i",
2609 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f"},
2610 {"_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_i",
2611 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv4_f"}};
David Neto22f144c2017-06-12 14:26:21 -04002612
2613 for (auto Pair : Map) {
2614 // If we find a function with the matching name.
2615 if (auto F = M.getFunction(Pair.first)) {
2616 SmallVector<Instruction *, 4> ToRemoves;
2617
2618 // Walk the users of the function.
2619 for (auto &U : F->uses()) {
2620 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2621 // The image.
2622 auto Arg0 = CI->getOperand(0);
2623
2624 // The sampler.
2625 auto Arg1 = CI->getOperand(1);
2626
2627 // The coordinate (integer type that we can't handle).
2628 auto Arg2 = CI->getOperand(2);
2629
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002630 auto FloatVecTy =
2631 VectorType::get(Type::getFloatTy(M.getContext()),
2632 Arg2->getType()->getVectorNumElements());
David Neto22f144c2017-06-12 14:26:21 -04002633
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002634 auto NewFType = FunctionType::get(
2635 CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy},
2636 false);
David Neto22f144c2017-06-12 14:26:21 -04002637
2638 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2639
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002640 auto Cast =
2641 CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002642
2643 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2644
2645 CI->replaceAllUsesWith(NewCI);
2646
2647 // Lastly, remember to remove the user.
2648 ToRemoves.push_back(CI);
2649 }
2650 }
2651
2652 Changed = !ToRemoves.empty();
2653
2654 // And cleanup the calls we don't use anymore.
2655 for (auto V : ToRemoves) {
2656 V->eraseFromParent();
2657 }
2658
2659 // And remove the function we don't need either too.
2660 F->eraseFromParent();
2661 }
2662 }
2663
2664 return Changed;
2665}
2666
2667bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2668 bool Changed = false;
2669
2670 const std::map<const char *, const char *> Map = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002671 {"_Z8atom_incPU3AS1Vi", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002672 {"_Z8atom_incPU3AS3Vi", "spirv.atomic_inc"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002673 {"_Z8atom_incPU3AS1Vj", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002674 {"_Z8atom_incPU3AS3Vj", "spirv.atomic_inc"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002675 {"_Z8atom_decPU3AS1Vi", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002676 {"_Z8atom_decPU3AS3Vi", "spirv.atomic_dec"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002677 {"_Z8atom_decPU3AS1Vj", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002678 {"_Z8atom_decPU3AS3Vj", "spirv.atomic_dec"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002679 {"_Z12atom_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002680 {"_Z12atom_cmpxchgPU3AS3Viii", "spirv.atomic_compare_exchange"},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002681 {"_Z12atom_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002682 {"_Z12atom_cmpxchgPU3AS3Vjjj", "spirv.atomic_compare_exchange"},
David Neto22f144c2017-06-12 14:26:21 -04002683 {"_Z10atomic_incPU3AS1Vi", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002684 {"_Z10atomic_incPU3AS3Vi", "spirv.atomic_inc"},
David Neto22f144c2017-06-12 14:26:21 -04002685 {"_Z10atomic_incPU3AS1Vj", "spirv.atomic_inc"},
Kévin Petita303dc62019-03-26 21:40:35 +00002686 {"_Z10atomic_incPU3AS3Vj", "spirv.atomic_inc"},
David Neto22f144c2017-06-12 14:26:21 -04002687 {"_Z10atomic_decPU3AS1Vi", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002688 {"_Z10atomic_decPU3AS3Vi", "spirv.atomic_dec"},
David Neto22f144c2017-06-12 14:26:21 -04002689 {"_Z10atomic_decPU3AS1Vj", "spirv.atomic_dec"},
Kévin Petita303dc62019-03-26 21:40:35 +00002690 {"_Z10atomic_decPU3AS3Vj", "spirv.atomic_dec"},
David Neto22f144c2017-06-12 14:26:21 -04002691 {"_Z14atomic_cmpxchgPU3AS1Viii", "spirv.atomic_compare_exchange"},
Kévin Petita303dc62019-03-26 21:40:35 +00002692 {"_Z14atomic_cmpxchgPU3AS3Viii", "spirv.atomic_compare_exchange"},
2693 {"_Z14atomic_cmpxchgPU3AS1Vjjj", "spirv.atomic_compare_exchange"},
2694 {"_Z14atomic_cmpxchgPU3AS3Vjjj", "spirv.atomic_compare_exchange"}};
David Neto22f144c2017-06-12 14:26:21 -04002695
2696 for (auto Pair : Map) {
2697 // If we find a function with the matching name.
2698 if (auto F = M.getFunction(Pair.first)) {
2699 SmallVector<Instruction *, 4> ToRemoves;
2700
2701 // Walk the users of the function.
2702 for (auto &U : F->uses()) {
2703 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2704 auto FType = F->getFunctionType();
2705 SmallVector<Type *, 5> ParamTypes;
2706
2707 // The pointer type.
2708 ParamTypes.push_back(FType->getParamType(0));
2709
2710 auto IntTy = Type::getInt32Ty(M.getContext());
2711
2712 // The memory scope type.
2713 ParamTypes.push_back(IntTy);
2714
2715 // The memory semantics type.
2716 ParamTypes.push_back(IntTy);
2717
2718 if (2 < CI->getNumArgOperands()) {
2719 // The unequal memory semantics type.
2720 ParamTypes.push_back(IntTy);
2721
2722 // The value type.
2723 ParamTypes.push_back(FType->getParamType(2));
2724
2725 // The comparator type.
2726 ParamTypes.push_back(FType->getParamType(1));
2727 } else if (1 < CI->getNumArgOperands()) {
2728 // The value type.
2729 ParamTypes.push_back(FType->getParamType(1));
2730 }
2731
2732 auto NewFType =
2733 FunctionType::get(FType->getReturnType(), ParamTypes, false);
2734 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2735
2736 // We need to map the OpenCL constants to the SPIR-V equivalents.
2737 const auto ConstantScopeDevice =
2738 ConstantInt::get(IntTy, spv::ScopeDevice);
2739 const auto ConstantMemorySemantics = ConstantInt::get(
2740 IntTy, spv::MemorySemanticsUniformMemoryMask |
2741 spv::MemorySemanticsSequentiallyConsistentMask);
2742
2743 SmallVector<Value *, 5> Params;
2744
2745 // The pointer.
2746 Params.push_back(CI->getArgOperand(0));
2747
2748 // The memory scope.
2749 Params.push_back(ConstantScopeDevice);
2750
2751 // The memory semantics.
2752 Params.push_back(ConstantMemorySemantics);
2753
2754 if (2 < CI->getNumArgOperands()) {
2755 // The unequal memory semantics.
2756 Params.push_back(ConstantMemorySemantics);
2757
2758 // The value.
2759 Params.push_back(CI->getArgOperand(2));
2760
2761 // The comparator.
2762 Params.push_back(CI->getArgOperand(1));
2763 } else if (1 < CI->getNumArgOperands()) {
2764 // The value.
2765 Params.push_back(CI->getArgOperand(1));
2766 }
2767
2768 auto NewCI = CallInst::Create(NewF, Params, "", CI);
2769
2770 CI->replaceAllUsesWith(NewCI);
2771
2772 // Lastly, remember to remove the user.
2773 ToRemoves.push_back(CI);
2774 }
2775 }
2776
2777 Changed = !ToRemoves.empty();
2778
2779 // And cleanup the calls we don't use anymore.
2780 for (auto V : ToRemoves) {
2781 V->eraseFromParent();
2782 }
2783
2784 // And remove the function we don't need either too.
2785 F->eraseFromParent();
2786 }
2787 }
2788
Neil Henning39672102017-09-29 14:33:13 +01002789 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002790 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002791 {"_Z8atom_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002792 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002793 {"_Z8atom_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002794 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002795 {"_Z8atom_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002796 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002797 {"_Z8atom_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002798 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002799 {"_Z9atom_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002800 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002801 {"_Z9atom_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002802 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002803 {"_Z8atom_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002804 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002805 {"_Z8atom_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002806 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002807 {"_Z8atom_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002808 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002809 {"_Z8atom_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002810 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002811 {"_Z8atom_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002812 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002813 {"_Z8atom_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002814 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002815 {"_Z7atom_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002816 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002817 {"_Z7atom_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002818 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002819 {"_Z8atom_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002820 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002821 {"_Z8atom_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002822 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002823 {"_Z10atomic_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002824 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002825 {"_Z10atomic_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002826 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002827 {"_Z10atomic_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002828 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002829 {"_Z10atomic_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002830 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002831 {"_Z11atomic_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002832 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002833 {"_Z11atomic_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002834 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002835 {"_Z10atomic_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Neil Henning39672102017-09-29 14:33:13 +01002836 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002837 {"_Z10atomic_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Neil Henning39672102017-09-29 14:33:13 +01002838 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002839 {"_Z10atomic_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Neil Henning39672102017-09-29 14:33:13 +01002840 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002841 {"_Z10atomic_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Neil Henning39672102017-09-29 14:33:13 +01002842 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002843 {"_Z10atomic_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002844 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002845 {"_Z10atomic_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002846 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002847 {"_Z9atomic_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002848 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002849 {"_Z9atomic_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002850 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002851 {"_Z10atomic_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
2852 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
2853 {"_Z10atomic_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor}};
Neil Henning39672102017-09-29 14:33:13 +01002854
2855 for (auto Pair : Map2) {
2856 // If we find a function with the matching name.
2857 if (auto F = M.getFunction(Pair.first)) {
2858 SmallVector<Instruction *, 4> ToRemoves;
2859
2860 // Walk the users of the function.
2861 for (auto &U : F->uses()) {
2862 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2863 auto AtomicOp = new AtomicRMWInst(
2864 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2865 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2866
2867 CI->replaceAllUsesWith(AtomicOp);
2868
2869 // Lastly, remember to remove the user.
2870 ToRemoves.push_back(CI);
2871 }
2872 }
2873
2874 Changed = !ToRemoves.empty();
2875
2876 // And cleanup the calls we don't use anymore.
2877 for (auto V : ToRemoves) {
2878 V->eraseFromParent();
2879 }
2880
2881 // And remove the function we don't need either too.
2882 F->eraseFromParent();
2883 }
2884 }
2885
David Neto22f144c2017-06-12 14:26:21 -04002886 return Changed;
2887}
2888
2889bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002890
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002891 std::vector<const char *> Names = {
2892 "_Z5crossDv4_fS_",
Kévin Petite8edce32019-04-10 14:23:32 +01002893 };
2894
2895 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
David Neto22f144c2017-06-12 14:26:21 -04002896 auto IntTy = Type::getInt32Ty(M.getContext());
2897 auto FloatTy = Type::getFloatTy(M.getContext());
2898
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002899 Constant *DownShuffleMask[3] = {ConstantInt::get(IntTy, 0),
2900 ConstantInt::get(IntTy, 1),
2901 ConstantInt::get(IntTy, 2)};
David Neto22f144c2017-06-12 14:26:21 -04002902
2903 Constant *UpShuffleMask[4] = {
2904 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2905 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2906
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002907 Constant *FloatVec[3] = {ConstantFP::get(FloatTy, 0.0f),
2908 UndefValue::get(FloatTy),
2909 UndefValue::get(FloatTy)};
David Neto22f144c2017-06-12 14:26:21 -04002910
Kévin Petite8edce32019-04-10 14:23:32 +01002911 auto Vec4Ty = CI->getArgOperand(0)->getType();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002912 auto Arg0 =
2913 new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty),
2914 ConstantVector::get(DownShuffleMask), "", CI);
2915 auto Arg1 =
2916 new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty),
2917 ConstantVector::get(DownShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002918 auto Vec3Ty = Arg0->getType();
David Neto22f144c2017-06-12 14:26:21 -04002919
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002920 auto NewFType = FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
David Neto22f144c2017-06-12 14:26:21 -04002921
Kévin Petite8edce32019-04-10 14:23:32 +01002922 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002923
Kévin Petite8edce32019-04-10 14:23:32 +01002924 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002925
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002926 return new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec),
2927 ConstantVector::get(UpShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002928 });
David Neto22f144c2017-06-12 14:26:21 -04002929}
David Neto62653202017-10-16 19:05:18 -04002930
2931bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2932 bool Changed = false;
2933
2934 // OpenCL's float result = fract(float x, float* ptr)
2935 //
2936 // In the LLVM domain:
2937 //
2938 // %floor_result = call spir_func float @floor(float %x)
2939 // store float %floor_result, float * %ptr
2940 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2941 // %result = call spir_func float
2942 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2943 //
2944 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2945 // and clspv.fract occur in the SPIR-V generator pass:
2946 //
2947 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2948 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2949 // ...
2950 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2951 // OpStore %ptr %floor_result
2952 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2953 // %fract_result = OpExtInst %float
2954 // %glsl_ext Fmin %fract_intermediate %just_under_1
2955
David Neto62653202017-10-16 19:05:18 -04002956 using std::string;
2957
2958 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2959 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002960 using QuadType =
2961 std::tuple<const char *, const char *, const char *, const char *>;
David Neto62653202017-10-16 19:05:18 -04002962 auto make_quad = [](const char *a, const char *b, const char *c,
2963 const char *d) {
2964 return std::tuple<const char *, const char *, const char *, const char *>(
2965 a, b, c, d);
2966 };
2967 const std::vector<QuadType> Functions = {
2968 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002969 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff",
2970 "clspv.fract.v2f"),
2971 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff",
2972 "clspv.fract.v3f"),
2973 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff",
2974 "clspv.fract.v4f"),
David Neto62653202017-10-16 19:05:18 -04002975 };
2976
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002977 for (auto &quad : Functions) {
David Neto62653202017-10-16 19:05:18 -04002978 const StringRef fract_name(std::get<0>(quad));
2979
2980 // If we find a function with the matching name.
2981 if (auto F = M.getFunction(fract_name)) {
2982 if (F->use_begin() == F->use_end())
2983 continue;
2984
2985 // We have some uses.
2986 Changed = true;
2987
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002988 auto &Context = M.getContext();
David Neto62653202017-10-16 19:05:18 -04002989
2990 const StringRef floor_name(std::get<1>(quad));
2991 const StringRef fmin_name(std::get<2>(quad));
2992 const StringRef clspv_fract_name(std::get<3>(quad));
2993
2994 // This is either float or a float vector. All the float-like
2995 // types are this type.
2996 auto result_ty = F->getReturnType();
2997
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002998 Function *fmin_fn = M.getFunction(fmin_name);
David Neto62653202017-10-16 19:05:18 -04002999 if (!fmin_fn) {
3000 // Make the fmin function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003001 FunctionType *fn_ty =
3002 FunctionType::get(result_ty, {result_ty, result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003003 fmin_fn =
3004 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003005 fmin_fn->addFnAttr(Attribute::ReadNone);
3006 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
3007 }
3008
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003009 Function *floor_fn = M.getFunction(floor_name);
David Neto62653202017-10-16 19:05:18 -04003010 if (!floor_fn) {
3011 // Make the floor function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003012 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003013 floor_fn = cast<Function>(
3014 M.getOrInsertFunction(floor_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003015 floor_fn->addFnAttr(Attribute::ReadNone);
3016 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
3017 }
3018
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003019 Function *clspv_fract_fn = M.getFunction(clspv_fract_name);
David Neto62653202017-10-16 19:05:18 -04003020 if (!clspv_fract_fn) {
3021 // Make the clspv_fract function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003022 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003023 clspv_fract_fn = cast<Function>(
3024 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003025 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
3026 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
3027 }
3028
3029 // Number of significant significand bits, whether represented or not.
3030 unsigned num_significand_bits;
3031 switch (result_ty->getScalarType()->getTypeID()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003032 case Type::HalfTyID:
3033 num_significand_bits = 11;
3034 break;
3035 case Type::FloatTyID:
3036 num_significand_bits = 24;
3037 break;
3038 case Type::DoubleTyID:
3039 num_significand_bits = 53;
3040 break;
3041 default:
3042 assert(false && "Unhandled float type when processing fract builtin");
3043 break;
David Neto62653202017-10-16 19:05:18 -04003044 }
3045 // Beware that the disassembler displays this value as
3046 // OpConstant %float 1
3047 // which is not quite right.
3048 const double kJustUnderOneScalar =
3049 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3050
3051 Constant *just_under_one =
3052 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3053 if (result_ty->isVectorTy()) {
3054 just_under_one = ConstantVector::getSplat(
3055 result_ty->getVectorNumElements(), just_under_one);
3056 }
3057
3058 IRBuilder<> Builder(Context);
3059
3060 SmallVector<Instruction *, 4> ToRemoves;
3061
3062 // Walk the users of the function.
3063 for (auto &U : F->uses()) {
3064 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3065
3066 Builder.SetInsertPoint(CI);
3067 auto arg = CI->getArgOperand(0);
3068 auto ptr = CI->getArgOperand(1);
3069
3070 // Compute floor result and store it.
3071 auto floor = Builder.CreateCall(floor_fn, {arg});
3072 Builder.CreateStore(floor, ptr);
3073
3074 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003075 auto fract_result =
3076 Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
David Neto62653202017-10-16 19:05:18 -04003077
3078 CI->replaceAllUsesWith(fract_result);
3079
3080 // Lastly, remember to remove the user.
3081 ToRemoves.push_back(CI);
3082 }
3083 }
3084
3085 // And cleanup the calls we don't use anymore.
3086 for (auto V : ToRemoves) {
3087 V->eraseFromParent();
3088 }
3089
3090 // And remove the function we don't need either too.
3091 F->eraseFromParent();
3092 }
3093 }
3094
3095 return Changed;
3096}