blob: 3411176836822c120f01ac61f017a10539731f7b [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
David Neto118188e2018-08-24 11:27:54 -040021#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040022#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040023#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000024#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040025#include "llvm/Pass.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040029
David Neto118188e2018-08-24 11:27:54 -040030#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040031
alan-baker931d18a2019-12-12 08:21:32 -050032#include "clspv/AddressSpace.h"
James Pricec05f6052020-01-14 13:37:20 -050033#include "clspv/DescriptorMap.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040034#include "clspv/Option.h"
David Neto482550a2018-03-24 05:21:07 -070035
alan-baker931d18a2019-12-12 08:21:32 -050036#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040037#include "Passes.h"
38#include "SPIRVOp.h"
alan-bakerf906d2b2019-12-10 11:26:23 -050039#include "Types.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040040
David Neto22f144c2017-06-12 14:26:21 -040041using namespace llvm;
42
43#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
44
45namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000046
47struct ArgTypeInfo {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040048 enum class SignedNess { None, Unsigned, Signed };
Kévin Petit8a560882019-03-21 15:24:34 +000049 SignedNess signedness;
50};
51
52struct FunctionInfo {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000053 StringRef name;
Kévin Petit8a560882019-03-21 15:24:34 +000054 std::vector<ArgTypeInfo> argTypeInfos;
Kévin Petit8a560882019-03-21 15:24:34 +000055
Kévin Petit91bc72e2019-04-08 15:17:46 +010056 bool isArgSigned(size_t arg) const {
57 assert(argTypeInfos.size() > arg);
58 return argTypeInfos[arg].signedness == ArgTypeInfo::SignedNess::Signed;
Kévin Petit8a560882019-03-21 15:24:34 +000059 }
60
Kévin Petit91bc72e2019-04-08 15:17:46 +010061 static FunctionInfo getFromMangledName(StringRef name) {
62 FunctionInfo fi;
63 if (!getFromMangledNameCheck(name, &fi)) {
64 llvm_unreachable("Can't parse mangled function name!");
Kévin Petit8a560882019-03-21 15:24:34 +000065 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010066 return fi;
67 }
Kévin Petit8a560882019-03-21 15:24:34 +000068
Kévin Petit91bc72e2019-04-08 15:17:46 +010069 static bool getFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
70 if (!name.consume_front("_Z")) {
71 return false;
72 }
73 size_t nameLen;
74 if (name.consumeInteger(10, nameLen)) {
Kévin Petit8a560882019-03-21 15:24:34 +000075 return false;
76 }
77
Kévin Petit91bc72e2019-04-08 15:17:46 +010078 finfo->name = name.take_front(nameLen);
79 name = name.drop_front(nameLen);
Kévin Petit8a560882019-03-21 15:24:34 +000080
Kévin Petit91bc72e2019-04-08 15:17:46 +010081 ArgTypeInfo prev_ti;
Kévin Petit8a560882019-03-21 15:24:34 +000082
Kévin Petit91bc72e2019-04-08 15:17:46 +010083 while (name.size() != 0) {
84
85 ArgTypeInfo ti;
86
87 // Try parsing a vector prefix
88 if (name.consume_front("Dv")) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040089 int numElems;
90 if (name.consumeInteger(10, numElems)) {
91 return false;
92 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010093
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040094 if (!name.consume_front("_")) {
95 return false;
96 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010097 }
98
99 // Parse the base type
100 char typeCode = name.front();
101 name = name.drop_front(1);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400102 switch (typeCode) {
Kévin Petit91bc72e2019-04-08 15:17:46 +0100103 case 'c': // char
104 case 'a': // signed char
105 case 's': // short
106 case 'i': // int
107 case 'l': // long
108 ti.signedness = ArgTypeInfo::SignedNess::Signed;
109 break;
110 case 'h': // unsigned char
111 case 't': // unsigned short
112 case 'j': // unsigned int
113 case 'm': // unsigned long
114 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
115 break;
116 case 'f':
117 ti.signedness = ArgTypeInfo::SignedNess::None;
118 break;
119 case 'S':
120 ti = prev_ti;
121 if (!name.consume_front("_")) {
122 return false;
123 }
124 break;
125 default:
126 return false;
127 }
128
129 finfo->argTypeInfos.push_back(ti);
130
131 prev_ti = ti;
132 }
133
134 return true;
135 };
Kévin Petit8a560882019-03-21 15:24:34 +0000136};
137
David Neto22f144c2017-06-12 14:26:21 -0400138uint32_t clz(uint32_t v) {
139 uint32_t r;
140 uint32_t shift;
141
142 r = (v > 0xFFFF) << 4;
143 v >>= r;
144 shift = (v > 0xFF) << 3;
145 v >>= shift;
146 r |= shift;
147 shift = (v > 0xF) << 2;
148 v >>= shift;
149 r |= shift;
150 shift = (v > 0x3) << 1;
151 v >>= shift;
152 r |= shift;
153 r |= (v >> 1);
154
155 return r;
156}
157
158Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
159 if (1 == elements) {
160 return Type::getInt1Ty(C);
161 } else {
162 return VectorType::get(Type::getInt1Ty(C), elements);
163 }
164}
165
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100166Type *getIntOrIntVectorTyForCast(LLVMContext &C, Type *Ty) {
167 Type *IntTy = Type::getIntNTy(C, Ty->getScalarSizeInBits());
168 if (Ty->isVectorTy()) {
169 IntTy = VectorType::get(IntTy, Ty->getVectorNumElements());
170 }
171 return IntTy;
172}
173
David Neto22f144c2017-06-12 14:26:21 -0400174struct ReplaceOpenCLBuiltinPass final : public ModulePass {
175 static char ID;
176 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
177
178 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000179 bool replaceAbs(Module &M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100180 bool replaceAbsDiff(Module &M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100181 bool replaceCopysign(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400182 bool replaceRecip(Module &M);
183 bool replaceDivide(Module &M);
Kévin Petit1329a002019-06-15 05:54:05 +0100184 bool replaceDot(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400185 bool replaceExp10(Module &M);
Kévin Petit0644a9c2019-06-20 21:08:46 +0100186 bool replaceFmod(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400187 bool replaceLog10(Module &M);
188 bool replaceBarrier(Module &M);
189 bool replaceMemFence(Module &M);
190 bool replaceRelational(Module &M);
191 bool replaceIsInfAndIsNan(Module &M);
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100192 bool replaceIsFinite(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400193 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000194 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000195 bool replaceRotate(Module &M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000196 bool replaceConvert(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000197 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000198 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000199 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000200 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400201 bool replaceSignbit(Module &M);
202 bool replaceMadandMad24andMul24(Module &M);
203 bool replaceVloadHalf(Module &M);
204 bool replaceVloadHalf2(Module &M);
205 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700206 bool replaceClspvVloadaHalf2(Module &M);
207 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400208 bool replaceVstoreHalf(Module &M);
209 bool replaceVstoreHalf2(Module &M);
210 bool replaceVstoreHalf4(Module &M);
alan-bakerf7e17cb2020-01-02 07:29:59 -0500211 bool replaceHalfReadImage(Module &M);
212 bool replaceHalfWriteImage(Module &M);
alan-baker931d18a2019-12-12 08:21:32 -0500213 bool replaceUnsampledReadImage(Module &M);
Kévin Petit06517a12019-12-09 19:40:31 +0000214 bool replaceSampledReadImageWithIntCoords(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400215 bool replaceAtomics(Module &M);
216 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400217 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700218 bool replaceVload(Module &M);
219 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400220};
Kévin Petit91bc72e2019-04-08 15:17:46 +0100221} // namespace
David Neto22f144c2017-06-12 14:26:21 -0400222
223char ReplaceOpenCLBuiltinPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400224INITIALIZE_PASS(ReplaceOpenCLBuiltinPass, "ReplaceOpenCLBuiltin",
225 "Replace OpenCL Builtins Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -0400226
227namespace clspv {
228ModulePass *createReplaceOpenCLBuiltinPass() {
229 return new ReplaceOpenCLBuiltinPass();
230}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400231} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -0400232
233bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
234 bool Changed = false;
235
Kévin Petit2444e9b2018-11-09 14:14:37 +0000236 Changed |= replaceAbs(M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100237 Changed |= replaceAbsDiff(M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100238 Changed |= replaceCopysign(M);
David Neto22f144c2017-06-12 14:26:21 -0400239 Changed |= replaceRecip(M);
240 Changed |= replaceDivide(M);
Kévin Petit1329a002019-06-15 05:54:05 +0100241 Changed |= replaceDot(M);
David Neto22f144c2017-06-12 14:26:21 -0400242 Changed |= replaceExp10(M);
Kévin Petit0644a9c2019-06-20 21:08:46 +0100243 Changed |= replaceFmod(M);
David Neto22f144c2017-06-12 14:26:21 -0400244 Changed |= replaceLog10(M);
245 Changed |= replaceBarrier(M);
246 Changed |= replaceMemFence(M);
247 Changed |= replaceRelational(M);
248 Changed |= replaceIsInfAndIsNan(M);
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100249 Changed |= replaceIsFinite(M);
David Neto22f144c2017-06-12 14:26:21 -0400250 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000251 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000252 Changed |= replaceRotate(M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000253 Changed |= replaceConvert(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000254 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000255 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000256 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000257 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400258 Changed |= replaceSignbit(M);
259 Changed |= replaceMadandMad24andMul24(M);
260 Changed |= replaceVloadHalf(M);
261 Changed |= replaceVloadHalf2(M);
262 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700263 Changed |= replaceClspvVloadaHalf2(M);
264 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400265 Changed |= replaceVstoreHalf(M);
266 Changed |= replaceVstoreHalf2(M);
267 Changed |= replaceVstoreHalf4(M);
alan-bakerf7e17cb2020-01-02 07:29:59 -0500268 // Replace the half image builtins before handling other image builtins.
269 Changed |= replaceHalfReadImage(M);
270 Changed |= replaceHalfWriteImage(M);
alan-baker931d18a2019-12-12 08:21:32 -0500271 // Replace unsampled reads before converting sampled read coordinates.
272 Changed |= replaceUnsampledReadImage(M);
Kévin Petit06517a12019-12-09 19:40:31 +0000273 Changed |= replaceSampledReadImageWithIntCoords(M);
David Neto22f144c2017-06-12 14:26:21 -0400274 Changed |= replaceAtomics(M);
275 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400276 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700277 Changed |= replaceVload(M);
278 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400279
280 return Changed;
281}
282
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400283bool replaceCallsWithValue(Module &M, std::vector<const char *> Names,
284 std::function<Value *(CallInst *)> Replacer) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000285
Kévin Petite8edce32019-04-10 14:23:32 +0100286 bool Changed = false;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000287
288 for (auto Name : Names) {
289 // If we find a function with the matching name.
290 if (auto F = M.getFunction(Name)) {
291 SmallVector<Instruction *, 4> ToRemoves;
292
293 // Walk the users of the function.
294 for (auto &U : F->uses()) {
295 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000296
Kévin Petite8edce32019-04-10 14:23:32 +0100297 auto NewValue = Replacer(CI);
298
299 if (NewValue != nullptr) {
300 CI->replaceAllUsesWith(NewValue);
301 }
Kévin Petit2444e9b2018-11-09 14:14:37 +0000302
303 // Lastly, remember to remove the user.
304 ToRemoves.push_back(CI);
305 }
306 }
307
308 Changed = !ToRemoves.empty();
309
310 // And cleanup the calls we don't use anymore.
311 for (auto V : ToRemoves) {
312 V->eraseFromParent();
313 }
314
315 // And remove the function we don't need either too.
316 F->eraseFromParent();
317 }
318 }
319
320 return Changed;
321}
322
Kévin Petite8edce32019-04-10 14:23:32 +0100323bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
Kévin Petit91bc72e2019-04-08 15:17:46 +0100324
Kévin Petite8edce32019-04-10 14:23:32 +0100325 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400326 "_Z3absh", "_Z3absDv2_h", "_Z3absDv3_h", "_Z3absDv4_h",
327 "_Z3abst", "_Z3absDv2_t", "_Z3absDv3_t", "_Z3absDv4_t",
328 "_Z3absj", "_Z3absDv2_j", "_Z3absDv3_j", "_Z3absDv4_j",
329 "_Z3absm", "_Z3absDv2_m", "_Z3absDv3_m", "_Z3absDv4_m",
Kévin Petite8edce32019-04-10 14:23:32 +0100330 };
331
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400332 return replaceCallsWithValue(M, Names,
333 [](CallInst *CI) { return CI->getOperand(0); });
Kévin Petite8edce32019-04-10 14:23:32 +0100334}
335
336bool ReplaceOpenCLBuiltinPass::replaceAbsDiff(Module &M) {
337
338 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400339 "_Z8abs_diffcc", "_Z8abs_diffDv2_cS_", "_Z8abs_diffDv3_cS_",
340 "_Z8abs_diffDv4_cS_", "_Z8abs_diffhh", "_Z8abs_diffDv2_hS_",
341 "_Z8abs_diffDv3_hS_", "_Z8abs_diffDv4_hS_", "_Z8abs_diffss",
342 "_Z8abs_diffDv2_sS_", "_Z8abs_diffDv3_sS_", "_Z8abs_diffDv4_sS_",
343 "_Z8abs_difftt", "_Z8abs_diffDv2_tS_", "_Z8abs_diffDv3_tS_",
344 "_Z8abs_diffDv4_tS_", "_Z8abs_diffii", "_Z8abs_diffDv2_iS_",
345 "_Z8abs_diffDv3_iS_", "_Z8abs_diffDv4_iS_", "_Z8abs_diffjj",
346 "_Z8abs_diffDv2_jS_", "_Z8abs_diffDv3_jS_", "_Z8abs_diffDv4_jS_",
347 "_Z8abs_diffll", "_Z8abs_diffDv2_lS_", "_Z8abs_diffDv3_lS_",
348 "_Z8abs_diffDv4_lS_", "_Z8abs_diffmm", "_Z8abs_diffDv2_mS_",
349 "_Z8abs_diffDv3_mS_", "_Z8abs_diffDv4_mS_",
Kévin Petit91bc72e2019-04-08 15:17:46 +0100350 };
351
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400352 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100353 auto XValue = CI->getOperand(0);
354 auto YValue = CI->getOperand(1);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100355
Kévin Petite8edce32019-04-10 14:23:32 +0100356 IRBuilder<> Builder(CI);
357 auto XmY = Builder.CreateSub(XValue, YValue);
358 auto YmX = Builder.CreateSub(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100359
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400360 Value *Cmp;
Kévin Petite8edce32019-04-10 14:23:32 +0100361 auto F = CI->getCalledFunction();
362 auto finfo = FunctionInfo::getFromMangledName(F->getName());
363 if (finfo.isArgSigned(0)) {
364 Cmp = Builder.CreateICmpSGT(YValue, XValue);
365 } else {
366 Cmp = Builder.CreateICmpUGT(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100367 }
Kévin Petit91bc72e2019-04-08 15:17:46 +0100368
Kévin Petite8edce32019-04-10 14:23:32 +0100369 return Builder.CreateSelect(Cmp, YmX, XmY);
370 });
Kévin Petit91bc72e2019-04-08 15:17:46 +0100371}
372
Kévin Petit8c1be282019-04-02 19:34:25 +0100373bool ReplaceOpenCLBuiltinPass::replaceCopysign(Module &M) {
Kévin Petit8c1be282019-04-02 19:34:25 +0100374
Kévin Petite8edce32019-04-10 14:23:32 +0100375 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400376 "_Z8copysignff",
377 "_Z8copysignDv2_fS_",
378 "_Z8copysignDv3_fS_",
379 "_Z8copysignDv4_fS_",
Kévin Petit8c1be282019-04-02 19:34:25 +0100380 };
381
Kévin Petite8edce32019-04-10 14:23:32 +0100382 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
383 auto XValue = CI->getOperand(0);
384 auto YValue = CI->getOperand(1);
Kévin Petit8c1be282019-04-02 19:34:25 +0100385
Kévin Petite8edce32019-04-10 14:23:32 +0100386 auto Ty = XValue->getType();
Kévin Petit8c1be282019-04-02 19:34:25 +0100387
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400388 Type *IntTy = Type::getIntNTy(M.getContext(), Ty->getScalarSizeInBits());
Kévin Petite8edce32019-04-10 14:23:32 +0100389 if (Ty->isVectorTy()) {
390 IntTy = VectorType::get(IntTy, Ty->getVectorNumElements());
Kévin Petit8c1be282019-04-02 19:34:25 +0100391 }
Kévin Petit8c1be282019-04-02 19:34:25 +0100392
Kévin Petite8edce32019-04-10 14:23:32 +0100393 // Return X with the sign of Y
394
395 // Sign bit masks
396 auto SignBit = IntTy->getScalarSizeInBits() - 1;
397 auto SignBitMask = 1 << SignBit;
398 auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
399 auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);
400
401 IRBuilder<> Builder(CI);
402
403 // Extract sign of Y
404 auto YInt = Builder.CreateBitCast(YValue, IntTy);
405 auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);
406
407 // Clear sign bit in X
408 auto XInt = Builder.CreateBitCast(XValue, IntTy);
409 XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);
410
411 // Insert sign bit of Y into X
412 auto NewXInt = Builder.CreateOr(XInt, YSign);
413
414 // And cast back to floating-point
415 return Builder.CreateBitCast(NewXInt, Ty);
416 });
Kévin Petit8c1be282019-04-02 19:34:25 +0100417}
418
David Neto22f144c2017-06-12 14:26:21 -0400419bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400420
Kévin Petite8edce32019-04-10 14:23:32 +0100421 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400422 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
423 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
424 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
425 };
426
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400427 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100428 // Recip has one arg.
429 auto Arg = CI->getOperand(0);
430 auto Cst1 = ConstantFP::get(Arg->getType(), 1.0);
431 return BinaryOperator::Create(Instruction::FDiv, Cst1, Arg, "", CI);
432 });
David Neto22f144c2017-06-12 14:26:21 -0400433}
434
435bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400436
Kévin Petite8edce32019-04-10 14:23:32 +0100437 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400438 "_Z11half_divideff", "_Z13native_divideff",
439 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
440 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
441 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
442 };
443
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400444 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100445 auto Op0 = CI->getOperand(0);
446 auto Op1 = CI->getOperand(1);
447 return BinaryOperator::Create(Instruction::FDiv, Op0, Op1, "", CI);
448 });
David Neto22f144c2017-06-12 14:26:21 -0400449}
450
Kévin Petit1329a002019-06-15 05:54:05 +0100451bool ReplaceOpenCLBuiltinPass::replaceDot(Module &M) {
452
453 std::vector<const char *> Names = {
454 "_Z3dotff",
455 "_Z3dotDv2_fS_",
456 "_Z3dotDv3_fS_",
457 "_Z3dotDv4_fS_",
458 };
459
460 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
461 auto Op0 = CI->getOperand(0);
462 auto Op1 = CI->getOperand(1);
463
464 Value *V;
465 if (Op0->getType()->isVectorTy()) {
466 V = clspv::InsertSPIRVOp(CI, spv::OpDot, {Attribute::ReadNone},
467 CI->getType(), {Op0, Op1});
468 } else {
469 V = BinaryOperator::Create(Instruction::FMul, Op0, Op1, "", CI);
470 }
471
472 return V;
473 });
474}
475
David Neto22f144c2017-06-12 14:26:21 -0400476bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
477 bool Changed = false;
478
479 const std::map<const char *, const char *> Map = {
480 {"_Z5exp10f", "_Z3expf"},
481 {"_Z10half_exp10f", "_Z8half_expf"},
482 {"_Z12native_exp10f", "_Z10native_expf"},
483 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
484 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
485 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
486 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
487 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
488 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
489 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
490 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
491 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
492
493 for (auto Pair : Map) {
494 // If we find a function with the matching name.
495 if (auto F = M.getFunction(Pair.first)) {
496 SmallVector<Instruction *, 4> ToRemoves;
497
498 // Walk the users of the function.
499 for (auto &U : F->uses()) {
500 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
501 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
502
503 auto Arg = CI->getOperand(0);
504
505 // Constant of the natural log of 10 (ln(10)).
506 const double Ln10 =
507 2.302585092994045684017991454684364207601101488628772976033;
508
509 auto Mul = BinaryOperator::Create(
510 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
511 CI);
512
513 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
514
515 CI->replaceAllUsesWith(NewCI);
516
517 // Lastly, remember to remove the user.
518 ToRemoves.push_back(CI);
519 }
520 }
521
522 Changed = !ToRemoves.empty();
523
524 // And cleanup the calls we don't use anymore.
525 for (auto V : ToRemoves) {
526 V->eraseFromParent();
527 }
528
529 // And remove the function we don't need either too.
530 F->eraseFromParent();
531 }
532 }
533
534 return Changed;
535}
536
Kévin Petit0644a9c2019-06-20 21:08:46 +0100537bool ReplaceOpenCLBuiltinPass::replaceFmod(Module &M) {
538
539 std::vector<const char *> Names = {
540 "_Z4fmodff",
541 "_Z4fmodDv2_fS_",
542 "_Z4fmodDv3_fS_",
543 "_Z4fmodDv4_fS_",
544 };
545
546 // OpenCL fmod(x,y) is x - y * trunc(x/y)
547 // The sign for a non-zero result is taken from x.
548 // (Try an example.)
549 // So translate to FRem
550 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
551 auto Op0 = CI->getOperand(0);
552 auto Op1 = CI->getOperand(1);
553 return BinaryOperator::Create(Instruction::FRem, Op0, Op1, "", CI);
554 });
555}
556
David Neto22f144c2017-06-12 14:26:21 -0400557bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
558 bool Changed = false;
559
560 const std::map<const char *, const char *> Map = {
561 {"_Z5log10f", "_Z3logf"},
562 {"_Z10half_log10f", "_Z8half_logf"},
563 {"_Z12native_log10f", "_Z10native_logf"},
564 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
565 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
566 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
567 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
568 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
569 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
570 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
571 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
572 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
573
574 for (auto Pair : Map) {
575 // If we find a function with the matching name.
576 if (auto F = M.getFunction(Pair.first)) {
577 SmallVector<Instruction *, 4> ToRemoves;
578
579 // Walk the users of the function.
580 for (auto &U : F->uses()) {
581 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
582 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
583
584 auto Arg = CI->getOperand(0);
585
586 // Constant of the reciprocal of the natural log of 10 (ln(10)).
587 const double Ln10 =
588 0.434294481903251827651128918916605082294397005803666566114;
589
590 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
591
592 auto Mul = BinaryOperator::Create(
593 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
594 "", CI);
595
596 CI->replaceAllUsesWith(Mul);
597
598 // Lastly, remember to remove the user.
599 ToRemoves.push_back(CI);
600 }
601 }
602
603 Changed = !ToRemoves.empty();
604
605 // And cleanup the calls we don't use anymore.
606 for (auto V : ToRemoves) {
607 V->eraseFromParent();
608 }
609
610 // And remove the function we don't need either too.
611 F->eraseFromParent();
612 }
613 }
614
615 return Changed;
616}
617
618bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400619
620 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
621
alan-bakerb60b1fc2019-12-13 19:09:38 -0500622 const std::vector<const char *> Names = {"_Z7barrierj",
623 // OpenCL 2.0 alias for barrier.
624 "_Z18work_group_barrierj"};
David Neto22f144c2017-06-12 14:26:21 -0400625
Kévin Petitc4643922019-06-17 19:32:05 +0100626 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
627 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400628
Kévin Petitc4643922019-06-17 19:32:05 +0100629 // We need to map the OpenCL constants to the SPIR-V equivalents.
630 const auto LocalMemFence =
631 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
632 const auto GlobalMemFence =
633 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
634 const auto ConstantSequentiallyConsistent = ConstantInt::get(
635 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
636 const auto ConstantScopeDevice =
637 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
638 const auto ConstantScopeWorkgroup =
639 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
David Neto22f144c2017-06-12 14:26:21 -0400640
Kévin Petitc4643922019-06-17 19:32:05 +0100641 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
642 const auto LocalMemFenceMask =
643 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
644 const auto WorkgroupShiftAmount =
645 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
646 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
647 Instruction::Shl, LocalMemFenceMask,
648 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400649
Kévin Petitc4643922019-06-17 19:32:05 +0100650 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
651 const auto GlobalMemFenceMask =
652 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
653 const auto UniformShiftAmount =
654 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
655 const auto MemorySemanticsUniform = BinaryOperator::Create(
656 Instruction::Shl, GlobalMemFenceMask,
657 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400658
Kévin Petitc4643922019-06-17 19:32:05 +0100659 // And combine the above together, also adding in
660 // MemorySemanticsSequentiallyConsistentMask.
661 auto MemorySemantics =
662 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
663 ConstantSequentiallyConsistent, "", CI);
664 MemorySemantics = BinaryOperator::Create(Instruction::Or, MemorySemantics,
665 MemorySemanticsUniform, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400666
Kévin Petitc4643922019-06-17 19:32:05 +0100667 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
668 // Device Scope, otherwise Workgroup Scope.
669 const auto Cmp =
670 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, GlobalMemFenceMask,
671 GlobalMemFence, "", CI);
672 const auto MemoryScope = SelectInst::Create(Cmp, ConstantScopeDevice,
673 ConstantScopeWorkgroup, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400674
Kévin Petitc4643922019-06-17 19:32:05 +0100675 // Lastly, the Execution Scope is always Workgroup Scope.
676 const auto ExecutionScope = ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400677
Kévin Petitc4643922019-06-17 19:32:05 +0100678 return clspv::InsertSPIRVOp(CI, spv::OpControlBarrier,
679 {Attribute::NoDuplicate}, CI->getType(),
680 {ExecutionScope, MemoryScope, MemorySemantics});
681 });
David Neto22f144c2017-06-12 14:26:21 -0400682}
683
684bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
685 bool Changed = false;
686
687 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
688
Kévin Petitc4643922019-06-17 19:32:05 +0100689 using Tuple = std::tuple<spv::Op, unsigned>;
Neil Henning39672102017-09-29 14:33:13 +0100690 const std::map<const char *, Tuple> Map = {
Kévin Petitc4643922019-06-17 19:32:05 +0100691 {"_Z9mem_fencej", Tuple(spv::OpMemoryBarrier,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400692 spv::MemorySemanticsSequentiallyConsistentMask)},
Neil Henning39672102017-09-29 14:33:13 +0100693 {"_Z14read_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100694 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsAcquireMask)},
Neil Henning39672102017-09-29 14:33:13 +0100695 {"_Z15write_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100696 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400697
698 for (auto Pair : Map) {
699 // If we find a function with the matching name.
700 if (auto F = M.getFunction(Pair.first)) {
701 SmallVector<Instruction *, 4> ToRemoves;
702
703 // Walk the users of the function.
704 for (auto &U : F->uses()) {
705 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -0400706
707 auto Arg = CI->getOperand(0);
708
709 // We need to map the OpenCL constants to the SPIR-V equivalents.
710 const auto LocalMemFence =
711 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
712 const auto GlobalMemFence =
713 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
714 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100715 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400716 const auto ConstantScopeDevice =
717 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
718
719 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
720 const auto LocalMemFenceMask = BinaryOperator::Create(
721 Instruction::And, LocalMemFence, Arg, "", CI);
722 const auto WorkgroupShiftAmount =
723 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
724 clz(CLK_LOCAL_MEM_FENCE);
725 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
726 Instruction::Shl, LocalMemFenceMask,
727 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
728
729 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
730 const auto GlobalMemFenceMask = BinaryOperator::Create(
731 Instruction::And, GlobalMemFence, Arg, "", CI);
732 const auto UniformShiftAmount =
733 clz(spv::MemorySemanticsUniformMemoryMask) -
734 clz(CLK_GLOBAL_MEM_FENCE);
735 const auto MemorySemanticsUniform = BinaryOperator::Create(
736 Instruction::Shl, GlobalMemFenceMask,
737 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
738
739 // And combine the above together, also adding in
740 // MemorySemanticsSequentiallyConsistentMask.
741 auto MemorySemantics =
742 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
743 ConstantMemorySemantics, "", CI);
744 MemorySemantics = BinaryOperator::Create(
745 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
746
747 // Memory Scope is always device.
748 const auto MemoryScope = ConstantScopeDevice;
749
Kévin Petitc4643922019-06-17 19:32:05 +0100750 const auto SPIRVOp = std::get<0>(Pair.second);
751 auto NewCI = clspv::InsertSPIRVOp(CI, SPIRVOp, {}, CI->getType(),
752 {MemoryScope, MemorySemantics});
David Neto22f144c2017-06-12 14:26:21 -0400753
754 CI->replaceAllUsesWith(NewCI);
755
756 // Lastly, remember to remove the user.
757 ToRemoves.push_back(CI);
758 }
759 }
760
761 Changed = !ToRemoves.empty();
762
763 // And cleanup the calls we don't use anymore.
764 for (auto V : ToRemoves) {
765 V->eraseFromParent();
766 }
767
768 // And remove the function we don't need either too.
769 F->eraseFromParent();
770 }
771 }
772
773 return Changed;
774}
775
776bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
777 bool Changed = false;
778
779 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
780 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
781 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
782 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
783 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
784 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
785 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
786 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
787 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
788 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
789 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
790 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
791 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
792 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
793 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
794 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
795 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
796 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
797 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
798 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
799 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
800 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
801 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
802 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
803 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
804 };
805
806 for (auto Pair : Map) {
807 // If we find a function with the matching name.
808 if (auto F = M.getFunction(Pair.first)) {
809 SmallVector<Instruction *, 4> ToRemoves;
810
811 // Walk the users of the function.
812 for (auto &U : F->uses()) {
813 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
814 // The predicate to use in the CmpInst.
815 auto Predicate = Pair.second.first;
816
817 // The value to return for true.
818 auto TrueValue =
819 ConstantInt::getSigned(CI->getType(), Pair.second.second);
820
821 // The value to return for false.
822 auto FalseValue = Constant::getNullValue(CI->getType());
823
824 auto Arg1 = CI->getOperand(0);
825 auto Arg2 = CI->getOperand(1);
826
827 const auto Cmp =
828 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
829
830 const auto Select =
831 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
832
833 CI->replaceAllUsesWith(Select);
834
835 // Lastly, remember to remove the user.
836 ToRemoves.push_back(CI);
837 }
838 }
839
840 Changed = !ToRemoves.empty();
841
842 // And cleanup the calls we don't use anymore.
843 for (auto V : ToRemoves) {
844 V->eraseFromParent();
845 }
846
847 // And remove the function we don't need either too.
848 F->eraseFromParent();
849 }
850 }
851
852 return Changed;
853}
854
855bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
856 bool Changed = false;
857
Kévin Petitff03aee2019-06-12 19:39:03 +0100858 const std::map<const char *, std::pair<spv::Op, int32_t>> Map = {
859 {"_Z5isinff", {spv::OpIsInf, 1}},
860 {"_Z5isinfDv2_f", {spv::OpIsInf, -1}},
861 {"_Z5isinfDv3_f", {spv::OpIsInf, -1}},
862 {"_Z5isinfDv4_f", {spv::OpIsInf, -1}},
863 {"_Z5isnanf", {spv::OpIsNan, 1}},
864 {"_Z5isnanDv2_f", {spv::OpIsNan, -1}},
865 {"_Z5isnanDv3_f", {spv::OpIsNan, -1}},
866 {"_Z5isnanDv4_f", {spv::OpIsNan, -1}},
David Neto22f144c2017-06-12 14:26:21 -0400867 };
868
869 for (auto Pair : Map) {
870 // If we find a function with the matching name.
871 if (auto F = M.getFunction(Pair.first)) {
872 SmallVector<Instruction *, 4> ToRemoves;
873
874 // Walk the users of the function.
875 for (auto &U : F->uses()) {
876 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
877 const auto CITy = CI->getType();
878
Kévin Petitff03aee2019-06-12 19:39:03 +0100879 auto SPIRVOp = Pair.second.first;
David Neto22f144c2017-06-12 14:26:21 -0400880
881 // The value to return for true.
882 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
883
884 // The value to return for false.
885 auto FalseValue = Constant::getNullValue(CITy);
886
887 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
888 M.getContext(),
889 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
890
Kévin Petitff03aee2019-06-12 19:39:03 +0100891 auto NewCI =
892 clspv::InsertSPIRVOp(CI, SPIRVOp, {Attribute::ReadNone},
893 CorrespondingBoolTy, {CI->getOperand(0)});
David Neto22f144c2017-06-12 14:26:21 -0400894
895 const auto Select =
896 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
897
898 CI->replaceAllUsesWith(Select);
899
900 // Lastly, remember to remove the user.
901 ToRemoves.push_back(CI);
902 }
903 }
904
905 Changed = !ToRemoves.empty();
906
907 // And cleanup the calls we don't use anymore.
908 for (auto V : ToRemoves) {
909 V->eraseFromParent();
910 }
911
912 // And remove the function we don't need either too.
913 F->eraseFromParent();
914 }
915 }
916
917 return Changed;
918}
919
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100920bool ReplaceOpenCLBuiltinPass::replaceIsFinite(Module &M) {
921 std::vector<const char *> Names = {
922 "_Z8isfiniteh", "_Z8isfiniteDv2_h", "_Z8isfiniteDv3_h",
923 "_Z8isfiniteDv4_h", "_Z8isfinitef", "_Z8isfiniteDv2_f",
924 "_Z8isfiniteDv3_f", "_Z8isfiniteDv4_f", "_Z8isfinited",
925 "_Z8isfiniteDv2_d", "_Z8isfiniteDv3_d", "_Z8isfiniteDv4_d",
926 };
927
928 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
929 auto &C = M.getContext();
930 auto Val = CI->getOperand(0);
931 auto ValTy = Val->getType();
932 auto RetTy = CI->getType();
933
934 // Get a suitable integer type to represent the number
935 auto IntTy = getIntOrIntVectorTyForCast(C, ValTy);
936
937 // Create Mask
938 auto ScalarSize = ValTy->getScalarSizeInBits();
939 Value *InfMask;
940 switch (ScalarSize) {
941 case 16:
942 InfMask = ConstantInt::get(IntTy, 0x7C00U);
943 break;
944 case 32:
945 InfMask = ConstantInt::get(IntTy, 0x7F800000U);
946 break;
947 case 64:
948 InfMask = ConstantInt::get(IntTy, 0x7FF0000000000000ULL);
949 break;
950 default:
951 llvm_unreachable("Unsupported floating-point type");
952 }
953
954 IRBuilder<> Builder(CI);
955
956 // Bitcast to int
957 auto ValInt = Builder.CreateBitCast(Val, IntTy);
958
959 // Mask and compare
960 auto InfBits = Builder.CreateAnd(InfMask, ValInt);
961 auto Cmp = Builder.CreateICmp(CmpInst::ICMP_EQ, InfBits, InfMask);
962
963 auto RetFalse = ConstantInt::get(RetTy, 0);
964 Value *RetTrue;
965 if (ValTy->isVectorTy()) {
966 RetTrue = ConstantInt::getSigned(RetTy, -1);
967 } else {
968 RetTrue = ConstantInt::get(RetTy, 1);
969 }
970 return Builder.CreateSelect(Cmp, RetFalse, RetTrue);
971 });
972}
973
David Neto22f144c2017-06-12 14:26:21 -0400974bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
975 bool Changed = false;
976
Kévin Petitff03aee2019-06-12 19:39:03 +0100977 const std::map<const char *, spv::Op> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000978 // all
Kévin Petitff03aee2019-06-12 19:39:03 +0100979 {"_Z3allc", spv::OpNop},
980 {"_Z3allDv2_c", spv::OpAll},
981 {"_Z3allDv3_c", spv::OpAll},
982 {"_Z3allDv4_c", spv::OpAll},
983 {"_Z3alls", spv::OpNop},
984 {"_Z3allDv2_s", spv::OpAll},
985 {"_Z3allDv3_s", spv::OpAll},
986 {"_Z3allDv4_s", spv::OpAll},
987 {"_Z3alli", spv::OpNop},
988 {"_Z3allDv2_i", spv::OpAll},
989 {"_Z3allDv3_i", spv::OpAll},
990 {"_Z3allDv4_i", spv::OpAll},
991 {"_Z3alll", spv::OpNop},
992 {"_Z3allDv2_l", spv::OpAll},
993 {"_Z3allDv3_l", spv::OpAll},
994 {"_Z3allDv4_l", spv::OpAll},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000995
996 // any
Kévin Petitff03aee2019-06-12 19:39:03 +0100997 {"_Z3anyc", spv::OpNop},
998 {"_Z3anyDv2_c", spv::OpAny},
999 {"_Z3anyDv3_c", spv::OpAny},
1000 {"_Z3anyDv4_c", spv::OpAny},
1001 {"_Z3anys", spv::OpNop},
1002 {"_Z3anyDv2_s", spv::OpAny},
1003 {"_Z3anyDv3_s", spv::OpAny},
1004 {"_Z3anyDv4_s", spv::OpAny},
1005 {"_Z3anyi", spv::OpNop},
1006 {"_Z3anyDv2_i", spv::OpAny},
1007 {"_Z3anyDv3_i", spv::OpAny},
1008 {"_Z3anyDv4_i", spv::OpAny},
1009 {"_Z3anyl", spv::OpNop},
1010 {"_Z3anyDv2_l", spv::OpAny},
1011 {"_Z3anyDv3_l", spv::OpAny},
1012 {"_Z3anyDv4_l", spv::OpAny},
David Neto22f144c2017-06-12 14:26:21 -04001013 };
1014
1015 for (auto Pair : Map) {
1016 // If we find a function with the matching name.
1017 if (auto F = M.getFunction(Pair.first)) {
1018 SmallVector<Instruction *, 4> ToRemoves;
1019
1020 // Walk the users of the function.
1021 for (auto &U : F->uses()) {
1022 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -04001023
1024 auto Arg = CI->getOperand(0);
1025
1026 Value *V;
1027
Kévin Petitfd27cca2018-10-31 13:00:17 +00001028 // If the argument is a 32-bit int, just use a shift
1029 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
1030 V = BinaryOperator::Create(Instruction::LShr, Arg,
1031 ConstantInt::get(Arg->getType(), 31), "",
1032 CI);
1033 } else {
David Neto22f144c2017-06-12 14:26:21 -04001034 // The value for zero to compare against.
1035 const auto ZeroValue = Constant::getNullValue(Arg->getType());
1036
David Neto22f144c2017-06-12 14:26:21 -04001037 // The value to return for true.
1038 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
1039
1040 // The value to return for false.
1041 const auto FalseValue = Constant::getNullValue(CI->getType());
1042
Kévin Petitfd27cca2018-10-31 13:00:17 +00001043 const auto Cmp = CmpInst::Create(
1044 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
1045
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001046 Value *SelectSource;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001047
1048 // If we have a function to call, call it!
Kévin Petitff03aee2019-06-12 19:39:03 +01001049 const auto SPIRVOp = Pair.second;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001050
Kévin Petitff03aee2019-06-12 19:39:03 +01001051 if (SPIRVOp != spv::OpNop) {
Kévin Petitfd27cca2018-10-31 13:00:17 +00001052
Kévin Petitff03aee2019-06-12 19:39:03 +01001053 const auto BoolTy = Type::getInt1Ty(M.getContext());
Kévin Petitfd27cca2018-10-31 13:00:17 +00001054
Kévin Petitff03aee2019-06-12 19:39:03 +01001055 const auto NewCI = clspv::InsertSPIRVOp(
1056 CI, SPIRVOp, {Attribute::ReadNone}, BoolTy, {Cmp});
Kévin Petitfd27cca2018-10-31 13:00:17 +00001057 SelectSource = NewCI;
1058
1059 } else {
1060 SelectSource = Cmp;
1061 }
1062
1063 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001064 }
1065
1066 CI->replaceAllUsesWith(V);
1067
1068 // Lastly, remember to remove the user.
1069 ToRemoves.push_back(CI);
1070 }
1071 }
1072
1073 Changed = !ToRemoves.empty();
1074
1075 // And cleanup the calls we don't use anymore.
1076 for (auto V : ToRemoves) {
1077 V->eraseFromParent();
1078 }
1079
1080 // And remove the function we don't need either too.
1081 F->eraseFromParent();
1082 }
1083 }
1084
1085 return Changed;
1086}
1087
Kévin Petitbf0036c2019-03-06 13:57:10 +00001088bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
1089 bool Changed = false;
1090
1091 for (auto const &SymVal : M.getValueSymbolTable()) {
1092 // Skip symbols whose name doesn't match
1093 if (!SymVal.getKey().startswith("_Z8upsample")) {
1094 continue;
1095 }
1096 // Is there a function going by that name?
1097 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1098
1099 SmallVector<Instruction *, 4> ToRemoves;
1100
1101 // Walk the users of the function.
1102 for (auto &U : F->uses()) {
1103 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1104
1105 // Get arguments
1106 auto HiValue = CI->getOperand(0);
1107 auto LoValue = CI->getOperand(1);
1108
1109 // Don't touch overloads that aren't in OpenCL C
1110 auto HiType = HiValue->getType();
1111 auto LoType = LoValue->getType();
1112
1113 if (HiType != LoType) {
1114 continue;
1115 }
1116
1117 if (!HiType->isIntOrIntVectorTy()) {
1118 continue;
1119 }
1120
1121 if (HiType->getScalarSizeInBits() * 2 !=
1122 CI->getType()->getScalarSizeInBits()) {
1123 continue;
1124 }
1125
1126 if ((HiType->getScalarSizeInBits() != 8) &&
1127 (HiType->getScalarSizeInBits() != 16) &&
1128 (HiType->getScalarSizeInBits() != 32)) {
1129 continue;
1130 }
1131
1132 if (HiType->isVectorTy()) {
1133 if ((HiType->getVectorNumElements() != 2) &&
1134 (HiType->getVectorNumElements() != 3) &&
1135 (HiType->getVectorNumElements() != 4) &&
1136 (HiType->getVectorNumElements() != 8) &&
1137 (HiType->getVectorNumElements() != 16)) {
1138 continue;
1139 }
1140 }
1141
1142 // Convert both operands to the result type
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001143 auto HiCast =
1144 CastInst::CreateZExtOrBitCast(HiValue, CI->getType(), "", CI);
1145 auto LoCast =
1146 CastInst::CreateZExtOrBitCast(LoValue, CI->getType(), "", CI);
Kévin Petitbf0036c2019-03-06 13:57:10 +00001147
1148 // Shift high operand
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001149 auto ShiftAmount =
1150 ConstantInt::get(CI->getType(), HiType->getScalarSizeInBits());
Kévin Petitbf0036c2019-03-06 13:57:10 +00001151 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1152 ShiftAmount, "", CI);
1153
1154 // OR both results
1155 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1156 "", CI);
1157
1158 // Replace call with the expression
1159 CI->replaceAllUsesWith(V);
1160
1161 // Lastly, remember to remove the user.
1162 ToRemoves.push_back(CI);
1163 }
1164 }
1165
1166 Changed = !ToRemoves.empty();
1167
1168 // And cleanup the calls we don't use anymore.
1169 for (auto V : ToRemoves) {
1170 V->eraseFromParent();
1171 }
1172
1173 // And remove the function we don't need either too.
1174 F->eraseFromParent();
1175 }
1176 }
1177
1178 return Changed;
1179}
1180
Kévin Petitd44eef52019-03-08 13:22:14 +00001181bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1182 bool Changed = false;
1183
1184 for (auto const &SymVal : M.getValueSymbolTable()) {
1185 // Skip symbols whose name doesn't match
1186 if (!SymVal.getKey().startswith("_Z6rotate")) {
1187 continue;
1188 }
1189 // Is there a function going by that name?
1190 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1191
1192 SmallVector<Instruction *, 4> ToRemoves;
1193
1194 // Walk the users of the function.
1195 for (auto &U : F->uses()) {
1196 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1197
1198 // Get arguments
1199 auto SrcValue = CI->getOperand(0);
1200 auto RotAmount = CI->getOperand(1);
1201
1202 // Don't touch overloads that aren't in OpenCL C
1203 auto SrcType = SrcValue->getType();
1204 auto RotType = RotAmount->getType();
1205
1206 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1207 continue;
1208 }
1209
1210 if (!SrcType->isIntOrIntVectorTy()) {
1211 continue;
1212 }
1213
1214 if ((SrcType->getScalarSizeInBits() != 8) &&
1215 (SrcType->getScalarSizeInBits() != 16) &&
1216 (SrcType->getScalarSizeInBits() != 32) &&
1217 (SrcType->getScalarSizeInBits() != 64)) {
1218 continue;
1219 }
1220
1221 if (SrcType->isVectorTy()) {
1222 if ((SrcType->getVectorNumElements() != 2) &&
1223 (SrcType->getVectorNumElements() != 3) &&
1224 (SrcType->getVectorNumElements() != 4) &&
1225 (SrcType->getVectorNumElements() != 8) &&
1226 (SrcType->getVectorNumElements() != 16)) {
1227 continue;
1228 }
1229 }
1230
1231 // The approach used is to shift the top bits down, the bottom bits up
1232 // and OR the two shifted values.
1233
1234 // The rotation amount is to be treated modulo the element size.
1235 // Since SPIR-V shift ops don't support this, let's apply the
1236 // modulo ahead of shifting. The element size is always a power of
1237 // two so we can just AND with a mask.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001238 auto ModMask =
1239 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits() - 1);
Kévin Petitd44eef52019-03-08 13:22:14 +00001240 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1241 ModMask, "", CI);
1242
1243 // Let's calc the amount by which to shift top bits down
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001244 auto ScalarSize =
1245 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits());
Kévin Petitd44eef52019-03-08 13:22:14 +00001246 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1247 RotAmount, "", CI);
1248
1249 // Now shift the bottom bits up and the top bits down
1250 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1251 RotAmount, "", CI);
1252 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1253 DownAmount, "", CI);
1254
1255 // Finally OR the two shifted values
1256 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1257 HiRotated, "", CI);
1258
1259 // Replace call with the expression
1260 CI->replaceAllUsesWith(V);
1261
1262 // Lastly, remember to remove the user.
1263 ToRemoves.push_back(CI);
1264 }
1265 }
1266
1267 Changed = !ToRemoves.empty();
1268
1269 // And cleanup the calls we don't use anymore.
1270 for (auto V : ToRemoves) {
1271 V->eraseFromParent();
1272 }
1273
1274 // And remove the function we don't need either too.
1275 F->eraseFromParent();
1276 }
1277 }
1278
1279 return Changed;
1280}
1281
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001282bool ReplaceOpenCLBuiltinPass::replaceConvert(Module &M) {
1283 bool Changed = false;
1284
1285 for (auto const &SymVal : M.getValueSymbolTable()) {
1286
1287 // Skip symbols whose name obviously doesn't match
1288 if (!SymVal.getKey().contains("convert_")) {
1289 continue;
1290 }
1291
1292 // Is there a function going by that name?
1293 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1294
1295 // Get info from the mangled name
1296 FunctionInfo finfo;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001297 bool parsed = FunctionInfo::getFromMangledNameCheck(F->getName(), &finfo);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001298
1299 // All functions of interest are handled by our mangled name parser
1300 if (!parsed) {
1301 continue;
1302 }
1303
1304 // Move on if this isn't a call to convert_
1305 if (!finfo.name.startswith("convert_")) {
1306 continue;
1307 }
1308
1309 // Extract the destination type from the function name
1310 StringRef DstTypeName = finfo.name;
1311 DstTypeName.consume_front("convert_");
1312
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001313 auto DstSignedNess =
1314 StringSwitch<ArgTypeInfo::SignedNess>(DstTypeName)
1315 .StartsWith("char", ArgTypeInfo::SignedNess::Signed)
1316 .StartsWith("short", ArgTypeInfo::SignedNess::Signed)
1317 .StartsWith("int", ArgTypeInfo::SignedNess::Signed)
1318 .StartsWith("long", ArgTypeInfo::SignedNess::Signed)
1319 .StartsWith("uchar", ArgTypeInfo::SignedNess::Unsigned)
1320 .StartsWith("ushort", ArgTypeInfo::SignedNess::Unsigned)
1321 .StartsWith("uint", ArgTypeInfo::SignedNess::Unsigned)
1322 .StartsWith("ulong", ArgTypeInfo::SignedNess::Unsigned)
1323 .Default(ArgTypeInfo::SignedNess::None);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001324
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001325 bool DstIsSigned = DstSignedNess == ArgTypeInfo::SignedNess::Signed;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001326 bool SrcIsSigned = finfo.isArgSigned(0);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001327
1328 SmallVector<Instruction *, 4> ToRemoves;
1329
1330 // Walk the users of the function.
1331 for (auto &U : F->uses()) {
1332 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1333
1334 // Get arguments
1335 auto SrcValue = CI->getOperand(0);
1336
1337 // Don't touch overloads that aren't in OpenCL C
1338 auto SrcType = SrcValue->getType();
1339 auto DstType = CI->getType();
1340
1341 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1342 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1343 continue;
1344 }
1345
1346 if (SrcType->isVectorTy()) {
1347
1348 if (SrcType->getVectorNumElements() !=
1349 DstType->getVectorNumElements()) {
1350 continue;
1351 }
1352
1353 if ((SrcType->getVectorNumElements() != 2) &&
1354 (SrcType->getVectorNumElements() != 3) &&
1355 (SrcType->getVectorNumElements() != 4) &&
1356 (SrcType->getVectorNumElements() != 8) &&
1357 (SrcType->getVectorNumElements() != 16)) {
1358 continue;
1359 }
1360 }
1361
1362 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1363 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1364
1365 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1366 bool DstIsInt = DstType->isIntOrIntVectorTy();
1367
1368 Value *V;
1369 if (SrcIsFloat && DstIsFloat) {
1370 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1371 } else if (SrcIsFloat && DstIsInt) {
1372 if (DstIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001373 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "",
1374 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001375 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001376 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "",
1377 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001378 }
1379 } else if (SrcIsInt && DstIsFloat) {
1380 if (SrcIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001381 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "",
1382 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001383 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001384 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "",
1385 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001386 }
1387 } else if (SrcIsInt && DstIsInt) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001388 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "",
1389 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001390 } else {
1391 // Not something we're supposed to handle, just move on
1392 continue;
1393 }
1394
1395 // Replace call with the expression
1396 CI->replaceAllUsesWith(V);
1397
1398 // Lastly, remember to remove the user.
1399 ToRemoves.push_back(CI);
1400 }
1401 }
1402
1403 Changed = !ToRemoves.empty();
1404
1405 // And cleanup the calls we don't use anymore.
1406 for (auto V : ToRemoves) {
1407 V->eraseFromParent();
1408 }
1409
1410 // And remove the function we don't need either too.
1411 F->eraseFromParent();
1412 }
1413 }
1414
1415 return Changed;
1416}
1417
Kévin Petit8a560882019-03-21 15:24:34 +00001418bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1419 bool Changed = false;
1420
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001421 SmallVector<Function *, 4> FnWorklist;
Kévin Petit8a560882019-03-21 15:24:34 +00001422
Kévin Petit617a76d2019-04-04 13:54:16 +01001423 for (auto const &SymVal : M.getValueSymbolTable()) {
Kévin Petit8a560882019-03-21 15:24:34 +00001424 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1425 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1426
1427 // Skip symbols whose name doesn't match
1428 if (!isMad && !isMul) {
1429 continue;
1430 }
1431
1432 // Is there a function going by that name?
1433 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001434 FnWorklist.push_back(F);
Kévin Petit8a560882019-03-21 15:24:34 +00001435 }
1436 }
1437
Kévin Petit617a76d2019-04-04 13:54:16 +01001438 for (auto F : FnWorklist) {
1439 SmallVector<Instruction *, 4> ToRemoves;
1440
1441 bool isMad = F->getName().startswith("_Z6mad_hi");
1442 // Walk the users of the function.
1443 for (auto &U : F->uses()) {
1444 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1445
1446 // Get arguments
1447 auto AValue = CI->getOperand(0);
1448 auto BValue = CI->getOperand(1);
1449 auto CValue = CI->getOperand(2);
1450
1451 // Don't touch overloads that aren't in OpenCL C
1452 auto AType = AValue->getType();
1453 auto BType = BValue->getType();
1454 auto CType = CValue->getType();
1455
1456 if ((AType != BType) || (CI->getType() != AType) ||
1457 (isMad && (AType != CType))) {
1458 continue;
1459 }
1460
1461 if (!AType->isIntOrIntVectorTy()) {
1462 continue;
1463 }
1464
1465 if ((AType->getScalarSizeInBits() != 8) &&
1466 (AType->getScalarSizeInBits() != 16) &&
1467 (AType->getScalarSizeInBits() != 32) &&
1468 (AType->getScalarSizeInBits() != 64)) {
1469 continue;
1470 }
1471
1472 if (AType->isVectorTy()) {
1473 if ((AType->getVectorNumElements() != 2) &&
1474 (AType->getVectorNumElements() != 3) &&
1475 (AType->getVectorNumElements() != 4) &&
1476 (AType->getVectorNumElements() != 8) &&
1477 (AType->getVectorNumElements() != 16)) {
1478 continue;
1479 }
1480 }
1481
1482 // Get infos from the mangled OpenCL built-in function name
Kévin Petit91bc72e2019-04-08 15:17:46 +01001483 auto finfo = FunctionInfo::getFromMangledName(F->getName());
Kévin Petit617a76d2019-04-04 13:54:16 +01001484
1485 // Select the appropriate signed/unsigned SPIR-V op
1486 spv::Op opcode;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001487 if (finfo.isArgSigned(0)) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001488 opcode = spv::OpSMulExtended;
1489 } else {
1490 opcode = spv::OpUMulExtended;
1491 }
1492
1493 // Our SPIR-V op returns a struct, create a type for it
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001494 SmallVector<Type *, 2> TwoValueType = {AType, AType};
Kévin Petit617a76d2019-04-04 13:54:16 +01001495 auto ExMulRetType = StructType::create(TwoValueType);
1496
1497 // Call the SPIR-V op
1498 auto Call = clspv::InsertSPIRVOp(CI, opcode, {Attribute::ReadNone},
1499 ExMulRetType, {AValue, BValue});
1500
1501 // Get the high part of the result
1502 unsigned Idxs[] = {1};
1503 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1504
1505 // If we're handling a mad_hi, add the third argument to the result
1506 if (isMad) {
1507 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1508 }
1509
1510 // Replace call with the expression
1511 CI->replaceAllUsesWith(V);
1512
1513 // Lastly, remember to remove the user.
1514 ToRemoves.push_back(CI);
1515 }
1516 }
1517
1518 Changed = !ToRemoves.empty();
1519
1520 // And cleanup the calls we don't use anymore.
1521 for (auto V : ToRemoves) {
1522 V->eraseFromParent();
1523 }
1524
1525 // And remove the function we don't need either too.
1526 F->eraseFromParent();
1527 }
1528
Kévin Petit8a560882019-03-21 15:24:34 +00001529 return Changed;
1530}
1531
Kévin Petitf5b78a22018-10-25 14:32:17 +00001532bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1533 bool Changed = false;
1534
1535 for (auto const &SymVal : M.getValueSymbolTable()) {
1536 // Skip symbols whose name doesn't match
1537 if (!SymVal.getKey().startswith("_Z6select")) {
1538 continue;
1539 }
1540 // Is there a function going by that name?
1541 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1542
1543 SmallVector<Instruction *, 4> ToRemoves;
1544
1545 // Walk the users of the function.
1546 for (auto &U : F->uses()) {
1547 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1548
1549 // Get arguments
1550 auto FalseValue = CI->getOperand(0);
1551 auto TrueValue = CI->getOperand(1);
1552 auto PredicateValue = CI->getOperand(2);
1553
1554 // Don't touch overloads that aren't in OpenCL C
1555 auto FalseType = FalseValue->getType();
1556 auto TrueType = TrueValue->getType();
1557 auto PredicateType = PredicateValue->getType();
1558
1559 if (FalseType != TrueType) {
1560 continue;
1561 }
1562
1563 if (!PredicateType->isIntOrIntVectorTy()) {
1564 continue;
1565 }
1566
1567 if (!FalseType->isIntOrIntVectorTy() &&
1568 !FalseType->getScalarType()->isFloatingPointTy()) {
1569 continue;
1570 }
1571
1572 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1573 continue;
1574 }
1575
1576 if (FalseType->getScalarSizeInBits() !=
1577 PredicateType->getScalarSizeInBits()) {
1578 continue;
1579 }
1580
1581 if (FalseType->isVectorTy()) {
1582 if (FalseType->getVectorNumElements() !=
1583 PredicateType->getVectorNumElements()) {
1584 continue;
1585 }
1586
1587 if ((FalseType->getVectorNumElements() != 2) &&
1588 (FalseType->getVectorNumElements() != 3) &&
1589 (FalseType->getVectorNumElements() != 4) &&
1590 (FalseType->getVectorNumElements() != 8) &&
1591 (FalseType->getVectorNumElements() != 16)) {
1592 continue;
1593 }
1594 }
1595
1596 // Create constant
1597 const auto ZeroValue = Constant::getNullValue(PredicateType);
1598
1599 // Scalar and vector are to be treated differently
1600 CmpInst::Predicate Pred;
1601 if (PredicateType->isVectorTy()) {
1602 Pred = CmpInst::ICMP_SLT;
1603 } else {
1604 Pred = CmpInst::ICMP_NE;
1605 }
1606
1607 // Create comparison instruction
1608 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1609 ZeroValue, "", CI);
1610
1611 // Create select
1612 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1613
1614 // Replace call with the selection
1615 CI->replaceAllUsesWith(V);
1616
1617 // Lastly, remember to remove the user.
1618 ToRemoves.push_back(CI);
1619 }
1620 }
1621
1622 Changed = !ToRemoves.empty();
1623
1624 // And cleanup the calls we don't use anymore.
1625 for (auto V : ToRemoves) {
1626 V->eraseFromParent();
1627 }
1628
1629 // And remove the function we don't need either too.
1630 F->eraseFromParent();
1631 }
1632 }
1633
1634 return Changed;
1635}
1636
Kévin Petite7d0cce2018-10-31 12:38:56 +00001637bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1638 bool Changed = false;
1639
1640 for (auto const &SymVal : M.getValueSymbolTable()) {
1641 // Skip symbols whose name doesn't match
1642 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1643 continue;
1644 }
1645 // Is there a function going by that name?
1646 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1647
1648 SmallVector<Instruction *, 4> ToRemoves;
1649
1650 // Walk the users of the function.
1651 for (auto &U : F->uses()) {
1652 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1653
1654 if (CI->getNumOperands() != 4) {
1655 continue;
1656 }
1657
1658 // Get arguments
1659 auto FalseValue = CI->getOperand(0);
1660 auto TrueValue = CI->getOperand(1);
1661 auto PredicateValue = CI->getOperand(2);
1662
1663 // Don't touch overloads that aren't in OpenCL C
1664 auto FalseType = FalseValue->getType();
1665 auto TrueType = TrueValue->getType();
1666 auto PredicateType = PredicateValue->getType();
1667
1668 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1669 continue;
1670 }
1671
1672 if (TrueType->isVectorTy()) {
1673 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1674 !TrueType->getScalarType()->isIntegerTy()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001675 continue;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001676 }
1677 if ((TrueType->getVectorNumElements() != 2) &&
1678 (TrueType->getVectorNumElements() != 3) &&
1679 (TrueType->getVectorNumElements() != 4) &&
1680 (TrueType->getVectorNumElements() != 8) &&
1681 (TrueType->getVectorNumElements() != 16)) {
1682 continue;
1683 }
1684 }
1685
1686 // Remember the type of the operands
1687 auto OpType = TrueType;
1688
1689 // The actual bit selection will always be done on an integer type,
1690 // declare it here
1691 Type *BitType;
1692
1693 // If the operands are float, then bitcast them to int
1694 if (OpType->getScalarType()->isFloatingPointTy()) {
1695
1696 // First create the new type
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001697 BitType = getIntOrIntVectorTyForCast(M.getContext(), OpType);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001698
1699 // Then bitcast all operands
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001700 PredicateValue =
1701 CastInst::CreateZExtOrBitCast(PredicateValue, BitType, "", CI);
1702 FalseValue =
1703 CastInst::CreateZExtOrBitCast(FalseValue, BitType, "", CI);
1704 TrueValue =
1705 CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001706
1707 } else {
1708 // The operands have an integer type, use it directly
1709 BitType = OpType;
1710 }
1711
1712 // All the operands are now always integers
1713 // implement as (c & b) | (~c & a)
1714
1715 // Create our negated predicate value
1716 auto AllOnes = Constant::getAllOnesValue(BitType);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001717 auto NotPredicateValue = BinaryOperator::Create(
1718 Instruction::Xor, PredicateValue, AllOnes, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001719
1720 // Then put everything together
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001721 auto BitsFalse = BinaryOperator::Create(
1722 Instruction::And, NotPredicateValue, FalseValue, "", CI);
1723 auto BitsTrue = BinaryOperator::Create(
1724 Instruction::And, PredicateValue, TrueValue, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001725
1726 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1727 BitsTrue, "", CI);
1728
1729 // If we were dealing with a floating point type, we must bitcast
1730 // the result back to that
1731 if (OpType->getScalarType()->isFloatingPointTy()) {
1732 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1733 }
1734
1735 // Replace call with our new code
1736 CI->replaceAllUsesWith(V);
1737
1738 // Lastly, remember to remove the user.
1739 ToRemoves.push_back(CI);
1740 }
1741 }
1742
1743 Changed = !ToRemoves.empty();
1744
1745 // And cleanup the calls we don't use anymore.
1746 for (auto V : ToRemoves) {
1747 V->eraseFromParent();
1748 }
1749
1750 // And remove the function we don't need either too.
1751 F->eraseFromParent();
1752 }
1753 }
1754
1755 return Changed;
1756}
1757
Kévin Petit6b0a9532018-10-30 20:00:39 +00001758bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1759 bool Changed = false;
1760
1761 const std::map<const char *, const char *> Map = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001762 {"_Z4stepfDv2_f", "_Z4stepDv2_fS_"},
1763 {"_Z4stepfDv3_f", "_Z4stepDv3_fS_"},
1764 {"_Z4stepfDv4_f", "_Z4stepDv4_fS_"},
1765 {"_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_"},
1766 {"_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_"},
1767 {"_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_"},
Kévin Petit6b0a9532018-10-30 20:00:39 +00001768 };
1769
1770 for (auto Pair : Map) {
1771 // If we find a function with the matching name.
1772 if (auto F = M.getFunction(Pair.first)) {
1773 SmallVector<Instruction *, 4> ToRemoves;
1774
1775 // Walk the users of the function.
1776 for (auto &U : F->uses()) {
1777 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1778
1779 auto ReplacementFn = Pair.second;
1780
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001781 SmallVector<Value *, 2> ArgsToSplat = {CI->getOperand(0)};
Kévin Petit6b0a9532018-10-30 20:00:39 +00001782 Value *VectorArg;
1783
1784 // First figure out which function we're dealing with
1785 if (F->getName().startswith("_Z10smoothstep")) {
1786 ArgsToSplat.push_back(CI->getOperand(1));
1787 VectorArg = CI->getOperand(2);
1788 } else {
1789 VectorArg = CI->getOperand(1);
1790 }
1791
1792 // Splat arguments that need to be
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001793 SmallVector<Value *, 2> SplatArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001794 auto VecType = VectorArg->getType();
1795
1796 for (auto arg : ArgsToSplat) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001797 Value *NewVectorArg = UndefValue::get(VecType);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001798 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001799 auto index =
1800 ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1801 NewVectorArg =
1802 InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001803 }
1804 SplatArgs.push_back(NewVectorArg);
1805 }
1806
1807 // Replace the call with the vector/vector flavour
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001808 SmallVector<Type *, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1809 const auto NewFType =
1810 FunctionType::get(CI->getType(), NewArgTypes, false);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001811
1812 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1813
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001814 SmallVector<Value *, 3> NewArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001815 for (auto arg : SplatArgs) {
1816 NewArgs.push_back(arg);
1817 }
1818 NewArgs.push_back(VectorArg);
1819
1820 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1821
1822 CI->replaceAllUsesWith(NewCI);
1823
1824 // Lastly, remember to remove the user.
1825 ToRemoves.push_back(CI);
1826 }
1827 }
1828
1829 Changed = !ToRemoves.empty();
1830
1831 // And cleanup the calls we don't use anymore.
1832 for (auto V : ToRemoves) {
1833 V->eraseFromParent();
1834 }
1835
1836 // And remove the function we don't need either too.
1837 F->eraseFromParent();
1838 }
1839 }
1840
1841 return Changed;
1842}
1843
David Neto22f144c2017-06-12 14:26:21 -04001844bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1845 bool Changed = false;
1846
1847 const std::map<const char *, Instruction::BinaryOps> Map = {
1848 {"_Z7signbitf", Instruction::LShr},
1849 {"_Z7signbitDv2_f", Instruction::AShr},
1850 {"_Z7signbitDv3_f", Instruction::AShr},
1851 {"_Z7signbitDv4_f", Instruction::AShr},
1852 };
1853
1854 for (auto Pair : Map) {
1855 // If we find a function with the matching name.
1856 if (auto F = M.getFunction(Pair.first)) {
1857 SmallVector<Instruction *, 4> ToRemoves;
1858
1859 // Walk the users of the function.
1860 for (auto &U : F->uses()) {
1861 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1862 auto Arg = CI->getOperand(0);
1863
1864 auto Bitcast =
1865 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1866
1867 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1868 ConstantInt::get(CI->getType(), 31),
1869 "", CI);
1870
1871 CI->replaceAllUsesWith(Shr);
1872
1873 // Lastly, remember to remove the user.
1874 ToRemoves.push_back(CI);
1875 }
1876 }
1877
1878 Changed = !ToRemoves.empty();
1879
1880 // And cleanup the calls we don't use anymore.
1881 for (auto V : ToRemoves) {
1882 V->eraseFromParent();
1883 }
1884
1885 // And remove the function we don't need either too.
1886 F->eraseFromParent();
1887 }
1888 }
1889
1890 return Changed;
1891}
1892
1893bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1894 bool Changed = false;
1895
1896 const std::map<const char *,
1897 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1898 Map = {
1899 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1900 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1901 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1902 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
alan-bakerc21a65e2020-01-15 14:19:39 -05001903 {"_Z3madDhDhDh", {Instruction::FMul, Instruction::FAdd}},
1904 {"_Z3madDv2_DhS_S_", {Instruction::FMul, Instruction::FAdd}},
1905 {"_Z3madDv3_DhS_S_", {Instruction::FMul, Instruction::FAdd}},
1906 {"_Z3madDv4_DhS_S_", {Instruction::FMul, Instruction::FAdd}},
David Neto22f144c2017-06-12 14:26:21 -04001907 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1908 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1909 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1910 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1911 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1912 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1913 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1914 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1915 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1916 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1917 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1918 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1919 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1920 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1921 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1922 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1923 };
1924
1925 for (auto Pair : Map) {
1926 // If we find a function with the matching name.
1927 if (auto F = M.getFunction(Pair.first)) {
1928 SmallVector<Instruction *, 4> ToRemoves;
1929
1930 // Walk the users of the function.
1931 for (auto &U : F->uses()) {
1932 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1933 // The multiply instruction to use.
1934 auto MulInst = Pair.second.first;
1935
1936 // The add instruction to use.
1937 auto AddInst = Pair.second.second;
1938
1939 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1940
1941 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1942 CI->getArgOperand(1), "", CI);
1943
1944 if (Instruction::BinaryOpsEnd != AddInst) {
1945 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1946 CI);
1947 }
1948
1949 CI->replaceAllUsesWith(I);
1950
1951 // Lastly, remember to remove the user.
1952 ToRemoves.push_back(CI);
1953 }
1954 }
1955
1956 Changed = !ToRemoves.empty();
1957
1958 // And cleanup the calls we don't use anymore.
1959 for (auto V : ToRemoves) {
1960 V->eraseFromParent();
1961 }
1962
1963 // And remove the function we don't need either too.
1964 F->eraseFromParent();
1965 }
1966 }
1967
1968 return Changed;
1969}
1970
Derek Chowcfd368b2017-10-19 20:58:45 -07001971bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1972 bool Changed = false;
1973
alan-bakerf795f392019-06-11 18:24:34 -04001974 for (auto const &SymVal : M.getValueSymbolTable()) {
1975 if (!SymVal.getKey().contains("vstore"))
1976 continue;
1977 if (SymVal.getKey().contains("vstore_"))
1978 continue;
1979 if (SymVal.getKey().contains("vstorea"))
1980 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001981
alan-bakerf795f392019-06-11 18:24:34 -04001982 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001983 SmallVector<Instruction *, 4> ToRemoves;
1984
alan-bakerf795f392019-06-11 18:24:34 -04001985 auto fname = F->getName();
1986 if (!fname.consume_front("_Z"))
1987 continue;
1988 size_t name_len;
1989 if (fname.consumeInteger(10, name_len))
1990 continue;
1991 std::string name = fname.take_front(name_len);
1992
1993 bool ok = StringSwitch<bool>(name)
1994 .Case("vstore2", true)
1995 .Case("vstore3", true)
1996 .Case("vstore4", true)
1997 .Case("vstore8", true)
1998 .Case("vstore16", true)
1999 .Default(false);
2000 if (!ok)
2001 continue;
2002
Derek Chowcfd368b2017-10-19 20:58:45 -07002003 for (auto &U : F->uses()) {
2004 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04002005 auto data = CI->getOperand(0);
Derek Chowcfd368b2017-10-19 20:58:45 -07002006
alan-bakerf795f392019-06-11 18:24:34 -04002007 auto data_type = data->getType();
2008 if (!data_type->isVectorTy())
2009 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002010
alan-bakerf795f392019-06-11 18:24:34 -04002011 auto elems = data_type->getVectorNumElements();
2012 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
2013 elems != 16)
2014 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002015
alan-bakerf795f392019-06-11 18:24:34 -04002016 auto offset = CI->getOperand(1);
2017 auto ptr = CI->getOperand(2);
2018 auto ptr_type = ptr->getType();
2019 auto pointee_type = ptr_type->getPointerElementType();
2020 if (pointee_type != data_type->getVectorElementType())
2021 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002022
alan-bakerf795f392019-06-11 18:24:34 -04002023 // Avoid pointer casts. Instead generate the correct number of stores
2024 // and rely on drivers to coalesce appropriately.
2025 IRBuilder<> builder(CI);
2026 auto elems_const = builder.getInt32(elems);
2027 auto adjust = builder.CreateMul(offset, elems_const);
2028 for (auto i = 0; i < elems; ++i) {
2029 auto idx = builder.getInt32(i);
2030 auto add = builder.CreateAdd(adjust, idx);
2031 auto gep = builder.CreateGEP(ptr, add);
2032 auto extract = builder.CreateExtractElement(data, i);
2033 auto store = builder.CreateStore(extract, gep);
2034 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002035
Derek Chowcfd368b2017-10-19 20:58:45 -07002036 ToRemoves.push_back(CI);
2037 }
2038 }
2039
2040 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07002041 for (auto V : ToRemoves) {
2042 V->eraseFromParent();
2043 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002044 F->eraseFromParent();
2045 }
2046 }
2047
2048 return Changed;
2049}
2050
2051bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
2052 bool Changed = false;
2053
alan-bakerf795f392019-06-11 18:24:34 -04002054 for (auto const &SymVal : M.getValueSymbolTable()) {
2055 if (!SymVal.getKey().contains("vload"))
2056 continue;
2057 if (SymVal.getKey().contains("vload_"))
2058 continue;
2059 if (SymVal.getKey().contains("vloada"))
2060 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002061
alan-bakerf795f392019-06-11 18:24:34 -04002062 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07002063 SmallVector<Instruction *, 4> ToRemoves;
2064
alan-bakerf795f392019-06-11 18:24:34 -04002065 auto fname = F->getName();
2066 if (!fname.consume_front("_Z"))
2067 continue;
2068 size_t name_len;
2069 if (fname.consumeInteger(10, name_len))
2070 continue;
2071 std::string name = fname.take_front(name_len);
2072
2073 bool ok = StringSwitch<bool>(name)
2074 .Case("vload2", true)
2075 .Case("vload3", true)
2076 .Case("vload4", true)
2077 .Case("vload8", true)
2078 .Case("vload16", true)
2079 .Default(false);
2080 if (!ok)
2081 continue;
2082
Derek Chowcfd368b2017-10-19 20:58:45 -07002083 for (auto &U : F->uses()) {
2084 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04002085 auto ret_type = F->getReturnType();
2086 if (!ret_type->isVectorTy())
2087 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002088
alan-bakerf795f392019-06-11 18:24:34 -04002089 auto elems = ret_type->getVectorNumElements();
2090 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
2091 elems != 16)
2092 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002093
alan-bakerf795f392019-06-11 18:24:34 -04002094 auto offset = CI->getOperand(0);
2095 auto ptr = CI->getOperand(1);
2096 auto ptr_type = ptr->getType();
2097 auto pointee_type = ptr_type->getPointerElementType();
2098 if (pointee_type != ret_type->getVectorElementType())
2099 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002100
alan-bakerf795f392019-06-11 18:24:34 -04002101 // Avoid pointer casts. Instead generate the correct number of loads
2102 // and rely on drivers to coalesce appropriately.
2103 IRBuilder<> builder(CI);
2104 auto elems_const = builder.getInt32(elems);
2105 Value *insert = UndefValue::get(ret_type);
2106 auto adjust = builder.CreateMul(offset, elems_const);
2107 for (auto i = 0; i < elems; ++i) {
2108 auto idx = builder.getInt32(i);
2109 auto add = builder.CreateAdd(adjust, idx);
2110 auto gep = builder.CreateGEP(ptr, add);
2111 auto load = builder.CreateLoad(gep);
2112 insert = builder.CreateInsertElement(insert, load, i);
2113 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002114
alan-bakerf795f392019-06-11 18:24:34 -04002115 CI->replaceAllUsesWith(insert);
Derek Chowcfd368b2017-10-19 20:58:45 -07002116 ToRemoves.push_back(CI);
2117 }
2118 }
2119
2120 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07002121 for (auto V : ToRemoves) {
2122 V->eraseFromParent();
2123 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002124 F->eraseFromParent();
Derek Chowcfd368b2017-10-19 20:58:45 -07002125 }
2126 }
2127
2128 return Changed;
2129}
2130
David Neto22f144c2017-06-12 14:26:21 -04002131bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
2132 bool Changed = false;
2133
2134 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
2135 "_Z10vload_halfjPU3AS2KDh"};
2136
2137 for (auto Name : Map) {
2138 // If we find a function with the matching name.
2139 if (auto F = M.getFunction(Name)) {
2140 SmallVector<Instruction *, 4> ToRemoves;
2141
2142 // Walk the users of the function.
2143 for (auto &U : F->uses()) {
2144 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2145 // The index argument from vload_half.
2146 auto Arg0 = CI->getOperand(0);
2147
2148 // The pointer argument from vload_half.
2149 auto Arg1 = CI->getOperand(1);
2150
David Neto22f144c2017-06-12 14:26:21 -04002151 auto IntTy = Type::getInt32Ty(M.getContext());
2152 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002153 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2154
David Neto22f144c2017-06-12 14:26:21 -04002155 // Our intrinsic to unpack a float2 from an int.
2156 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2157
2158 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2159
David Neto482550a2018-03-24 05:21:07 -07002160 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04002161 auto ShortTy = Type::getInt16Ty(M.getContext());
2162 auto ShortPointerTy = PointerType::get(
2163 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002164
David Netoac825b82017-05-30 12:49:01 -04002165 // Cast the half* pointer to short*.
2166 auto Cast =
2167 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002168
David Netoac825b82017-05-30 12:49:01 -04002169 // Index into the correct address of the casted pointer.
2170 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
2171
2172 // Load from the short* we casted to.
2173 auto Load = new LoadInst(Index, "", CI);
2174
2175 // ZExt the short -> int.
2176 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
2177
2178 // Get our float2.
2179 auto Call = CallInst::Create(NewF, ZExt, "", CI);
2180
2181 // Extract out the bottom element which is our float result.
2182 auto Extract = ExtractElementInst::Create(
2183 Call, ConstantInt::get(IntTy, 0), "", CI);
2184
2185 CI->replaceAllUsesWith(Extract);
2186 } else {
2187 // Assume the pointer argument points to storage aligned to 32bits
2188 // or more.
2189 // TODO(dneto): Do more analysis to make sure this is true?
2190 //
2191 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
2192 // with:
2193 //
2194 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
2195 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
2196 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
2197 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
2198 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
2199 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
2200 // x float> %converted, %index_is_odd32
2201
2202 auto IntPointerTy = PointerType::get(
2203 IntTy, Arg1->getType()->getPointerAddressSpace());
2204
David Neto973e6a82017-05-30 13:48:18 -04002205 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04002206 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04002207 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04002208 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
2209
2210 auto One = ConstantInt::get(IntTy, 1);
2211 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
2212 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
2213
2214 // Index into the correct address of the casted pointer.
2215 auto Ptr =
2216 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
2217
2218 // Load from the int* we casted to.
2219 auto Load = new LoadInst(Ptr, "", CI);
2220
2221 // Get our float2.
2222 auto Call = CallInst::Create(NewF, Load, "", CI);
2223
2224 // Extract out the float result, where the element number is
2225 // determined by whether the original index was even or odd.
2226 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
2227
2228 CI->replaceAllUsesWith(Extract);
2229 }
David Neto22f144c2017-06-12 14:26:21 -04002230
2231 // Lastly, remember to remove the user.
2232 ToRemoves.push_back(CI);
2233 }
2234 }
2235
2236 Changed = !ToRemoves.empty();
2237
2238 // And cleanup the calls we don't use anymore.
2239 for (auto V : ToRemoves) {
2240 V->eraseFromParent();
2241 }
2242
2243 // And remove the function we don't need either too.
2244 F->eraseFromParent();
2245 }
2246 }
2247
2248 return Changed;
2249}
2250
2251bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002252
Kévin Petite8edce32019-04-10 14:23:32 +01002253 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002254 "_Z11vload_half2jPU3AS1KDh",
2255 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
2256 "_Z11vload_half2jPU3AS2KDh",
2257 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
2258 };
David Neto22f144c2017-06-12 14:26:21 -04002259
Kévin Petite8edce32019-04-10 14:23:32 +01002260 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2261 // The index argument from vload_half.
2262 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002263
Kévin Petite8edce32019-04-10 14:23:32 +01002264 // The pointer argument from vload_half.
2265 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002266
Kévin Petite8edce32019-04-10 14:23:32 +01002267 auto IntTy = Type::getInt32Ty(M.getContext());
2268 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002269 auto NewPointerTy =
2270 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002271 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002272
Kévin Petite8edce32019-04-10 14:23:32 +01002273 // Cast the half* pointer to int*.
2274 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002275
Kévin Petite8edce32019-04-10 14:23:32 +01002276 // Index into the correct address of the casted pointer.
2277 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002278
Kévin Petite8edce32019-04-10 14:23:32 +01002279 // Load from the int* we casted to.
2280 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002281
Kévin Petite8edce32019-04-10 14:23:32 +01002282 // Our intrinsic to unpack a float2 from an int.
2283 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002284
Kévin Petite8edce32019-04-10 14:23:32 +01002285 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002286
Kévin Petite8edce32019-04-10 14:23:32 +01002287 // Get our float2.
2288 return CallInst::Create(NewF, Load, "", CI);
2289 });
David Neto22f144c2017-06-12 14:26:21 -04002290}
2291
2292bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002293
Kévin Petite8edce32019-04-10 14:23:32 +01002294 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002295 "_Z11vload_half4jPU3AS1KDh",
2296 "_Z12vloada_half4jPU3AS1KDh",
2297 "_Z11vload_half4jPU3AS2KDh",
2298 "_Z12vloada_half4jPU3AS2KDh",
2299 };
David Neto22f144c2017-06-12 14:26:21 -04002300
Kévin Petite8edce32019-04-10 14:23:32 +01002301 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2302 // The index argument from vload_half.
2303 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002304
Kévin Petite8edce32019-04-10 14:23:32 +01002305 // The pointer argument from vload_half.
2306 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002307
Kévin Petite8edce32019-04-10 14:23:32 +01002308 auto IntTy = Type::getInt32Ty(M.getContext());
2309 auto Int2Ty = VectorType::get(IntTy, 2);
2310 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002311 auto NewPointerTy =
2312 PointerType::get(Int2Ty, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002313 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002314
Kévin Petite8edce32019-04-10 14:23:32 +01002315 // Cast the half* pointer to int2*.
2316 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002317
Kévin Petite8edce32019-04-10 14:23:32 +01002318 // Index into the correct address of the casted pointer.
2319 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002320
Kévin Petite8edce32019-04-10 14:23:32 +01002321 // Load from the int2* we casted to.
2322 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002323
Kévin Petite8edce32019-04-10 14:23:32 +01002324 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002325 auto X =
2326 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2327 auto Y =
2328 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002329
Kévin Petite8edce32019-04-10 14:23:32 +01002330 // Our intrinsic to unpack a float2 from an int.
2331 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002332
Kévin Petite8edce32019-04-10 14:23:32 +01002333 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002334
Kévin Petite8edce32019-04-10 14:23:32 +01002335 // Get the lower (x & y) components of our final float4.
2336 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002337
Kévin Petite8edce32019-04-10 14:23:32 +01002338 // Get the higher (z & w) components of our final float4.
2339 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002340
Kévin Petite8edce32019-04-10 14:23:32 +01002341 Constant *ShuffleMask[4] = {
2342 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2343 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002344
Kévin Petite8edce32019-04-10 14:23:32 +01002345 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002346 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2347 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002348 });
David Neto22f144c2017-06-12 14:26:21 -04002349}
2350
David Neto6ad93232018-06-07 15:42:58 -07002351bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002352
2353 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2354 //
2355 // %u = load i32 %ptr
2356 // %fxy = call <2 x float> Unpack2xHalf(u)
2357 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002358 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002359 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2360 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2361 "_Z20__clspv_vloada_half2jPKj", // private
2362 };
2363
Kévin Petite8edce32019-04-10 14:23:32 +01002364 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2365 auto Index = CI->getOperand(0);
2366 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002367
Kévin Petite8edce32019-04-10 14:23:32 +01002368 auto IntTy = Type::getInt32Ty(M.getContext());
2369 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2370 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002371
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002372 auto IndexedPtr = GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002373 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002374
Kévin Petite8edce32019-04-10 14:23:32 +01002375 // Our intrinsic to unpack a float2 from an int.
2376 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002377
Kévin Petite8edce32019-04-10 14:23:32 +01002378 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002379
Kévin Petite8edce32019-04-10 14:23:32 +01002380 // Get our final float2.
2381 return CallInst::Create(NewF, Load, "", CI);
2382 });
David Neto6ad93232018-06-07 15:42:58 -07002383}
2384
2385bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002386
2387 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2388 //
2389 // %u2 = load <2 x i32> %ptr
2390 // %u2xy = extractelement %u2, 0
2391 // %u2zw = extractelement %u2, 1
2392 // %fxy = call <2 x float> Unpack2xHalf(uint)
2393 // %fzw = call <2 x float> Unpack2xHalf(uint)
2394 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002395 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002396 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2397 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2398 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2399 };
2400
Kévin Petite8edce32019-04-10 14:23:32 +01002401 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2402 auto Index = CI->getOperand(0);
2403 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002404
Kévin Petite8edce32019-04-10 14:23:32 +01002405 auto IntTy = Type::getInt32Ty(M.getContext());
2406 auto Int2Ty = VectorType::get(IntTy, 2);
2407 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2408 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002409
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002410 auto IndexedPtr = GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002411 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002412
Kévin Petite8edce32019-04-10 14:23:32 +01002413 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002414 auto X =
2415 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2416 auto Y =
2417 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002418
Kévin Petite8edce32019-04-10 14:23:32 +01002419 // Our intrinsic to unpack a float2 from an int.
2420 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002421
Kévin Petite8edce32019-04-10 14:23:32 +01002422 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002423
Kévin Petite8edce32019-04-10 14:23:32 +01002424 // Get the lower (x & y) components of our final float4.
2425 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002426
Kévin Petite8edce32019-04-10 14:23:32 +01002427 // Get the higher (z & w) components of our final float4.
2428 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002429
Kévin Petite8edce32019-04-10 14:23:32 +01002430 Constant *ShuffleMask[4] = {
2431 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2432 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto6ad93232018-06-07 15:42:58 -07002433
Kévin Petite8edce32019-04-10 14:23:32 +01002434 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002435 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2436 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002437 });
David Neto6ad93232018-06-07 15:42:58 -07002438}
2439
David Neto22f144c2017-06-12 14:26:21 -04002440bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002441
Kévin Petite8edce32019-04-10 14:23:32 +01002442 const std::vector<const char *> Names = {"_Z11vstore_halffjPU3AS1Dh",
2443 "_Z15vstore_half_rtefjPU3AS1Dh",
2444 "_Z15vstore_half_rtzfjPU3AS1Dh"};
David Neto22f144c2017-06-12 14:26:21 -04002445
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002446 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002447 // The value to store.
2448 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002449
Kévin Petite8edce32019-04-10 14:23:32 +01002450 // The index argument from vstore_half.
2451 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002452
Kévin Petite8edce32019-04-10 14:23:32 +01002453 // The pointer argument from vstore_half.
2454 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002455
Kévin Petite8edce32019-04-10 14:23:32 +01002456 auto IntTy = Type::getInt32Ty(M.getContext());
2457 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2458 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2459 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002460
Kévin Petite8edce32019-04-10 14:23:32 +01002461 // Our intrinsic to pack a float2 to an int.
2462 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002463
Kévin Petite8edce32019-04-10 14:23:32 +01002464 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002465
Kévin Petite8edce32019-04-10 14:23:32 +01002466 // Insert our value into a float2 so that we can pack it.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002467 auto TempVec = InsertElementInst::Create(
2468 UndefValue::get(Float2Ty), Arg0, ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002469
Kévin Petite8edce32019-04-10 14:23:32 +01002470 // Pack the float2 -> half2 (in an int).
2471 auto X = CallInst::Create(NewF, TempVec, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002472
Kévin Petite8edce32019-04-10 14:23:32 +01002473 Value *Ret;
2474 if (clspv::Option::F16BitStorage()) {
2475 auto ShortTy = Type::getInt16Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002476 auto ShortPointerTy =
2477 PointerType::get(ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002478
Kévin Petite8edce32019-04-10 14:23:32 +01002479 // Truncate our i32 to an i16.
2480 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002481
Kévin Petite8edce32019-04-10 14:23:32 +01002482 // Cast the half* pointer to short*.
2483 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002484
Kévin Petite8edce32019-04-10 14:23:32 +01002485 // Index into the correct address of the casted pointer.
2486 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002487
Kévin Petite8edce32019-04-10 14:23:32 +01002488 // Store to the int* we casted to.
2489 Ret = new StoreInst(Trunc, Index, CI);
2490 } else {
2491 // We can only write to 32-bit aligned words.
2492 //
2493 // Assuming base is aligned to 32-bits, replace the equivalent of
2494 // vstore_half(value, index, base)
2495 // with:
2496 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2497 // uint32_t write_to_upper_half = index & 1u;
2498 // uint32_t shift = write_to_upper_half << 4;
2499 //
2500 // // Pack the float value as a half number in bottom 16 bits
2501 // // of an i32.
2502 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2503 //
2504 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2505 // ^ ((packed & 0xffff) << shift)
2506 // // We only need relaxed consistency, but OpenCL 1.2 only has
2507 // // sequentially consistent atomics.
2508 // // TODO(dneto): Use relaxed consistency.
2509 // atomic_xor(target_ptr, xor_value)
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002510 auto IntPointerTy =
2511 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002512
Kévin Petite8edce32019-04-10 14:23:32 +01002513 auto Four = ConstantInt::get(IntTy, 4);
2514 auto FFFF = ConstantInt::get(IntTy, 0xffff);
David Neto17852de2017-05-29 17:29:31 -04002515
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002516 auto IndexIsOdd =
2517 BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002518 // Compute index / 2
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002519 auto IndexIntoI32 =
2520 BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2521 auto BaseI32Ptr =
2522 CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2523 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32,
2524 "base_i32_ptr", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002525 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2526 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002527 auto MaskBitsToWrite =
2528 BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2529 auto MaskedCurrent = BinaryOperator::CreateAnd(
2530 MaskBitsToWrite, CurrentValue, "masked_current", CI);
David Neto17852de2017-05-29 17:29:31 -04002531
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002532 auto XLowerBits =
2533 BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2534 auto NewBitsToWrite =
2535 BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2536 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite,
2537 "value_to_xor", CI);
David Neto17852de2017-05-29 17:29:31 -04002538
Kévin Petite8edce32019-04-10 14:23:32 +01002539 // Generate the call to atomi_xor.
2540 SmallVector<Type *, 5> ParamTypes;
2541 // The pointer type.
2542 ParamTypes.push_back(IntPointerTy);
2543 // The Types for memory scope, semantics, and value.
2544 ParamTypes.push_back(IntTy);
2545 ParamTypes.push_back(IntTy);
2546 ParamTypes.push_back(IntTy);
2547 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2548 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
David Neto17852de2017-05-29 17:29:31 -04002549
Kévin Petite8edce32019-04-10 14:23:32 +01002550 const auto ConstantScopeDevice =
2551 ConstantInt::get(IntTy, spv::ScopeDevice);
2552 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2553 // (SPIR-V Workgroup).
2554 const auto AddrSpaceSemanticsBits =
2555 IntPointerTy->getPointerAddressSpace() == 1
2556 ? spv::MemorySemanticsUniformMemoryMask
2557 : spv::MemorySemanticsWorkgroupMemoryMask;
David Neto17852de2017-05-29 17:29:31 -04002558
Kévin Petite8edce32019-04-10 14:23:32 +01002559 // We're using relaxed consistency here.
2560 const auto ConstantMemorySemantics =
2561 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2562 AddrSpaceSemanticsBits);
David Neto17852de2017-05-29 17:29:31 -04002563
Kévin Petite8edce32019-04-10 14:23:32 +01002564 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2565 ConstantMemorySemantics, ValueToXor};
2566 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2567 Ret = nullptr;
David Neto22f144c2017-06-12 14:26:21 -04002568 }
David Neto22f144c2017-06-12 14:26:21 -04002569
Kévin Petite8edce32019-04-10 14:23:32 +01002570 return Ret;
2571 });
David Neto22f144c2017-06-12 14:26:21 -04002572}
2573
2574bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002575
Kévin Petite8edce32019-04-10 14:23:32 +01002576 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002577 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2578 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2579 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2580 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2581 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2582 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2583 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2584 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2585 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2586 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2587 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2588 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2589 };
David Neto22f144c2017-06-12 14:26:21 -04002590
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002591 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002592 // The value to store.
2593 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002594
Kévin Petite8edce32019-04-10 14:23:32 +01002595 // The index argument from vstore_half.
2596 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002597
Kévin Petite8edce32019-04-10 14:23:32 +01002598 // The pointer argument from vstore_half.
2599 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002600
Kévin Petite8edce32019-04-10 14:23:32 +01002601 auto IntTy = Type::getInt32Ty(M.getContext());
2602 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002603 auto NewPointerTy =
2604 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002605 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002606
Kévin Petite8edce32019-04-10 14:23:32 +01002607 // Our intrinsic to pack a float2 to an int.
2608 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002609
Kévin Petite8edce32019-04-10 14:23:32 +01002610 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002611
Kévin Petite8edce32019-04-10 14:23:32 +01002612 // Turn the packed x & y into the final packing.
2613 auto X = CallInst::Create(NewF, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002614
Kévin Petite8edce32019-04-10 14:23:32 +01002615 // Cast the half* pointer to int*.
2616 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002617
Kévin Petite8edce32019-04-10 14:23:32 +01002618 // Index into the correct address of the casted pointer.
2619 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002620
Kévin Petite8edce32019-04-10 14:23:32 +01002621 // Store to the int* we casted to.
2622 return new StoreInst(X, Index, CI);
2623 });
David Neto22f144c2017-06-12 14:26:21 -04002624}
2625
2626bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002627
Kévin Petite8edce32019-04-10 14:23:32 +01002628 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002629 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2630 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2631 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2632 "_Z13vstorea_half4Dv4_fjPDh", // private
2633 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2634 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2635 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2636 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2637 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2638 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2639 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2640 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2641 };
David Neto22f144c2017-06-12 14:26:21 -04002642
Kévin Petite8edce32019-04-10 14:23:32 +01002643 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2644 // The value to store.
2645 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002646
Kévin Petite8edce32019-04-10 14:23:32 +01002647 // The index argument from vstore_half.
2648 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002649
Kévin Petite8edce32019-04-10 14:23:32 +01002650 // The pointer argument from vstore_half.
2651 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002652
Kévin Petite8edce32019-04-10 14:23:32 +01002653 auto IntTy = Type::getInt32Ty(M.getContext());
2654 auto Int2Ty = VectorType::get(IntTy, 2);
2655 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002656 auto NewPointerTy =
2657 PointerType::get(Int2Ty, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002658 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002659
Kévin Petite8edce32019-04-10 14:23:32 +01002660 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2661 ConstantInt::get(IntTy, 1)};
David Neto22f144c2017-06-12 14:26:21 -04002662
Kévin Petite8edce32019-04-10 14:23:32 +01002663 // Extract out the x & y components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002664 auto Lo = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2665 ConstantVector::get(LoShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002666
Kévin Petite8edce32019-04-10 14:23:32 +01002667 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2668 ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002669
Kévin Petite8edce32019-04-10 14:23:32 +01002670 // Extract out the z & w components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002671 auto Hi = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2672 ConstantVector::get(HiShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002673
Kévin Petite8edce32019-04-10 14:23:32 +01002674 // Our intrinsic to pack a float2 to an int.
2675 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002676
Kévin Petite8edce32019-04-10 14:23:32 +01002677 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002678
Kévin Petite8edce32019-04-10 14:23:32 +01002679 // Turn the packed x & y into the final component of our int2.
2680 auto X = CallInst::Create(NewF, Lo, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002681
Kévin Petite8edce32019-04-10 14:23:32 +01002682 // Turn the packed z & w into the final component of our int2.
2683 auto Y = CallInst::Create(NewF, Hi, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002684
Kévin Petite8edce32019-04-10 14:23:32 +01002685 auto Combine = InsertElementInst::Create(
2686 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002687 Combine = InsertElementInst::Create(Combine, Y, ConstantInt::get(IntTy, 1),
2688 "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002689
Kévin Petite8edce32019-04-10 14:23:32 +01002690 // Cast the half* pointer to int2*.
2691 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002692
Kévin Petite8edce32019-04-10 14:23:32 +01002693 // Index into the correct address of the casted pointer.
2694 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002695
Kévin Petite8edce32019-04-10 14:23:32 +01002696 // Store to the int2* we casted to.
2697 return new StoreInst(Combine, Index, CI);
2698 });
David Neto22f144c2017-06-12 14:26:21 -04002699}
2700
alan-bakerf7e17cb2020-01-02 07:29:59 -05002701bool ReplaceOpenCLBuiltinPass::replaceHalfReadImage(Module &M) {
2702 bool Changed = false;
2703 const std::map<const char *, const char *> Map = {
2704 // 1D
2705 {"_Z11read_imageh14ocl_image1d_roi", "_Z11read_imagef14ocl_image1d_roi"},
2706 {"_Z11read_imageh14ocl_image1d_ro11ocl_sampleri",
2707 "_Z11read_imagef14ocl_image1d_ro11ocl_sampleri"},
2708 {"_Z11read_imageh14ocl_image1d_ro11ocl_samplerf",
2709 "_Z11read_imagef14ocl_image1d_ro11ocl_samplerf"},
2710 // TODO 1D array
2711 // 2D
2712 {"_Z11read_imageh14ocl_image2d_roDv2_i",
2713 "_Z11read_imagef14ocl_image2d_roDv2_i"},
2714 {"_Z11read_imageh14ocl_image2d_ro11ocl_samplerDv2_i",
2715 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i"},
2716 {"_Z11read_imageh14ocl_image2d_ro11ocl_samplerDv2_f",
2717 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f"},
2718 // TODO 2D array
2719 // 3D
2720 {"_Z11read_imageh14ocl_image3d_roDv4_i",
2721 "_Z11read_imagef14ocl_image3d_roDv4_i"},
2722 {"_Z11read_imageh14ocl_image3d_ro11ocl_samplerDv4_i",
2723 "_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_i"},
2724 {"_Z11read_imageh14ocl_image3d_ro11ocl_samplerDv4_f",
2725 "_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_f"}};
2726
2727 for (auto Pair : Map) {
2728 // If we find a function with the matching name.
2729 if (auto F = M.getFunction(Pair.first)) {
2730 SmallVector<Instruction *, 4> ToRemoves;
2731
2732 // Walk the users of the function.
2733 for (auto &U : F->uses()) {
2734 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2735 SmallVector<Type *, 3> types;
2736 SmallVector<Value *, 3> args;
2737 for (auto i = 0; i < CI->getNumArgOperands(); ++i) {
2738 types.push_back(CI->getArgOperand(i)->getType());
2739 args.push_back(CI->getArgOperand(i));
2740 }
2741
2742 auto NewFType = FunctionType::get(
2743 VectorType::get(Type::getFloatTy(M.getContext()),
2744 CI->getType()->getVectorNumElements()),
2745 types, false);
2746
2747 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2748
2749 auto NewCI = CallInst::Create(NewF, args, "", CI);
2750
2751 // Convert to the half type.
2752 auto Cast = CastInst::CreateFPCast(NewCI, CI->getType(), "", CI);
2753
2754 CI->replaceAllUsesWith(Cast);
2755
2756 // Lastly, remember to remove the user.
2757 ToRemoves.push_back(CI);
2758 }
2759 }
2760
2761 Changed = !ToRemoves.empty();
2762
2763 // And cleanup the calls we don't use anymore.
2764 for (auto V : ToRemoves) {
2765 V->eraseFromParent();
2766 }
2767
2768 // And remove the function we don't need either too.
2769 F->eraseFromParent();
2770 }
2771 }
2772
2773 return Changed;
2774}
2775
2776bool ReplaceOpenCLBuiltinPass::replaceHalfWriteImage(Module &M) {
2777 bool Changed = false;
2778 const std::map<const char *, const char *> Map = {
2779 // 1D
2780 {"_Z12write_imageh14ocl_image1d_woiDv4_Dh",
2781 "_Z12write_imagef14ocl_image1d_woiDv4_f"},
2782 // TODO 1D array
2783 // 2D
2784 {"_Z12write_imageh14ocl_image2d_woDv2_iDv4_Dh",
2785 "_Z12write_imagef14ocl_image2d_woDv2_iDv4_f"},
2786 // TODO 2D array
2787 // 3D
2788 {"_Z12write_imageh14ocl_image3d_woDv4_iDv4_Dh",
2789 "_Z12write_imagef14ocl_image3d_woDv4_iDv4_f"}};
2790
2791 for (auto Pair : Map) {
2792 // If we find a function with the matching name.
2793 if (auto F = M.getFunction(Pair.first)) {
2794 SmallVector<Instruction *, 4> ToRemoves;
2795
2796 // Walk the users of the function.
2797 for (auto &U : F->uses()) {
2798 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2799 SmallVector<Type *, 3> types(3);
2800 SmallVector<Value *, 3> args(3);
2801
2802 // Image
2803 types[0] = CI->getArgOperand(0)->getType();
2804 args[0] = CI->getArgOperand(0);
2805
2806 // Coord
2807 types[1] = CI->getArgOperand(1)->getType();
2808 args[1] = CI->getArgOperand(1);
2809
2810 // Data
2811 types[2] = VectorType::get(
2812 Type::getFloatTy(M.getContext()),
2813 CI->getArgOperand(2)->getType()->getVectorNumElements());
2814
2815 auto NewFType =
2816 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
2817
2818 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2819
2820 // Convert data to the float type.
2821 auto Cast =
2822 CastInst::CreateFPCast(CI->getArgOperand(2), types[2], "", CI);
2823 args[2] = Cast;
2824
2825 auto NewCI = CallInst::Create(NewF, args, "", CI);
2826
2827 // Lastly, remember to remove the user.
2828 ToRemoves.push_back(CI);
2829 }
2830 }
2831
2832 Changed = !ToRemoves.empty();
2833
2834 // And cleanup the calls we don't use anymore.
2835 for (auto V : ToRemoves) {
2836 V->eraseFromParent();
2837 }
2838
2839 // And remove the function we don't need either too.
2840 F->eraseFromParent();
2841 }
2842 }
2843
2844 return Changed;
2845}
2846
alan-baker931d18a2019-12-12 08:21:32 -05002847bool ReplaceOpenCLBuiltinPass::replaceUnsampledReadImage(Module &M) {
2848 bool Changed = false;
2849 const std::map<const char *, const char *> Map = {
2850 // 1D
2851 {"_Z11read_imagef14ocl_image1d_roi",
2852 "_Z11read_imagef14ocl_image1d_ro11ocl_sampleri"},
2853 {"_Z11read_imagei14ocl_image1d_roi",
2854 "_Z11read_imagei14ocl_image1d_ro11ocl_sampleri"},
2855 {"_Z12read_imageui14ocl_image1d_roi",
2856 "_Z12read_imageui14ocl_image1d_ro11ocl_sampleri"},
2857 // TODO 1D array
2858 // 2D
2859 {"_Z11read_imagef14ocl_image2d_roDv2_i",
2860 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i"},
2861 {"_Z11read_imagei14ocl_image2d_roDv2_i",
2862 "_Z11read_imagei14ocl_image2d_ro11ocl_samplerDv2_i"},
2863 {"_Z12read_imageui14ocl_image2d_roDv2_i",
2864 "_Z12read_imageui14ocl_image2d_ro11ocl_samplerDv2_i"},
2865 // TODO 2D array
2866 // 3D
2867 {"_Z11read_imagef14ocl_image3d_roDv4_i",
2868 "_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_i"},
2869 {"_Z11read_imagei14ocl_image3d_roDv4_i",
2870 "_Z11read_imagei14ocl_image3d_ro11ocl_samplerDv4_i"},
2871 {"_Z12read_imageui14ocl_image3d_roDv4_i",
2872 "_Z12read_imageui14ocl_image3d_ro11ocl_samplerDv4_i"}};
2873
2874 Function *translate_sampler =
2875 M.getFunction(clspv::TranslateSamplerInitializerFunction());
2876 Type *sampler_type = M.getTypeByName("opencl.sampler_t");
alan-bakerf7e17cb2020-01-02 07:29:59 -05002877 if (sampler_type) {
2878 sampler_type = sampler_type->getPointerTo(clspv::AddressSpace::Constant);
2879 }
alan-baker931d18a2019-12-12 08:21:32 -05002880 for (auto Pair : Map) {
2881 // If we find a function with the matching name.
2882 if (auto F = M.getFunction(Pair.first)) {
2883 SmallVector<Instruction *, 4> ToRemoves;
2884
2885 // Walk the users of the function.
2886 for (auto &U : F->uses()) {
2887 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2888 // The image.
2889 auto Image = CI->getOperand(0);
2890
2891 // The coordinate.
2892 auto Coord = CI->getOperand(1);
2893
2894 // Create the sampler translation function if necessary.
2895 if (!translate_sampler) {
2896 // Create the sampler type if necessary.
2897 if (!sampler_type) {
2898 sampler_type =
2899 StructType::create(M.getContext(), "opencl.sampler_t");
2900 sampler_type =
2901 sampler_type->getPointerTo(clspv::AddressSpace::Constant);
2902 }
2903 auto fn_type = FunctionType::get(
2904 sampler_type, {Type::getInt32Ty(M.getContext())}, false);
2905 auto callee = M.getOrInsertFunction(
2906 clspv::TranslateSamplerInitializerFunction(), fn_type);
2907 translate_sampler = cast<Function>(callee.getCallee());
2908 }
2909
2910 auto NewFType = FunctionType::get(
2911 CI->getType(), {Image->getType(), sampler_type, Coord->getType()},
2912 false);
2913
2914 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2915
James Pricec05f6052020-01-14 13:37:20 -05002916 const uint64_t data_mask =
2917 clspv::version0::CLK_ADDRESS_NONE |
2918 clspv::version0::CLK_FILTER_NEAREST |
2919 clspv::version0::CLK_NORMALIZED_COORDS_FALSE;
alan-baker931d18a2019-12-12 08:21:32 -05002920 auto NewSamplerCI = CallInst::Create(
2921 translate_sampler,
2922 {ConstantInt::get(Type::getInt32Ty(M.getContext()), data_mask)},
2923 "", CI);
2924 auto NewCI =
2925 CallInst::Create(NewF, {Image, NewSamplerCI, Coord}, "", CI);
2926
2927 CI->replaceAllUsesWith(NewCI);
2928
2929 // Lastly, remember to remove the user.
2930 ToRemoves.push_back(CI);
2931 }
2932 }
2933
2934 Changed = !ToRemoves.empty();
2935
2936 // And cleanup the calls we don't use anymore.
2937 for (auto V : ToRemoves) {
2938 V->eraseFromParent();
2939 }
2940
2941 // And remove the function we don't need either too.
2942 F->eraseFromParent();
2943 }
2944 }
2945
2946 return Changed;
2947}
2948
Kévin Petit06517a12019-12-09 19:40:31 +00002949bool ReplaceOpenCLBuiltinPass::replaceSampledReadImageWithIntCoords(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002950 bool Changed = false;
2951
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002952 const std::map<const char *, const char *> Map = {
alan-bakerf906d2b2019-12-10 11:26:23 -05002953 // 1D
2954 {"_Z11read_imagei14ocl_image1d_ro11ocl_sampleri",
2955 "_Z11read_imagei14ocl_image1d_ro11ocl_samplerf"},
2956 {"_Z12read_imageui14ocl_image1d_ro11ocl_sampleri",
2957 "_Z12read_imageui14ocl_image1d_ro11ocl_samplerf"},
2958 {"_Z11read_imagef14ocl_image1d_ro11ocl_sampleri",
2959 "_Z11read_imagef14ocl_image1d_ro11ocl_samplerf"},
2960 // TODO 1Darray
Kévin Petit06517a12019-12-09 19:40:31 +00002961 // 2D
2962 {"_Z11read_imagei14ocl_image2d_ro11ocl_samplerDv2_i",
2963 "_Z11read_imagei14ocl_image2d_ro11ocl_samplerDv2_f"},
2964 {"_Z12read_imageui14ocl_image2d_ro11ocl_samplerDv2_i",
2965 "_Z12read_imageui14ocl_image2d_ro11ocl_samplerDv2_f"},
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002966 {"_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i",
2967 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f"},
Kévin Petit06517a12019-12-09 19:40:31 +00002968 // TODO 2D array
2969 // 3D
2970 {"_Z11read_imagei14ocl_image3d_ro11ocl_samplerDv4_i",
2971 "_Z11read_imagei14ocl_image3d_ro11ocl_samplerDv4_f"},
2972 {"_Z12read_imageui14ocl_image3d_ro11ocl_samplerDv4_i",
2973 "_Z12read_imageui14ocl_image3d_ro11ocl_samplerDv4_f"},
2974 {"_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_i",
2975 "_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_f"}};
David Neto22f144c2017-06-12 14:26:21 -04002976
2977 for (auto Pair : Map) {
2978 // If we find a function with the matching name.
2979 if (auto F = M.getFunction(Pair.first)) {
2980 SmallVector<Instruction *, 4> ToRemoves;
2981
2982 // Walk the users of the function.
2983 for (auto &U : F->uses()) {
2984 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2985 // The image.
2986 auto Arg0 = CI->getOperand(0);
2987
2988 // The sampler.
2989 auto Arg1 = CI->getOperand(1);
2990
2991 // The coordinate (integer type that we can't handle).
2992 auto Arg2 = CI->getOperand(2);
2993
alan-bakerf906d2b2019-12-10 11:26:23 -05002994 uint32_t dim = clspv::ImageDimensionality(Arg0->getType());
2995 // TODO(alan-baker): when arrayed images are supported fix component
2996 // calculation.
2997 uint32_t components = dim;
2998 Type *float_ty = nullptr;
2999 if (components == 1) {
3000 float_ty = Type::getFloatTy(M.getContext());
3001 } else {
3002 float_ty = VectorType::get(Type::getFloatTy(M.getContext()),
3003 Arg2->getType()->getVectorNumElements());
3004 }
David Neto22f144c2017-06-12 14:26:21 -04003005
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003006 auto NewFType = FunctionType::get(
alan-bakerf906d2b2019-12-10 11:26:23 -05003007 CI->getType(), {Arg0->getType(), Arg1->getType(), float_ty},
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003008 false);
David Neto22f144c2017-06-12 14:26:21 -04003009
3010 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
3011
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003012 auto Cast =
alan-bakerf906d2b2019-12-10 11:26:23 -05003013 CastInst::Create(Instruction::SIToFP, Arg2, float_ty, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04003014
3015 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
3016
3017 CI->replaceAllUsesWith(NewCI);
3018
3019 // Lastly, remember to remove the user.
3020 ToRemoves.push_back(CI);
3021 }
3022 }
3023
3024 Changed = !ToRemoves.empty();
3025
3026 // And cleanup the calls we don't use anymore.
3027 for (auto V : ToRemoves) {
3028 V->eraseFromParent();
3029 }
3030
3031 // And remove the function we don't need either too.
3032 F->eraseFromParent();
3033 }
3034 }
3035
3036 return Changed;
3037}
3038
3039bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
3040 bool Changed = false;
3041
Kévin Petit9b340262019-06-19 18:31:11 +01003042 const std::map<const char *, spv::Op> Map = {
3043 {"_Z8atom_incPU3AS1Vi", spv::OpAtomicIIncrement},
3044 {"_Z8atom_incPU3AS3Vi", spv::OpAtomicIIncrement},
3045 {"_Z8atom_incPU3AS1Vj", spv::OpAtomicIIncrement},
3046 {"_Z8atom_incPU3AS3Vj", spv::OpAtomicIIncrement},
3047 {"_Z8atom_decPU3AS1Vi", spv::OpAtomicIDecrement},
3048 {"_Z8atom_decPU3AS3Vi", spv::OpAtomicIDecrement},
3049 {"_Z8atom_decPU3AS1Vj", spv::OpAtomicIDecrement},
3050 {"_Z8atom_decPU3AS3Vj", spv::OpAtomicIDecrement},
3051 {"_Z12atom_cmpxchgPU3AS1Viii", spv::OpAtomicCompareExchange},
3052 {"_Z12atom_cmpxchgPU3AS3Viii", spv::OpAtomicCompareExchange},
3053 {"_Z12atom_cmpxchgPU3AS1Vjjj", spv::OpAtomicCompareExchange},
3054 {"_Z12atom_cmpxchgPU3AS3Vjjj", spv::OpAtomicCompareExchange},
3055 {"_Z10atomic_incPU3AS1Vi", spv::OpAtomicIIncrement},
3056 {"_Z10atomic_incPU3AS3Vi", spv::OpAtomicIIncrement},
3057 {"_Z10atomic_incPU3AS1Vj", spv::OpAtomicIIncrement},
3058 {"_Z10atomic_incPU3AS3Vj", spv::OpAtomicIIncrement},
3059 {"_Z10atomic_decPU3AS1Vi", spv::OpAtomicIDecrement},
3060 {"_Z10atomic_decPU3AS3Vi", spv::OpAtomicIDecrement},
3061 {"_Z10atomic_decPU3AS1Vj", spv::OpAtomicIDecrement},
3062 {"_Z10atomic_decPU3AS3Vj", spv::OpAtomicIDecrement},
3063 {"_Z14atomic_cmpxchgPU3AS1Viii", spv::OpAtomicCompareExchange},
3064 {"_Z14atomic_cmpxchgPU3AS3Viii", spv::OpAtomicCompareExchange},
3065 {"_Z14atomic_cmpxchgPU3AS1Vjjj", spv::OpAtomicCompareExchange},
3066 {"_Z14atomic_cmpxchgPU3AS3Vjjj", spv::OpAtomicCompareExchange}};
David Neto22f144c2017-06-12 14:26:21 -04003067
3068 for (auto Pair : Map) {
3069 // If we find a function with the matching name.
3070 if (auto F = M.getFunction(Pair.first)) {
3071 SmallVector<Instruction *, 4> ToRemoves;
3072
3073 // Walk the users of the function.
3074 for (auto &U : F->uses()) {
3075 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -04003076
3077 auto IntTy = Type::getInt32Ty(M.getContext());
3078
David Neto22f144c2017-06-12 14:26:21 -04003079 // We need to map the OpenCL constants to the SPIR-V equivalents.
3080 const auto ConstantScopeDevice =
3081 ConstantInt::get(IntTy, spv::ScopeDevice);
3082 const auto ConstantMemorySemantics = ConstantInt::get(
3083 IntTy, spv::MemorySemanticsUniformMemoryMask |
3084 spv::MemorySemanticsSequentiallyConsistentMask);
3085
3086 SmallVector<Value *, 5> Params;
3087
3088 // The pointer.
3089 Params.push_back(CI->getArgOperand(0));
3090
3091 // The memory scope.
3092 Params.push_back(ConstantScopeDevice);
3093
3094 // The memory semantics.
3095 Params.push_back(ConstantMemorySemantics);
3096
3097 if (2 < CI->getNumArgOperands()) {
3098 // The unequal memory semantics.
3099 Params.push_back(ConstantMemorySemantics);
3100
3101 // The value.
3102 Params.push_back(CI->getArgOperand(2));
3103
3104 // The comparator.
3105 Params.push_back(CI->getArgOperand(1));
3106 } else if (1 < CI->getNumArgOperands()) {
3107 // The value.
3108 Params.push_back(CI->getArgOperand(1));
3109 }
3110
Kévin Petit9b340262019-06-19 18:31:11 +01003111 auto NewCI =
3112 clspv::InsertSPIRVOp(CI, Pair.second, {}, CI->getType(), Params);
David Neto22f144c2017-06-12 14:26:21 -04003113
3114 CI->replaceAllUsesWith(NewCI);
3115
3116 // Lastly, remember to remove the user.
3117 ToRemoves.push_back(CI);
3118 }
3119 }
3120
3121 Changed = !ToRemoves.empty();
3122
3123 // And cleanup the calls we don't use anymore.
3124 for (auto V : ToRemoves) {
3125 V->eraseFromParent();
3126 }
3127
3128 // And remove the function we don't need either too.
3129 F->eraseFromParent();
3130 }
3131 }
3132
Neil Henning39672102017-09-29 14:33:13 +01003133 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003134 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00003135 {"_Z8atom_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003136 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00003137 {"_Z8atom_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003138 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00003139 {"_Z8atom_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003140 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00003141 {"_Z8atom_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003142 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00003143 {"_Z9atom_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003144 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00003145 {"_Z9atom_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003146 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00003147 {"_Z8atom_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003148 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00003149 {"_Z8atom_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003150 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00003151 {"_Z8atom_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003152 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00003153 {"_Z8atom_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003154 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00003155 {"_Z8atom_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003156 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00003157 {"_Z8atom_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003158 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00003159 {"_Z7atom_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003160 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00003161 {"_Z7atom_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003162 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00003163 {"_Z8atom_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00003164 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00003165 {"_Z8atom_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01003166 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00003167 {"_Z10atomic_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01003168 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00003169 {"_Z10atomic_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01003170 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00003171 {"_Z10atomic_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01003172 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00003173 {"_Z10atomic_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01003174 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00003175 {"_Z11atomic_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01003176 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00003177 {"_Z11atomic_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01003178 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00003179 {"_Z10atomic_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Neil Henning39672102017-09-29 14:33:13 +01003180 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00003181 {"_Z10atomic_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Neil Henning39672102017-09-29 14:33:13 +01003182 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00003183 {"_Z10atomic_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Neil Henning39672102017-09-29 14:33:13 +01003184 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00003185 {"_Z10atomic_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Neil Henning39672102017-09-29 14:33:13 +01003186 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00003187 {"_Z10atomic_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01003188 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00003189 {"_Z10atomic_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01003190 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00003191 {"_Z9atomic_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01003192 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00003193 {"_Z9atomic_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01003194 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00003195 {"_Z10atomic_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
3196 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
3197 {"_Z10atomic_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor}};
Neil Henning39672102017-09-29 14:33:13 +01003198
3199 for (auto Pair : Map2) {
3200 // If we find a function with the matching name.
3201 if (auto F = M.getFunction(Pair.first)) {
3202 SmallVector<Instruction *, 4> ToRemoves;
3203
3204 // Walk the users of the function.
3205 for (auto &U : F->uses()) {
3206 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3207 auto AtomicOp = new AtomicRMWInst(
3208 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
3209 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
3210
3211 CI->replaceAllUsesWith(AtomicOp);
3212
3213 // Lastly, remember to remove the user.
3214 ToRemoves.push_back(CI);
3215 }
3216 }
3217
3218 Changed = !ToRemoves.empty();
3219
3220 // And cleanup the calls we don't use anymore.
3221 for (auto V : ToRemoves) {
3222 V->eraseFromParent();
3223 }
3224
3225 // And remove the function we don't need either too.
3226 F->eraseFromParent();
3227 }
3228 }
3229
David Neto22f144c2017-06-12 14:26:21 -04003230 return Changed;
3231}
3232
3233bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04003234
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003235 std::vector<const char *> Names = {
3236 "_Z5crossDv4_fS_",
Kévin Petite8edce32019-04-10 14:23:32 +01003237 };
3238
3239 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
David Neto22f144c2017-06-12 14:26:21 -04003240 auto IntTy = Type::getInt32Ty(M.getContext());
3241 auto FloatTy = Type::getFloatTy(M.getContext());
3242
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003243 Constant *DownShuffleMask[3] = {ConstantInt::get(IntTy, 0),
3244 ConstantInt::get(IntTy, 1),
3245 ConstantInt::get(IntTy, 2)};
David Neto22f144c2017-06-12 14:26:21 -04003246
3247 Constant *UpShuffleMask[4] = {
3248 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
3249 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
3250
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003251 Constant *FloatVec[3] = {ConstantFP::get(FloatTy, 0.0f),
3252 UndefValue::get(FloatTy),
3253 UndefValue::get(FloatTy)};
David Neto22f144c2017-06-12 14:26:21 -04003254
Kévin Petite8edce32019-04-10 14:23:32 +01003255 auto Vec4Ty = CI->getArgOperand(0)->getType();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003256 auto Arg0 =
3257 new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty),
3258 ConstantVector::get(DownShuffleMask), "", CI);
3259 auto Arg1 =
3260 new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty),
3261 ConstantVector::get(DownShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01003262 auto Vec3Ty = Arg0->getType();
David Neto22f144c2017-06-12 14:26:21 -04003263
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003264 auto NewFType = FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
David Neto22f144c2017-06-12 14:26:21 -04003265
Kévin Petite8edce32019-04-10 14:23:32 +01003266 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
David Neto22f144c2017-06-12 14:26:21 -04003267
Kévin Petite8edce32019-04-10 14:23:32 +01003268 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04003269
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003270 return new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec),
3271 ConstantVector::get(UpShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01003272 });
David Neto22f144c2017-06-12 14:26:21 -04003273}
David Neto62653202017-10-16 19:05:18 -04003274
3275bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
3276 bool Changed = false;
3277
3278 // OpenCL's float result = fract(float x, float* ptr)
3279 //
3280 // In the LLVM domain:
3281 //
3282 // %floor_result = call spir_func float @floor(float %x)
3283 // store float %floor_result, float * %ptr
3284 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
3285 // %result = call spir_func float
3286 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
3287 //
3288 // Becomes in the SPIR-V domain, where translations of floor, fmin,
3289 // and clspv.fract occur in the SPIR-V generator pass:
3290 //
3291 // %glsl_ext = OpExtInstImport "GLSL.std.450"
3292 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
3293 // ...
3294 // %floor_result = OpExtInst %float %glsl_ext Floor %x
3295 // OpStore %ptr %floor_result
3296 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
3297 // %fract_result = OpExtInst %float
3298 // %glsl_ext Fmin %fract_intermediate %just_under_1
3299
David Neto62653202017-10-16 19:05:18 -04003300 using std::string;
3301
3302 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
3303 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003304 using QuadType =
3305 std::tuple<const char *, const char *, const char *, const char *>;
David Neto62653202017-10-16 19:05:18 -04003306 auto make_quad = [](const char *a, const char *b, const char *c,
3307 const char *d) {
3308 return std::tuple<const char *, const char *, const char *, const char *>(
3309 a, b, c, d);
3310 };
3311 const std::vector<QuadType> Functions = {
3312 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003313 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff",
3314 "clspv.fract.v2f"),
3315 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff",
3316 "clspv.fract.v3f"),
3317 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff",
3318 "clspv.fract.v4f"),
David Neto62653202017-10-16 19:05:18 -04003319 };
3320
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003321 for (auto &quad : Functions) {
David Neto62653202017-10-16 19:05:18 -04003322 const StringRef fract_name(std::get<0>(quad));
3323
3324 // If we find a function with the matching name.
3325 if (auto F = M.getFunction(fract_name)) {
3326 if (F->use_begin() == F->use_end())
3327 continue;
3328
3329 // We have some uses.
3330 Changed = true;
3331
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003332 auto &Context = M.getContext();
David Neto62653202017-10-16 19:05:18 -04003333
3334 const StringRef floor_name(std::get<1>(quad));
3335 const StringRef fmin_name(std::get<2>(quad));
3336 const StringRef clspv_fract_name(std::get<3>(quad));
3337
3338 // This is either float or a float vector. All the float-like
3339 // types are this type.
3340 auto result_ty = F->getReturnType();
3341
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003342 Function *fmin_fn = M.getFunction(fmin_name);
David Neto62653202017-10-16 19:05:18 -04003343 if (!fmin_fn) {
3344 // Make the fmin function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003345 FunctionType *fn_ty =
3346 FunctionType::get(result_ty, {result_ty, result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003347 fmin_fn =
3348 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003349 fmin_fn->addFnAttr(Attribute::ReadNone);
3350 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
3351 }
3352
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003353 Function *floor_fn = M.getFunction(floor_name);
David Neto62653202017-10-16 19:05:18 -04003354 if (!floor_fn) {
3355 // Make the floor function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003356 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003357 floor_fn = cast<Function>(
3358 M.getOrInsertFunction(floor_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003359 floor_fn->addFnAttr(Attribute::ReadNone);
3360 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
3361 }
3362
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003363 Function *clspv_fract_fn = M.getFunction(clspv_fract_name);
David Neto62653202017-10-16 19:05:18 -04003364 if (!clspv_fract_fn) {
3365 // Make the clspv_fract function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003366 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003367 clspv_fract_fn = cast<Function>(
3368 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003369 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
3370 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
3371 }
3372
3373 // Number of significant significand bits, whether represented or not.
3374 unsigned num_significand_bits;
3375 switch (result_ty->getScalarType()->getTypeID()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003376 case Type::HalfTyID:
3377 num_significand_bits = 11;
3378 break;
3379 case Type::FloatTyID:
3380 num_significand_bits = 24;
3381 break;
3382 case Type::DoubleTyID:
3383 num_significand_bits = 53;
3384 break;
3385 default:
3386 assert(false && "Unhandled float type when processing fract builtin");
3387 break;
David Neto62653202017-10-16 19:05:18 -04003388 }
3389 // Beware that the disassembler displays this value as
3390 // OpConstant %float 1
3391 // which is not quite right.
3392 const double kJustUnderOneScalar =
3393 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3394
3395 Constant *just_under_one =
3396 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3397 if (result_ty->isVectorTy()) {
3398 just_under_one = ConstantVector::getSplat(
3399 result_ty->getVectorNumElements(), just_under_one);
3400 }
3401
3402 IRBuilder<> Builder(Context);
3403
3404 SmallVector<Instruction *, 4> ToRemoves;
3405
3406 // Walk the users of the function.
3407 for (auto &U : F->uses()) {
3408 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3409
3410 Builder.SetInsertPoint(CI);
3411 auto arg = CI->getArgOperand(0);
3412 auto ptr = CI->getArgOperand(1);
3413
3414 // Compute floor result and store it.
3415 auto floor = Builder.CreateCall(floor_fn, {arg});
3416 Builder.CreateStore(floor, ptr);
3417
3418 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003419 auto fract_result =
3420 Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
David Neto62653202017-10-16 19:05:18 -04003421
3422 CI->replaceAllUsesWith(fract_result);
3423
3424 // Lastly, remember to remove the user.
3425 ToRemoves.push_back(CI);
3426 }
3427 }
3428
3429 // And cleanup the calls we don't use anymore.
3430 for (auto V : ToRemoves) {
3431 V->eraseFromParent();
3432 }
3433
3434 // And remove the function we don't need either too.
3435 F->eraseFromParent();
3436 }
3437 }
3438
3439 return Changed;
3440}