blob: ab9d09dfccc7d20d2444a5d53a1cd7c7c62760b8 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
David Neto118188e2018-08-24 11:27:54 -040021#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040022#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040023#include "llvm/IR/Module.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000024#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040025#include "llvm/Pass.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040029
David Neto118188e2018-08-24 11:27:54 -040030#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040031
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040032#include "clspv/Option.h"
David Neto482550a2018-03-24 05:21:07 -070033
Diego Novilloa4c44fa2019-04-11 10:56:15 -040034#include "Passes.h"
35#include "SPIRVOp.h"
36
David Neto22f144c2017-06-12 14:26:21 -040037using namespace llvm;
38
39#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
40
41namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000042
43struct ArgTypeInfo {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044 enum class SignedNess { None, Unsigned, Signed };
Kévin Petit8a560882019-03-21 15:24:34 +000045 SignedNess signedness;
46};
47
48struct FunctionInfo {
Kévin Petit9d1a9d12019-03-25 15:23:46 +000049 StringRef name;
Kévin Petit8a560882019-03-21 15:24:34 +000050 std::vector<ArgTypeInfo> argTypeInfos;
Kévin Petit8a560882019-03-21 15:24:34 +000051
Kévin Petit91bc72e2019-04-08 15:17:46 +010052 bool isArgSigned(size_t arg) const {
53 assert(argTypeInfos.size() > arg);
54 return argTypeInfos[arg].signedness == ArgTypeInfo::SignedNess::Signed;
Kévin Petit8a560882019-03-21 15:24:34 +000055 }
56
Kévin Petit91bc72e2019-04-08 15:17:46 +010057 static FunctionInfo getFromMangledName(StringRef name) {
58 FunctionInfo fi;
59 if (!getFromMangledNameCheck(name, &fi)) {
60 llvm_unreachable("Can't parse mangled function name!");
Kévin Petit8a560882019-03-21 15:24:34 +000061 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010062 return fi;
63 }
Kévin Petit8a560882019-03-21 15:24:34 +000064
Kévin Petit91bc72e2019-04-08 15:17:46 +010065 static bool getFromMangledNameCheck(StringRef name, FunctionInfo *finfo) {
66 if (!name.consume_front("_Z")) {
67 return false;
68 }
69 size_t nameLen;
70 if (name.consumeInteger(10, nameLen)) {
Kévin Petit8a560882019-03-21 15:24:34 +000071 return false;
72 }
73
Kévin Petit91bc72e2019-04-08 15:17:46 +010074 finfo->name = name.take_front(nameLen);
75 name = name.drop_front(nameLen);
Kévin Petit8a560882019-03-21 15:24:34 +000076
Kévin Petit91bc72e2019-04-08 15:17:46 +010077 ArgTypeInfo prev_ti;
Kévin Petit8a560882019-03-21 15:24:34 +000078
Kévin Petit91bc72e2019-04-08 15:17:46 +010079 while (name.size() != 0) {
80
81 ArgTypeInfo ti;
82
83 // Try parsing a vector prefix
84 if (name.consume_front("Dv")) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040085 int numElems;
86 if (name.consumeInteger(10, numElems)) {
87 return false;
88 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010089
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040090 if (!name.consume_front("_")) {
91 return false;
92 }
Kévin Petit91bc72e2019-04-08 15:17:46 +010093 }
94
95 // Parse the base type
96 char typeCode = name.front();
97 name = name.drop_front(1);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040098 switch (typeCode) {
Kévin Petit91bc72e2019-04-08 15:17:46 +010099 case 'c': // char
100 case 'a': // signed char
101 case 's': // short
102 case 'i': // int
103 case 'l': // long
104 ti.signedness = ArgTypeInfo::SignedNess::Signed;
105 break;
106 case 'h': // unsigned char
107 case 't': // unsigned short
108 case 'j': // unsigned int
109 case 'm': // unsigned long
110 ti.signedness = ArgTypeInfo::SignedNess::Unsigned;
111 break;
112 case 'f':
113 ti.signedness = ArgTypeInfo::SignedNess::None;
114 break;
115 case 'S':
116 ti = prev_ti;
117 if (!name.consume_front("_")) {
118 return false;
119 }
120 break;
121 default:
122 return false;
123 }
124
125 finfo->argTypeInfos.push_back(ti);
126
127 prev_ti = ti;
128 }
129
130 return true;
131 };
Kévin Petit8a560882019-03-21 15:24:34 +0000132};
133
David Neto22f144c2017-06-12 14:26:21 -0400134uint32_t clz(uint32_t v) {
135 uint32_t r;
136 uint32_t shift;
137
138 r = (v > 0xFFFF) << 4;
139 v >>= r;
140 shift = (v > 0xFF) << 3;
141 v >>= shift;
142 r |= shift;
143 shift = (v > 0xF) << 2;
144 v >>= shift;
145 r |= shift;
146 shift = (v > 0x3) << 1;
147 v >>= shift;
148 r |= shift;
149 r |= (v >> 1);
150
151 return r;
152}
153
154Type *getBoolOrBoolVectorTy(LLVMContext &C, unsigned elements) {
155 if (1 == elements) {
156 return Type::getInt1Ty(C);
157 } else {
158 return VectorType::get(Type::getInt1Ty(C), elements);
159 }
160}
161
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100162Type *getIntOrIntVectorTyForCast(LLVMContext &C, Type *Ty) {
163 Type *IntTy = Type::getIntNTy(C, Ty->getScalarSizeInBits());
164 if (Ty->isVectorTy()) {
165 IntTy = VectorType::get(IntTy, Ty->getVectorNumElements());
166 }
167 return IntTy;
168}
169
David Neto22f144c2017-06-12 14:26:21 -0400170struct ReplaceOpenCLBuiltinPass final : public ModulePass {
171 static char ID;
172 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
173
174 bool runOnModule(Module &M) override;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000175 bool replaceAbs(Module &M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100176 bool replaceAbsDiff(Module &M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100177 bool replaceCopysign(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400178 bool replaceRecip(Module &M);
179 bool replaceDivide(Module &M);
Kévin Petit1329a002019-06-15 05:54:05 +0100180 bool replaceDot(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400181 bool replaceExp10(Module &M);
Kévin Petit0644a9c2019-06-20 21:08:46 +0100182 bool replaceFmod(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400183 bool replaceLog10(Module &M);
184 bool replaceBarrier(Module &M);
185 bool replaceMemFence(Module &M);
186 bool replaceRelational(Module &M);
187 bool replaceIsInfAndIsNan(Module &M);
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100188 bool replaceIsFinite(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400189 bool replaceAllAndAny(Module &M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000190 bool replaceUpsample(Module &M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000191 bool replaceRotate(Module &M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000192 bool replaceConvert(Module &M);
Kévin Petit8a560882019-03-21 15:24:34 +0000193 bool replaceMulHiMadHi(Module &M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000194 bool replaceSelect(Module &M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000195 bool replaceBitSelect(Module &M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000196 bool replaceStepSmoothStep(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400197 bool replaceSignbit(Module &M);
198 bool replaceMadandMad24andMul24(Module &M);
199 bool replaceVloadHalf(Module &M);
200 bool replaceVloadHalf2(Module &M);
201 bool replaceVloadHalf4(Module &M);
David Neto6ad93232018-06-07 15:42:58 -0700202 bool replaceClspvVloadaHalf2(Module &M);
203 bool replaceClspvVloadaHalf4(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400204 bool replaceVstoreHalf(Module &M);
205 bool replaceVstoreHalf2(Module &M);
206 bool replaceVstoreHalf4(Module &M);
Kévin Petit06517a12019-12-09 19:40:31 +0000207 bool replaceSampledReadImageWithIntCoords(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400208 bool replaceAtomics(Module &M);
209 bool replaceCross(Module &M);
David Neto62653202017-10-16 19:05:18 -0400210 bool replaceFract(Module &M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700211 bool replaceVload(Module &M);
212 bool replaceVstore(Module &M);
David Neto22f144c2017-06-12 14:26:21 -0400213};
Kévin Petit91bc72e2019-04-08 15:17:46 +0100214} // namespace
David Neto22f144c2017-06-12 14:26:21 -0400215
216char ReplaceOpenCLBuiltinPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400217INITIALIZE_PASS(ReplaceOpenCLBuiltinPass, "ReplaceOpenCLBuiltin",
218 "Replace OpenCL Builtins Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -0400219
220namespace clspv {
221ModulePass *createReplaceOpenCLBuiltinPass() {
222 return new ReplaceOpenCLBuiltinPass();
223}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400224} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -0400225
226bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
227 bool Changed = false;
228
Kévin Petit2444e9b2018-11-09 14:14:37 +0000229 Changed |= replaceAbs(M);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100230 Changed |= replaceAbsDiff(M);
Kévin Petit8c1be282019-04-02 19:34:25 +0100231 Changed |= replaceCopysign(M);
David Neto22f144c2017-06-12 14:26:21 -0400232 Changed |= replaceRecip(M);
233 Changed |= replaceDivide(M);
Kévin Petit1329a002019-06-15 05:54:05 +0100234 Changed |= replaceDot(M);
David Neto22f144c2017-06-12 14:26:21 -0400235 Changed |= replaceExp10(M);
Kévin Petit0644a9c2019-06-20 21:08:46 +0100236 Changed |= replaceFmod(M);
David Neto22f144c2017-06-12 14:26:21 -0400237 Changed |= replaceLog10(M);
238 Changed |= replaceBarrier(M);
239 Changed |= replaceMemFence(M);
240 Changed |= replaceRelational(M);
241 Changed |= replaceIsInfAndIsNan(M);
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100242 Changed |= replaceIsFinite(M);
David Neto22f144c2017-06-12 14:26:21 -0400243 Changed |= replaceAllAndAny(M);
Kévin Petitbf0036c2019-03-06 13:57:10 +0000244 Changed |= replaceUpsample(M);
Kévin Petitd44eef52019-03-08 13:22:14 +0000245 Changed |= replaceRotate(M);
Kévin Petit9d1a9d12019-03-25 15:23:46 +0000246 Changed |= replaceConvert(M);
Kévin Petit8a560882019-03-21 15:24:34 +0000247 Changed |= replaceMulHiMadHi(M);
Kévin Petitf5b78a22018-10-25 14:32:17 +0000248 Changed |= replaceSelect(M);
Kévin Petite7d0cce2018-10-31 12:38:56 +0000249 Changed |= replaceBitSelect(M);
Kévin Petit6b0a9532018-10-30 20:00:39 +0000250 Changed |= replaceStepSmoothStep(M);
David Neto22f144c2017-06-12 14:26:21 -0400251 Changed |= replaceSignbit(M);
252 Changed |= replaceMadandMad24andMul24(M);
253 Changed |= replaceVloadHalf(M);
254 Changed |= replaceVloadHalf2(M);
255 Changed |= replaceVloadHalf4(M);
David Neto6ad93232018-06-07 15:42:58 -0700256 Changed |= replaceClspvVloadaHalf2(M);
257 Changed |= replaceClspvVloadaHalf4(M);
David Neto22f144c2017-06-12 14:26:21 -0400258 Changed |= replaceVstoreHalf(M);
259 Changed |= replaceVstoreHalf2(M);
260 Changed |= replaceVstoreHalf4(M);
Kévin Petit06517a12019-12-09 19:40:31 +0000261 Changed |= replaceSampledReadImageWithIntCoords(M);
David Neto22f144c2017-06-12 14:26:21 -0400262 Changed |= replaceAtomics(M);
263 Changed |= replaceCross(M);
David Neto62653202017-10-16 19:05:18 -0400264 Changed |= replaceFract(M);
Derek Chowcfd368b2017-10-19 20:58:45 -0700265 Changed |= replaceVload(M);
266 Changed |= replaceVstore(M);
David Neto22f144c2017-06-12 14:26:21 -0400267
268 return Changed;
269}
270
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400271bool replaceCallsWithValue(Module &M, std::vector<const char *> Names,
272 std::function<Value *(CallInst *)> Replacer) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000273
Kévin Petite8edce32019-04-10 14:23:32 +0100274 bool Changed = false;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000275
276 for (auto Name : Names) {
277 // If we find a function with the matching name.
278 if (auto F = M.getFunction(Name)) {
279 SmallVector<Instruction *, 4> ToRemoves;
280
281 // Walk the users of the function.
282 for (auto &U : F->uses()) {
283 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
Kévin Petit2444e9b2018-11-09 14:14:37 +0000284
Kévin Petite8edce32019-04-10 14:23:32 +0100285 auto NewValue = Replacer(CI);
286
287 if (NewValue != nullptr) {
288 CI->replaceAllUsesWith(NewValue);
289 }
Kévin Petit2444e9b2018-11-09 14:14:37 +0000290
291 // Lastly, remember to remove the user.
292 ToRemoves.push_back(CI);
293 }
294 }
295
296 Changed = !ToRemoves.empty();
297
298 // And cleanup the calls we don't use anymore.
299 for (auto V : ToRemoves) {
300 V->eraseFromParent();
301 }
302
303 // And remove the function we don't need either too.
304 F->eraseFromParent();
305 }
306 }
307
308 return Changed;
309}
310
Kévin Petite8edce32019-04-10 14:23:32 +0100311bool ReplaceOpenCLBuiltinPass::replaceAbs(Module &M) {
Kévin Petit91bc72e2019-04-08 15:17:46 +0100312
Kévin Petite8edce32019-04-10 14:23:32 +0100313 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400314 "_Z3absh", "_Z3absDv2_h", "_Z3absDv3_h", "_Z3absDv4_h",
315 "_Z3abst", "_Z3absDv2_t", "_Z3absDv3_t", "_Z3absDv4_t",
316 "_Z3absj", "_Z3absDv2_j", "_Z3absDv3_j", "_Z3absDv4_j",
317 "_Z3absm", "_Z3absDv2_m", "_Z3absDv3_m", "_Z3absDv4_m",
Kévin Petite8edce32019-04-10 14:23:32 +0100318 };
319
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400320 return replaceCallsWithValue(M, Names,
321 [](CallInst *CI) { return CI->getOperand(0); });
Kévin Petite8edce32019-04-10 14:23:32 +0100322}
323
324bool ReplaceOpenCLBuiltinPass::replaceAbsDiff(Module &M) {
325
326 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400327 "_Z8abs_diffcc", "_Z8abs_diffDv2_cS_", "_Z8abs_diffDv3_cS_",
328 "_Z8abs_diffDv4_cS_", "_Z8abs_diffhh", "_Z8abs_diffDv2_hS_",
329 "_Z8abs_diffDv3_hS_", "_Z8abs_diffDv4_hS_", "_Z8abs_diffss",
330 "_Z8abs_diffDv2_sS_", "_Z8abs_diffDv3_sS_", "_Z8abs_diffDv4_sS_",
331 "_Z8abs_difftt", "_Z8abs_diffDv2_tS_", "_Z8abs_diffDv3_tS_",
332 "_Z8abs_diffDv4_tS_", "_Z8abs_diffii", "_Z8abs_diffDv2_iS_",
333 "_Z8abs_diffDv3_iS_", "_Z8abs_diffDv4_iS_", "_Z8abs_diffjj",
334 "_Z8abs_diffDv2_jS_", "_Z8abs_diffDv3_jS_", "_Z8abs_diffDv4_jS_",
335 "_Z8abs_diffll", "_Z8abs_diffDv2_lS_", "_Z8abs_diffDv3_lS_",
336 "_Z8abs_diffDv4_lS_", "_Z8abs_diffmm", "_Z8abs_diffDv2_mS_",
337 "_Z8abs_diffDv3_mS_", "_Z8abs_diffDv4_mS_",
Kévin Petit91bc72e2019-04-08 15:17:46 +0100338 };
339
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400340 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100341 auto XValue = CI->getOperand(0);
342 auto YValue = CI->getOperand(1);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100343
Kévin Petite8edce32019-04-10 14:23:32 +0100344 IRBuilder<> Builder(CI);
345 auto XmY = Builder.CreateSub(XValue, YValue);
346 auto YmX = Builder.CreateSub(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100347
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400348 Value *Cmp;
Kévin Petite8edce32019-04-10 14:23:32 +0100349 auto F = CI->getCalledFunction();
350 auto finfo = FunctionInfo::getFromMangledName(F->getName());
351 if (finfo.isArgSigned(0)) {
352 Cmp = Builder.CreateICmpSGT(YValue, XValue);
353 } else {
354 Cmp = Builder.CreateICmpUGT(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100355 }
Kévin Petit91bc72e2019-04-08 15:17:46 +0100356
Kévin Petite8edce32019-04-10 14:23:32 +0100357 return Builder.CreateSelect(Cmp, YmX, XmY);
358 });
Kévin Petit91bc72e2019-04-08 15:17:46 +0100359}
360
Kévin Petit8c1be282019-04-02 19:34:25 +0100361bool ReplaceOpenCLBuiltinPass::replaceCopysign(Module &M) {
Kévin Petit8c1be282019-04-02 19:34:25 +0100362
Kévin Petite8edce32019-04-10 14:23:32 +0100363 std::vector<const char *> Names = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400364 "_Z8copysignff",
365 "_Z8copysignDv2_fS_",
366 "_Z8copysignDv3_fS_",
367 "_Z8copysignDv4_fS_",
Kévin Petit8c1be282019-04-02 19:34:25 +0100368 };
369
Kévin Petite8edce32019-04-10 14:23:32 +0100370 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
371 auto XValue = CI->getOperand(0);
372 auto YValue = CI->getOperand(1);
Kévin Petit8c1be282019-04-02 19:34:25 +0100373
Kévin Petite8edce32019-04-10 14:23:32 +0100374 auto Ty = XValue->getType();
Kévin Petit8c1be282019-04-02 19:34:25 +0100375
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400376 Type *IntTy = Type::getIntNTy(M.getContext(), Ty->getScalarSizeInBits());
Kévin Petite8edce32019-04-10 14:23:32 +0100377 if (Ty->isVectorTy()) {
378 IntTy = VectorType::get(IntTy, Ty->getVectorNumElements());
Kévin Petit8c1be282019-04-02 19:34:25 +0100379 }
Kévin Petit8c1be282019-04-02 19:34:25 +0100380
Kévin Petite8edce32019-04-10 14:23:32 +0100381 // Return X with the sign of Y
382
383 // Sign bit masks
384 auto SignBit = IntTy->getScalarSizeInBits() - 1;
385 auto SignBitMask = 1 << SignBit;
386 auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
387 auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);
388
389 IRBuilder<> Builder(CI);
390
391 // Extract sign of Y
392 auto YInt = Builder.CreateBitCast(YValue, IntTy);
393 auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);
394
395 // Clear sign bit in X
396 auto XInt = Builder.CreateBitCast(XValue, IntTy);
397 XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);
398
399 // Insert sign bit of Y into X
400 auto NewXInt = Builder.CreateOr(XInt, YSign);
401
402 // And cast back to floating-point
403 return Builder.CreateBitCast(NewXInt, Ty);
404 });
Kévin Petit8c1be282019-04-02 19:34:25 +0100405}
406
David Neto22f144c2017-06-12 14:26:21 -0400407bool ReplaceOpenCLBuiltinPass::replaceRecip(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400408
Kévin Petite8edce32019-04-10 14:23:32 +0100409 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400410 "_Z10half_recipf", "_Z12native_recipf", "_Z10half_recipDv2_f",
411 "_Z12native_recipDv2_f", "_Z10half_recipDv3_f", "_Z12native_recipDv3_f",
412 "_Z10half_recipDv4_f", "_Z12native_recipDv4_f",
413 };
414
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400415 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100416 // Recip has one arg.
417 auto Arg = CI->getOperand(0);
418 auto Cst1 = ConstantFP::get(Arg->getType(), 1.0);
419 return BinaryOperator::Create(Instruction::FDiv, Cst1, Arg, "", CI);
420 });
David Neto22f144c2017-06-12 14:26:21 -0400421}
422
423bool ReplaceOpenCLBuiltinPass::replaceDivide(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400424
Kévin Petite8edce32019-04-10 14:23:32 +0100425 std::vector<const char *> Names = {
David Neto22f144c2017-06-12 14:26:21 -0400426 "_Z11half_divideff", "_Z13native_divideff",
427 "_Z11half_divideDv2_fS_", "_Z13native_divideDv2_fS_",
428 "_Z11half_divideDv3_fS_", "_Z13native_divideDv3_fS_",
429 "_Z11half_divideDv4_fS_", "_Z13native_divideDv4_fS_",
430 };
431
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400432 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100433 auto Op0 = CI->getOperand(0);
434 auto Op1 = CI->getOperand(1);
435 return BinaryOperator::Create(Instruction::FDiv, Op0, Op1, "", CI);
436 });
David Neto22f144c2017-06-12 14:26:21 -0400437}
438
Kévin Petit1329a002019-06-15 05:54:05 +0100439bool ReplaceOpenCLBuiltinPass::replaceDot(Module &M) {
440
441 std::vector<const char *> Names = {
442 "_Z3dotff",
443 "_Z3dotDv2_fS_",
444 "_Z3dotDv3_fS_",
445 "_Z3dotDv4_fS_",
446 };
447
448 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
449 auto Op0 = CI->getOperand(0);
450 auto Op1 = CI->getOperand(1);
451
452 Value *V;
453 if (Op0->getType()->isVectorTy()) {
454 V = clspv::InsertSPIRVOp(CI, spv::OpDot, {Attribute::ReadNone},
455 CI->getType(), {Op0, Op1});
456 } else {
457 V = BinaryOperator::Create(Instruction::FMul, Op0, Op1, "", CI);
458 }
459
460 return V;
461 });
462}
463
David Neto22f144c2017-06-12 14:26:21 -0400464bool ReplaceOpenCLBuiltinPass::replaceExp10(Module &M) {
465 bool Changed = false;
466
467 const std::map<const char *, const char *> Map = {
468 {"_Z5exp10f", "_Z3expf"},
469 {"_Z10half_exp10f", "_Z8half_expf"},
470 {"_Z12native_exp10f", "_Z10native_expf"},
471 {"_Z5exp10Dv2_f", "_Z3expDv2_f"},
472 {"_Z10half_exp10Dv2_f", "_Z8half_expDv2_f"},
473 {"_Z12native_exp10Dv2_f", "_Z10native_expDv2_f"},
474 {"_Z5exp10Dv3_f", "_Z3expDv3_f"},
475 {"_Z10half_exp10Dv3_f", "_Z8half_expDv3_f"},
476 {"_Z12native_exp10Dv3_f", "_Z10native_expDv3_f"},
477 {"_Z5exp10Dv4_f", "_Z3expDv4_f"},
478 {"_Z10half_exp10Dv4_f", "_Z8half_expDv4_f"},
479 {"_Z12native_exp10Dv4_f", "_Z10native_expDv4_f"}};
480
481 for (auto Pair : Map) {
482 // If we find a function with the matching name.
483 if (auto F = M.getFunction(Pair.first)) {
484 SmallVector<Instruction *, 4> ToRemoves;
485
486 // Walk the users of the function.
487 for (auto &U : F->uses()) {
488 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
489 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
490
491 auto Arg = CI->getOperand(0);
492
493 // Constant of the natural log of 10 (ln(10)).
494 const double Ln10 =
495 2.302585092994045684017991454684364207601101488628772976033;
496
497 auto Mul = BinaryOperator::Create(
498 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "",
499 CI);
500
501 auto NewCI = CallInst::Create(NewF, Mul, "", CI);
502
503 CI->replaceAllUsesWith(NewCI);
504
505 // Lastly, remember to remove the user.
506 ToRemoves.push_back(CI);
507 }
508 }
509
510 Changed = !ToRemoves.empty();
511
512 // And cleanup the calls we don't use anymore.
513 for (auto V : ToRemoves) {
514 V->eraseFromParent();
515 }
516
517 // And remove the function we don't need either too.
518 F->eraseFromParent();
519 }
520 }
521
522 return Changed;
523}
524
Kévin Petit0644a9c2019-06-20 21:08:46 +0100525bool ReplaceOpenCLBuiltinPass::replaceFmod(Module &M) {
526
527 std::vector<const char *> Names = {
528 "_Z4fmodff",
529 "_Z4fmodDv2_fS_",
530 "_Z4fmodDv3_fS_",
531 "_Z4fmodDv4_fS_",
532 };
533
534 // OpenCL fmod(x,y) is x - y * trunc(x/y)
535 // The sign for a non-zero result is taken from x.
536 // (Try an example.)
537 // So translate to FRem
538 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
539 auto Op0 = CI->getOperand(0);
540 auto Op1 = CI->getOperand(1);
541 return BinaryOperator::Create(Instruction::FRem, Op0, Op1, "", CI);
542 });
543}
544
David Neto22f144c2017-06-12 14:26:21 -0400545bool ReplaceOpenCLBuiltinPass::replaceLog10(Module &M) {
546 bool Changed = false;
547
548 const std::map<const char *, const char *> Map = {
549 {"_Z5log10f", "_Z3logf"},
550 {"_Z10half_log10f", "_Z8half_logf"},
551 {"_Z12native_log10f", "_Z10native_logf"},
552 {"_Z5log10Dv2_f", "_Z3logDv2_f"},
553 {"_Z10half_log10Dv2_f", "_Z8half_logDv2_f"},
554 {"_Z12native_log10Dv2_f", "_Z10native_logDv2_f"},
555 {"_Z5log10Dv3_f", "_Z3logDv3_f"},
556 {"_Z10half_log10Dv3_f", "_Z8half_logDv3_f"},
557 {"_Z12native_log10Dv3_f", "_Z10native_logDv3_f"},
558 {"_Z5log10Dv4_f", "_Z3logDv4_f"},
559 {"_Z10half_log10Dv4_f", "_Z8half_logDv4_f"},
560 {"_Z12native_log10Dv4_f", "_Z10native_logDv4_f"}};
561
562 for (auto Pair : Map) {
563 // If we find a function with the matching name.
564 if (auto F = M.getFunction(Pair.first)) {
565 SmallVector<Instruction *, 4> ToRemoves;
566
567 // Walk the users of the function.
568 for (auto &U : F->uses()) {
569 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
570 auto NewF = M.getOrInsertFunction(Pair.second, F->getFunctionType());
571
572 auto Arg = CI->getOperand(0);
573
574 // Constant of the reciprocal of the natural log of 10 (ln(10)).
575 const double Ln10 =
576 0.434294481903251827651128918916605082294397005803666566114;
577
578 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
579
580 auto Mul = BinaryOperator::Create(
581 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), NewCI,
582 "", CI);
583
584 CI->replaceAllUsesWith(Mul);
585
586 // Lastly, remember to remove the user.
587 ToRemoves.push_back(CI);
588 }
589 }
590
591 Changed = !ToRemoves.empty();
592
593 // And cleanup the calls we don't use anymore.
594 for (auto V : ToRemoves) {
595 V->eraseFromParent();
596 }
597
598 // And remove the function we don't need either too.
599 F->eraseFromParent();
600 }
601 }
602
603 return Changed;
604}
605
606bool ReplaceOpenCLBuiltinPass::replaceBarrier(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -0400607
608 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
609
Kévin Petitc4643922019-06-17 19:32:05 +0100610 const std::vector<const char *> Names = {
alan-bakerf3bce4a2019-06-28 16:01:15 -0400611 "_Z7barrierj",
Kévin Petitc4643922019-06-17 19:32:05 +0100612 };
David Neto22f144c2017-06-12 14:26:21 -0400613
Kévin Petitc4643922019-06-17 19:32:05 +0100614 return replaceCallsWithValue(M, Names, [](CallInst *CI) {
615 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400616
Kévin Petitc4643922019-06-17 19:32:05 +0100617 // We need to map the OpenCL constants to the SPIR-V equivalents.
618 const auto LocalMemFence =
619 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
620 const auto GlobalMemFence =
621 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
622 const auto ConstantSequentiallyConsistent = ConstantInt::get(
623 Arg->getType(), spv::MemorySemanticsSequentiallyConsistentMask);
624 const auto ConstantScopeDevice =
625 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
626 const auto ConstantScopeWorkgroup =
627 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
David Neto22f144c2017-06-12 14:26:21 -0400628
Kévin Petitc4643922019-06-17 19:32:05 +0100629 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
630 const auto LocalMemFenceMask =
631 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
632 const auto WorkgroupShiftAmount =
633 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
634 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
635 Instruction::Shl, LocalMemFenceMask,
636 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400637
Kévin Petitc4643922019-06-17 19:32:05 +0100638 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
639 const auto GlobalMemFenceMask =
640 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
641 const auto UniformShiftAmount =
642 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
643 const auto MemorySemanticsUniform = BinaryOperator::Create(
644 Instruction::Shl, GlobalMemFenceMask,
645 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400646
Kévin Petitc4643922019-06-17 19:32:05 +0100647 // And combine the above together, also adding in
648 // MemorySemanticsSequentiallyConsistentMask.
649 auto MemorySemantics =
650 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
651 ConstantSequentiallyConsistent, "", CI);
652 MemorySemantics = BinaryOperator::Create(Instruction::Or, MemorySemantics,
653 MemorySemanticsUniform, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400654
Kévin Petitc4643922019-06-17 19:32:05 +0100655 // For Memory Scope if we used CLK_GLOBAL_MEM_FENCE, we need to use
656 // Device Scope, otherwise Workgroup Scope.
657 const auto Cmp =
658 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, GlobalMemFenceMask,
659 GlobalMemFence, "", CI);
660 const auto MemoryScope = SelectInst::Create(Cmp, ConstantScopeDevice,
661 ConstantScopeWorkgroup, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400662
Kévin Petitc4643922019-06-17 19:32:05 +0100663 // Lastly, the Execution Scope is always Workgroup Scope.
664 const auto ExecutionScope = ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400665
Kévin Petitc4643922019-06-17 19:32:05 +0100666 return clspv::InsertSPIRVOp(CI, spv::OpControlBarrier,
667 {Attribute::NoDuplicate}, CI->getType(),
668 {ExecutionScope, MemoryScope, MemorySemantics});
669 });
David Neto22f144c2017-06-12 14:26:21 -0400670}
671
672bool ReplaceOpenCLBuiltinPass::replaceMemFence(Module &M) {
673 bool Changed = false;
674
675 enum { CLK_LOCAL_MEM_FENCE = 0x01, CLK_GLOBAL_MEM_FENCE = 0x02 };
676
Kévin Petitc4643922019-06-17 19:32:05 +0100677 using Tuple = std::tuple<spv::Op, unsigned>;
Neil Henning39672102017-09-29 14:33:13 +0100678 const std::map<const char *, Tuple> Map = {
Kévin Petitc4643922019-06-17 19:32:05 +0100679 {"_Z9mem_fencej", Tuple(spv::OpMemoryBarrier,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400680 spv::MemorySemanticsSequentiallyConsistentMask)},
Neil Henning39672102017-09-29 14:33:13 +0100681 {"_Z14read_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100682 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsAcquireMask)},
Neil Henning39672102017-09-29 14:33:13 +0100683 {"_Z15write_mem_fencej",
Kévin Petitc4643922019-06-17 19:32:05 +0100684 Tuple(spv::OpMemoryBarrier, spv::MemorySemanticsReleaseMask)}};
David Neto22f144c2017-06-12 14:26:21 -0400685
686 for (auto Pair : Map) {
687 // If we find a function with the matching name.
688 if (auto F = M.getFunction(Pair.first)) {
689 SmallVector<Instruction *, 4> ToRemoves;
690
691 // Walk the users of the function.
692 for (auto &U : F->uses()) {
693 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -0400694
695 auto Arg = CI->getOperand(0);
696
697 // We need to map the OpenCL constants to the SPIR-V equivalents.
698 const auto LocalMemFence =
699 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
700 const auto GlobalMemFence =
701 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
702 const auto ConstantMemorySemantics =
Neil Henning39672102017-09-29 14:33:13 +0100703 ConstantInt::get(Arg->getType(), std::get<1>(Pair.second));
David Neto22f144c2017-06-12 14:26:21 -0400704 const auto ConstantScopeDevice =
705 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
706
707 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
708 const auto LocalMemFenceMask = BinaryOperator::Create(
709 Instruction::And, LocalMemFence, Arg, "", CI);
710 const auto WorkgroupShiftAmount =
711 clz(spv::MemorySemanticsWorkgroupMemoryMask) -
712 clz(CLK_LOCAL_MEM_FENCE);
713 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
714 Instruction::Shl, LocalMemFenceMask,
715 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
716
717 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
718 const auto GlobalMemFenceMask = BinaryOperator::Create(
719 Instruction::And, GlobalMemFence, Arg, "", CI);
720 const auto UniformShiftAmount =
721 clz(spv::MemorySemanticsUniformMemoryMask) -
722 clz(CLK_GLOBAL_MEM_FENCE);
723 const auto MemorySemanticsUniform = BinaryOperator::Create(
724 Instruction::Shl, GlobalMemFenceMask,
725 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
726
727 // And combine the above together, also adding in
728 // MemorySemanticsSequentiallyConsistentMask.
729 auto MemorySemantics =
730 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
731 ConstantMemorySemantics, "", CI);
732 MemorySemantics = BinaryOperator::Create(
733 Instruction::Or, MemorySemantics, MemorySemanticsUniform, "", CI);
734
735 // Memory Scope is always device.
736 const auto MemoryScope = ConstantScopeDevice;
737
Kévin Petitc4643922019-06-17 19:32:05 +0100738 const auto SPIRVOp = std::get<0>(Pair.second);
739 auto NewCI = clspv::InsertSPIRVOp(CI, SPIRVOp, {}, CI->getType(),
740 {MemoryScope, MemorySemantics});
David Neto22f144c2017-06-12 14:26:21 -0400741
742 CI->replaceAllUsesWith(NewCI);
743
744 // Lastly, remember to remove the user.
745 ToRemoves.push_back(CI);
746 }
747 }
748
749 Changed = !ToRemoves.empty();
750
751 // And cleanup the calls we don't use anymore.
752 for (auto V : ToRemoves) {
753 V->eraseFromParent();
754 }
755
756 // And remove the function we don't need either too.
757 F->eraseFromParent();
758 }
759 }
760
761 return Changed;
762}
763
764bool ReplaceOpenCLBuiltinPass::replaceRelational(Module &M) {
765 bool Changed = false;
766
767 const std::map<const char *, std::pair<CmpInst::Predicate, int32_t>> Map = {
768 {"_Z7isequalff", {CmpInst::FCMP_OEQ, 1}},
769 {"_Z7isequalDv2_fS_", {CmpInst::FCMP_OEQ, -1}},
770 {"_Z7isequalDv3_fS_", {CmpInst::FCMP_OEQ, -1}},
771 {"_Z7isequalDv4_fS_", {CmpInst::FCMP_OEQ, -1}},
772 {"_Z9isgreaterff", {CmpInst::FCMP_OGT, 1}},
773 {"_Z9isgreaterDv2_fS_", {CmpInst::FCMP_OGT, -1}},
774 {"_Z9isgreaterDv3_fS_", {CmpInst::FCMP_OGT, -1}},
775 {"_Z9isgreaterDv4_fS_", {CmpInst::FCMP_OGT, -1}},
776 {"_Z14isgreaterequalff", {CmpInst::FCMP_OGE, 1}},
777 {"_Z14isgreaterequalDv2_fS_", {CmpInst::FCMP_OGE, -1}},
778 {"_Z14isgreaterequalDv3_fS_", {CmpInst::FCMP_OGE, -1}},
779 {"_Z14isgreaterequalDv4_fS_", {CmpInst::FCMP_OGE, -1}},
780 {"_Z6islessff", {CmpInst::FCMP_OLT, 1}},
781 {"_Z6islessDv2_fS_", {CmpInst::FCMP_OLT, -1}},
782 {"_Z6islessDv3_fS_", {CmpInst::FCMP_OLT, -1}},
783 {"_Z6islessDv4_fS_", {CmpInst::FCMP_OLT, -1}},
784 {"_Z11islessequalff", {CmpInst::FCMP_OLE, 1}},
785 {"_Z11islessequalDv2_fS_", {CmpInst::FCMP_OLE, -1}},
786 {"_Z11islessequalDv3_fS_", {CmpInst::FCMP_OLE, -1}},
787 {"_Z11islessequalDv4_fS_", {CmpInst::FCMP_OLE, -1}},
788 {"_Z10isnotequalff", {CmpInst::FCMP_ONE, 1}},
789 {"_Z10isnotequalDv2_fS_", {CmpInst::FCMP_ONE, -1}},
790 {"_Z10isnotequalDv3_fS_", {CmpInst::FCMP_ONE, -1}},
791 {"_Z10isnotequalDv4_fS_", {CmpInst::FCMP_ONE, -1}},
792 };
793
794 for (auto Pair : Map) {
795 // If we find a function with the matching name.
796 if (auto F = M.getFunction(Pair.first)) {
797 SmallVector<Instruction *, 4> ToRemoves;
798
799 // Walk the users of the function.
800 for (auto &U : F->uses()) {
801 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
802 // The predicate to use in the CmpInst.
803 auto Predicate = Pair.second.first;
804
805 // The value to return for true.
806 auto TrueValue =
807 ConstantInt::getSigned(CI->getType(), Pair.second.second);
808
809 // The value to return for false.
810 auto FalseValue = Constant::getNullValue(CI->getType());
811
812 auto Arg1 = CI->getOperand(0);
813 auto Arg2 = CI->getOperand(1);
814
815 const auto Cmp =
816 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
817
818 const auto Select =
819 SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
820
821 CI->replaceAllUsesWith(Select);
822
823 // Lastly, remember to remove the user.
824 ToRemoves.push_back(CI);
825 }
826 }
827
828 Changed = !ToRemoves.empty();
829
830 // And cleanup the calls we don't use anymore.
831 for (auto V : ToRemoves) {
832 V->eraseFromParent();
833 }
834
835 // And remove the function we don't need either too.
836 F->eraseFromParent();
837 }
838 }
839
840 return Changed;
841}
842
843bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Module &M) {
844 bool Changed = false;
845
Kévin Petitff03aee2019-06-12 19:39:03 +0100846 const std::map<const char *, std::pair<spv::Op, int32_t>> Map = {
847 {"_Z5isinff", {spv::OpIsInf, 1}},
848 {"_Z5isinfDv2_f", {spv::OpIsInf, -1}},
849 {"_Z5isinfDv3_f", {spv::OpIsInf, -1}},
850 {"_Z5isinfDv4_f", {spv::OpIsInf, -1}},
851 {"_Z5isnanf", {spv::OpIsNan, 1}},
852 {"_Z5isnanDv2_f", {spv::OpIsNan, -1}},
853 {"_Z5isnanDv3_f", {spv::OpIsNan, -1}},
854 {"_Z5isnanDv4_f", {spv::OpIsNan, -1}},
David Neto22f144c2017-06-12 14:26:21 -0400855 };
856
857 for (auto Pair : Map) {
858 // If we find a function with the matching name.
859 if (auto F = M.getFunction(Pair.first)) {
860 SmallVector<Instruction *, 4> ToRemoves;
861
862 // Walk the users of the function.
863 for (auto &U : F->uses()) {
864 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
865 const auto CITy = CI->getType();
866
Kévin Petitff03aee2019-06-12 19:39:03 +0100867 auto SPIRVOp = Pair.second.first;
David Neto22f144c2017-06-12 14:26:21 -0400868
869 // The value to return for true.
870 auto TrueValue = ConstantInt::getSigned(CITy, Pair.second.second);
871
872 // The value to return for false.
873 auto FalseValue = Constant::getNullValue(CITy);
874
875 const auto CorrespondingBoolTy = getBoolOrBoolVectorTy(
876 M.getContext(),
877 CITy->isVectorTy() ? CITy->getVectorNumElements() : 1);
878
Kévin Petitff03aee2019-06-12 19:39:03 +0100879 auto NewCI =
880 clspv::InsertSPIRVOp(CI, SPIRVOp, {Attribute::ReadNone},
881 CorrespondingBoolTy, {CI->getOperand(0)});
David Neto22f144c2017-06-12 14:26:21 -0400882
883 const auto Select =
884 SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
885
886 CI->replaceAllUsesWith(Select);
887
888 // Lastly, remember to remove the user.
889 ToRemoves.push_back(CI);
890 }
891 }
892
893 Changed = !ToRemoves.empty();
894
895 // And cleanup the calls we don't use anymore.
896 for (auto V : ToRemoves) {
897 V->eraseFromParent();
898 }
899
900 // And remove the function we don't need either too.
901 F->eraseFromParent();
902 }
903 }
904
905 return Changed;
906}
907
Kévin Petitfdfa92e2019-09-25 14:20:58 +0100908bool ReplaceOpenCLBuiltinPass::replaceIsFinite(Module &M) {
909 std::vector<const char *> Names = {
910 "_Z8isfiniteh", "_Z8isfiniteDv2_h", "_Z8isfiniteDv3_h",
911 "_Z8isfiniteDv4_h", "_Z8isfinitef", "_Z8isfiniteDv2_f",
912 "_Z8isfiniteDv3_f", "_Z8isfiniteDv4_f", "_Z8isfinited",
913 "_Z8isfiniteDv2_d", "_Z8isfiniteDv3_d", "_Z8isfiniteDv4_d",
914 };
915
916 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
917 auto &C = M.getContext();
918 auto Val = CI->getOperand(0);
919 auto ValTy = Val->getType();
920 auto RetTy = CI->getType();
921
922 // Get a suitable integer type to represent the number
923 auto IntTy = getIntOrIntVectorTyForCast(C, ValTy);
924
925 // Create Mask
926 auto ScalarSize = ValTy->getScalarSizeInBits();
927 Value *InfMask;
928 switch (ScalarSize) {
929 case 16:
930 InfMask = ConstantInt::get(IntTy, 0x7C00U);
931 break;
932 case 32:
933 InfMask = ConstantInt::get(IntTy, 0x7F800000U);
934 break;
935 case 64:
936 InfMask = ConstantInt::get(IntTy, 0x7FF0000000000000ULL);
937 break;
938 default:
939 llvm_unreachable("Unsupported floating-point type");
940 }
941
942 IRBuilder<> Builder(CI);
943
944 // Bitcast to int
945 auto ValInt = Builder.CreateBitCast(Val, IntTy);
946
947 // Mask and compare
948 auto InfBits = Builder.CreateAnd(InfMask, ValInt);
949 auto Cmp = Builder.CreateICmp(CmpInst::ICMP_EQ, InfBits, InfMask);
950
951 auto RetFalse = ConstantInt::get(RetTy, 0);
952 Value *RetTrue;
953 if (ValTy->isVectorTy()) {
954 RetTrue = ConstantInt::getSigned(RetTy, -1);
955 } else {
956 RetTrue = ConstantInt::get(RetTy, 1);
957 }
958 return Builder.CreateSelect(Cmp, RetFalse, RetTrue);
959 });
960}
961
David Neto22f144c2017-06-12 14:26:21 -0400962bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Module &M) {
963 bool Changed = false;
964
Kévin Petitff03aee2019-06-12 19:39:03 +0100965 const std::map<const char *, spv::Op> Map = {
Kévin Petitfd27cca2018-10-31 13:00:17 +0000966 // all
Kévin Petitff03aee2019-06-12 19:39:03 +0100967 {"_Z3allc", spv::OpNop},
968 {"_Z3allDv2_c", spv::OpAll},
969 {"_Z3allDv3_c", spv::OpAll},
970 {"_Z3allDv4_c", spv::OpAll},
971 {"_Z3alls", spv::OpNop},
972 {"_Z3allDv2_s", spv::OpAll},
973 {"_Z3allDv3_s", spv::OpAll},
974 {"_Z3allDv4_s", spv::OpAll},
975 {"_Z3alli", spv::OpNop},
976 {"_Z3allDv2_i", spv::OpAll},
977 {"_Z3allDv3_i", spv::OpAll},
978 {"_Z3allDv4_i", spv::OpAll},
979 {"_Z3alll", spv::OpNop},
980 {"_Z3allDv2_l", spv::OpAll},
981 {"_Z3allDv3_l", spv::OpAll},
982 {"_Z3allDv4_l", spv::OpAll},
Kévin Petitfd27cca2018-10-31 13:00:17 +0000983
984 // any
Kévin Petitff03aee2019-06-12 19:39:03 +0100985 {"_Z3anyc", spv::OpNop},
986 {"_Z3anyDv2_c", spv::OpAny},
987 {"_Z3anyDv3_c", spv::OpAny},
988 {"_Z3anyDv4_c", spv::OpAny},
989 {"_Z3anys", spv::OpNop},
990 {"_Z3anyDv2_s", spv::OpAny},
991 {"_Z3anyDv3_s", spv::OpAny},
992 {"_Z3anyDv4_s", spv::OpAny},
993 {"_Z3anyi", spv::OpNop},
994 {"_Z3anyDv2_i", spv::OpAny},
995 {"_Z3anyDv3_i", spv::OpAny},
996 {"_Z3anyDv4_i", spv::OpAny},
997 {"_Z3anyl", spv::OpNop},
998 {"_Z3anyDv2_l", spv::OpAny},
999 {"_Z3anyDv3_l", spv::OpAny},
1000 {"_Z3anyDv4_l", spv::OpAny},
David Neto22f144c2017-06-12 14:26:21 -04001001 };
1002
1003 for (auto Pair : Map) {
1004 // If we find a function with the matching name.
1005 if (auto F = M.getFunction(Pair.first)) {
1006 SmallVector<Instruction *, 4> ToRemoves;
1007
1008 // Walk the users of the function.
1009 for (auto &U : F->uses()) {
1010 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -04001011
1012 auto Arg = CI->getOperand(0);
1013
1014 Value *V;
1015
Kévin Petitfd27cca2018-10-31 13:00:17 +00001016 // If the argument is a 32-bit int, just use a shift
1017 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
1018 V = BinaryOperator::Create(Instruction::LShr, Arg,
1019 ConstantInt::get(Arg->getType(), 31), "",
1020 CI);
1021 } else {
David Neto22f144c2017-06-12 14:26:21 -04001022 // The value for zero to compare against.
1023 const auto ZeroValue = Constant::getNullValue(Arg->getType());
1024
David Neto22f144c2017-06-12 14:26:21 -04001025 // The value to return for true.
1026 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
1027
1028 // The value to return for false.
1029 const auto FalseValue = Constant::getNullValue(CI->getType());
1030
Kévin Petitfd27cca2018-10-31 13:00:17 +00001031 const auto Cmp = CmpInst::Create(
1032 Instruction::ICmp, CmpInst::ICMP_SLT, Arg, ZeroValue, "", CI);
1033
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001034 Value *SelectSource;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001035
1036 // If we have a function to call, call it!
Kévin Petitff03aee2019-06-12 19:39:03 +01001037 const auto SPIRVOp = Pair.second;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001038
Kévin Petitff03aee2019-06-12 19:39:03 +01001039 if (SPIRVOp != spv::OpNop) {
Kévin Petitfd27cca2018-10-31 13:00:17 +00001040
Kévin Petitff03aee2019-06-12 19:39:03 +01001041 const auto BoolTy = Type::getInt1Ty(M.getContext());
Kévin Petitfd27cca2018-10-31 13:00:17 +00001042
Kévin Petitff03aee2019-06-12 19:39:03 +01001043 const auto NewCI = clspv::InsertSPIRVOp(
1044 CI, SPIRVOp, {Attribute::ReadNone}, BoolTy, {Cmp});
Kévin Petitfd27cca2018-10-31 13:00:17 +00001045 SelectSource = NewCI;
1046
1047 } else {
1048 SelectSource = Cmp;
1049 }
1050
1051 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001052 }
1053
1054 CI->replaceAllUsesWith(V);
1055
1056 // Lastly, remember to remove the user.
1057 ToRemoves.push_back(CI);
1058 }
1059 }
1060
1061 Changed = !ToRemoves.empty();
1062
1063 // And cleanup the calls we don't use anymore.
1064 for (auto V : ToRemoves) {
1065 V->eraseFromParent();
1066 }
1067
1068 // And remove the function we don't need either too.
1069 F->eraseFromParent();
1070 }
1071 }
1072
1073 return Changed;
1074}
1075
Kévin Petitbf0036c2019-03-06 13:57:10 +00001076bool ReplaceOpenCLBuiltinPass::replaceUpsample(Module &M) {
1077 bool Changed = false;
1078
1079 for (auto const &SymVal : M.getValueSymbolTable()) {
1080 // Skip symbols whose name doesn't match
1081 if (!SymVal.getKey().startswith("_Z8upsample")) {
1082 continue;
1083 }
1084 // Is there a function going by that name?
1085 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1086
1087 SmallVector<Instruction *, 4> ToRemoves;
1088
1089 // Walk the users of the function.
1090 for (auto &U : F->uses()) {
1091 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1092
1093 // Get arguments
1094 auto HiValue = CI->getOperand(0);
1095 auto LoValue = CI->getOperand(1);
1096
1097 // Don't touch overloads that aren't in OpenCL C
1098 auto HiType = HiValue->getType();
1099 auto LoType = LoValue->getType();
1100
1101 if (HiType != LoType) {
1102 continue;
1103 }
1104
1105 if (!HiType->isIntOrIntVectorTy()) {
1106 continue;
1107 }
1108
1109 if (HiType->getScalarSizeInBits() * 2 !=
1110 CI->getType()->getScalarSizeInBits()) {
1111 continue;
1112 }
1113
1114 if ((HiType->getScalarSizeInBits() != 8) &&
1115 (HiType->getScalarSizeInBits() != 16) &&
1116 (HiType->getScalarSizeInBits() != 32)) {
1117 continue;
1118 }
1119
1120 if (HiType->isVectorTy()) {
1121 if ((HiType->getVectorNumElements() != 2) &&
1122 (HiType->getVectorNumElements() != 3) &&
1123 (HiType->getVectorNumElements() != 4) &&
1124 (HiType->getVectorNumElements() != 8) &&
1125 (HiType->getVectorNumElements() != 16)) {
1126 continue;
1127 }
1128 }
1129
1130 // Convert both operands to the result type
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001131 auto HiCast =
1132 CastInst::CreateZExtOrBitCast(HiValue, CI->getType(), "", CI);
1133 auto LoCast =
1134 CastInst::CreateZExtOrBitCast(LoValue, CI->getType(), "", CI);
Kévin Petitbf0036c2019-03-06 13:57:10 +00001135
1136 // Shift high operand
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001137 auto ShiftAmount =
1138 ConstantInt::get(CI->getType(), HiType->getScalarSizeInBits());
Kévin Petitbf0036c2019-03-06 13:57:10 +00001139 auto HiShifted = BinaryOperator::Create(Instruction::Shl, HiCast,
1140 ShiftAmount, "", CI);
1141
1142 // OR both results
1143 Value *V = BinaryOperator::Create(Instruction::Or, HiShifted, LoCast,
1144 "", CI);
1145
1146 // Replace call with the expression
1147 CI->replaceAllUsesWith(V);
1148
1149 // Lastly, remember to remove the user.
1150 ToRemoves.push_back(CI);
1151 }
1152 }
1153
1154 Changed = !ToRemoves.empty();
1155
1156 // And cleanup the calls we don't use anymore.
1157 for (auto V : ToRemoves) {
1158 V->eraseFromParent();
1159 }
1160
1161 // And remove the function we don't need either too.
1162 F->eraseFromParent();
1163 }
1164 }
1165
1166 return Changed;
1167}
1168
Kévin Petitd44eef52019-03-08 13:22:14 +00001169bool ReplaceOpenCLBuiltinPass::replaceRotate(Module &M) {
1170 bool Changed = false;
1171
1172 for (auto const &SymVal : M.getValueSymbolTable()) {
1173 // Skip symbols whose name doesn't match
1174 if (!SymVal.getKey().startswith("_Z6rotate")) {
1175 continue;
1176 }
1177 // Is there a function going by that name?
1178 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1179
1180 SmallVector<Instruction *, 4> ToRemoves;
1181
1182 // Walk the users of the function.
1183 for (auto &U : F->uses()) {
1184 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1185
1186 // Get arguments
1187 auto SrcValue = CI->getOperand(0);
1188 auto RotAmount = CI->getOperand(1);
1189
1190 // Don't touch overloads that aren't in OpenCL C
1191 auto SrcType = SrcValue->getType();
1192 auto RotType = RotAmount->getType();
1193
1194 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1195 continue;
1196 }
1197
1198 if (!SrcType->isIntOrIntVectorTy()) {
1199 continue;
1200 }
1201
1202 if ((SrcType->getScalarSizeInBits() != 8) &&
1203 (SrcType->getScalarSizeInBits() != 16) &&
1204 (SrcType->getScalarSizeInBits() != 32) &&
1205 (SrcType->getScalarSizeInBits() != 64)) {
1206 continue;
1207 }
1208
1209 if (SrcType->isVectorTy()) {
1210 if ((SrcType->getVectorNumElements() != 2) &&
1211 (SrcType->getVectorNumElements() != 3) &&
1212 (SrcType->getVectorNumElements() != 4) &&
1213 (SrcType->getVectorNumElements() != 8) &&
1214 (SrcType->getVectorNumElements() != 16)) {
1215 continue;
1216 }
1217 }
1218
1219 // The approach used is to shift the top bits down, the bottom bits up
1220 // and OR the two shifted values.
1221
1222 // The rotation amount is to be treated modulo the element size.
1223 // Since SPIR-V shift ops don't support this, let's apply the
1224 // modulo ahead of shifting. The element size is always a power of
1225 // two so we can just AND with a mask.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001226 auto ModMask =
1227 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits() - 1);
Kévin Petitd44eef52019-03-08 13:22:14 +00001228 RotAmount = BinaryOperator::Create(Instruction::And, RotAmount,
1229 ModMask, "", CI);
1230
1231 // Let's calc the amount by which to shift top bits down
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001232 auto ScalarSize =
1233 ConstantInt::get(SrcType, SrcType->getScalarSizeInBits());
Kévin Petitd44eef52019-03-08 13:22:14 +00001234 auto DownAmount = BinaryOperator::Create(Instruction::Sub, ScalarSize,
1235 RotAmount, "", CI);
1236
1237 // Now shift the bottom bits up and the top bits down
1238 auto LoRotated = BinaryOperator::Create(Instruction::Shl, SrcValue,
1239 RotAmount, "", CI);
1240 auto HiRotated = BinaryOperator::Create(Instruction::LShr, SrcValue,
1241 DownAmount, "", CI);
1242
1243 // Finally OR the two shifted values
1244 Value *V = BinaryOperator::Create(Instruction::Or, LoRotated,
1245 HiRotated, "", CI);
1246
1247 // Replace call with the expression
1248 CI->replaceAllUsesWith(V);
1249
1250 // Lastly, remember to remove the user.
1251 ToRemoves.push_back(CI);
1252 }
1253 }
1254
1255 Changed = !ToRemoves.empty();
1256
1257 // And cleanup the calls we don't use anymore.
1258 for (auto V : ToRemoves) {
1259 V->eraseFromParent();
1260 }
1261
1262 // And remove the function we don't need either too.
1263 F->eraseFromParent();
1264 }
1265 }
1266
1267 return Changed;
1268}
1269
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001270bool ReplaceOpenCLBuiltinPass::replaceConvert(Module &M) {
1271 bool Changed = false;
1272
1273 for (auto const &SymVal : M.getValueSymbolTable()) {
1274
1275 // Skip symbols whose name obviously doesn't match
1276 if (!SymVal.getKey().contains("convert_")) {
1277 continue;
1278 }
1279
1280 // Is there a function going by that name?
1281 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1282
1283 // Get info from the mangled name
1284 FunctionInfo finfo;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001285 bool parsed = FunctionInfo::getFromMangledNameCheck(F->getName(), &finfo);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001286
1287 // All functions of interest are handled by our mangled name parser
1288 if (!parsed) {
1289 continue;
1290 }
1291
1292 // Move on if this isn't a call to convert_
1293 if (!finfo.name.startswith("convert_")) {
1294 continue;
1295 }
1296
1297 // Extract the destination type from the function name
1298 StringRef DstTypeName = finfo.name;
1299 DstTypeName.consume_front("convert_");
1300
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001301 auto DstSignedNess =
1302 StringSwitch<ArgTypeInfo::SignedNess>(DstTypeName)
1303 .StartsWith("char", ArgTypeInfo::SignedNess::Signed)
1304 .StartsWith("short", ArgTypeInfo::SignedNess::Signed)
1305 .StartsWith("int", ArgTypeInfo::SignedNess::Signed)
1306 .StartsWith("long", ArgTypeInfo::SignedNess::Signed)
1307 .StartsWith("uchar", ArgTypeInfo::SignedNess::Unsigned)
1308 .StartsWith("ushort", ArgTypeInfo::SignedNess::Unsigned)
1309 .StartsWith("uint", ArgTypeInfo::SignedNess::Unsigned)
1310 .StartsWith("ulong", ArgTypeInfo::SignedNess::Unsigned)
1311 .Default(ArgTypeInfo::SignedNess::None);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001312
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001313 bool DstIsSigned = DstSignedNess == ArgTypeInfo::SignedNess::Signed;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001314 bool SrcIsSigned = finfo.isArgSigned(0);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001315
1316 SmallVector<Instruction *, 4> ToRemoves;
1317
1318 // Walk the users of the function.
1319 for (auto &U : F->uses()) {
1320 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1321
1322 // Get arguments
1323 auto SrcValue = CI->getOperand(0);
1324
1325 // Don't touch overloads that aren't in OpenCL C
1326 auto SrcType = SrcValue->getType();
1327 auto DstType = CI->getType();
1328
1329 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1330 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1331 continue;
1332 }
1333
1334 if (SrcType->isVectorTy()) {
1335
1336 if (SrcType->getVectorNumElements() !=
1337 DstType->getVectorNumElements()) {
1338 continue;
1339 }
1340
1341 if ((SrcType->getVectorNumElements() != 2) &&
1342 (SrcType->getVectorNumElements() != 3) &&
1343 (SrcType->getVectorNumElements() != 4) &&
1344 (SrcType->getVectorNumElements() != 8) &&
1345 (SrcType->getVectorNumElements() != 16)) {
1346 continue;
1347 }
1348 }
1349
1350 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1351 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1352
1353 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1354 bool DstIsInt = DstType->isIntOrIntVectorTy();
1355
1356 Value *V;
1357 if (SrcIsFloat && DstIsFloat) {
1358 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1359 } else if (SrcIsFloat && DstIsInt) {
1360 if (DstIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001361 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "",
1362 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001363 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001364 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "",
1365 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001366 }
1367 } else if (SrcIsInt && DstIsFloat) {
1368 if (SrcIsSigned) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001369 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "",
1370 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001371 } else {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001372 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "",
1373 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001374 }
1375 } else if (SrcIsInt && DstIsInt) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001376 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "",
1377 CI);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001378 } else {
1379 // Not something we're supposed to handle, just move on
1380 continue;
1381 }
1382
1383 // Replace call with the expression
1384 CI->replaceAllUsesWith(V);
1385
1386 // Lastly, remember to remove the user.
1387 ToRemoves.push_back(CI);
1388 }
1389 }
1390
1391 Changed = !ToRemoves.empty();
1392
1393 // And cleanup the calls we don't use anymore.
1394 for (auto V : ToRemoves) {
1395 V->eraseFromParent();
1396 }
1397
1398 // And remove the function we don't need either too.
1399 F->eraseFromParent();
1400 }
1401 }
1402
1403 return Changed;
1404}
1405
Kévin Petit8a560882019-03-21 15:24:34 +00001406bool ReplaceOpenCLBuiltinPass::replaceMulHiMadHi(Module &M) {
1407 bool Changed = false;
1408
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001409 SmallVector<Function *, 4> FnWorklist;
Kévin Petit8a560882019-03-21 15:24:34 +00001410
Kévin Petit617a76d2019-04-04 13:54:16 +01001411 for (auto const &SymVal : M.getValueSymbolTable()) {
Kévin Petit8a560882019-03-21 15:24:34 +00001412 bool isMad = SymVal.getKey().startswith("_Z6mad_hi");
1413 bool isMul = SymVal.getKey().startswith("_Z6mul_hi");
1414
1415 // Skip symbols whose name doesn't match
1416 if (!isMad && !isMul) {
1417 continue;
1418 }
1419
1420 // Is there a function going by that name?
1421 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001422 FnWorklist.push_back(F);
Kévin Petit8a560882019-03-21 15:24:34 +00001423 }
1424 }
1425
Kévin Petit617a76d2019-04-04 13:54:16 +01001426 for (auto F : FnWorklist) {
1427 SmallVector<Instruction *, 4> ToRemoves;
1428
1429 bool isMad = F->getName().startswith("_Z6mad_hi");
1430 // Walk the users of the function.
1431 for (auto &U : F->uses()) {
1432 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1433
1434 // Get arguments
1435 auto AValue = CI->getOperand(0);
1436 auto BValue = CI->getOperand(1);
1437 auto CValue = CI->getOperand(2);
1438
1439 // Don't touch overloads that aren't in OpenCL C
1440 auto AType = AValue->getType();
1441 auto BType = BValue->getType();
1442 auto CType = CValue->getType();
1443
1444 if ((AType != BType) || (CI->getType() != AType) ||
1445 (isMad && (AType != CType))) {
1446 continue;
1447 }
1448
1449 if (!AType->isIntOrIntVectorTy()) {
1450 continue;
1451 }
1452
1453 if ((AType->getScalarSizeInBits() != 8) &&
1454 (AType->getScalarSizeInBits() != 16) &&
1455 (AType->getScalarSizeInBits() != 32) &&
1456 (AType->getScalarSizeInBits() != 64)) {
1457 continue;
1458 }
1459
1460 if (AType->isVectorTy()) {
1461 if ((AType->getVectorNumElements() != 2) &&
1462 (AType->getVectorNumElements() != 3) &&
1463 (AType->getVectorNumElements() != 4) &&
1464 (AType->getVectorNumElements() != 8) &&
1465 (AType->getVectorNumElements() != 16)) {
1466 continue;
1467 }
1468 }
1469
1470 // Get infos from the mangled OpenCL built-in function name
Kévin Petit91bc72e2019-04-08 15:17:46 +01001471 auto finfo = FunctionInfo::getFromMangledName(F->getName());
Kévin Petit617a76d2019-04-04 13:54:16 +01001472
1473 // Select the appropriate signed/unsigned SPIR-V op
1474 spv::Op opcode;
Kévin Petit91bc72e2019-04-08 15:17:46 +01001475 if (finfo.isArgSigned(0)) {
Kévin Petit617a76d2019-04-04 13:54:16 +01001476 opcode = spv::OpSMulExtended;
1477 } else {
1478 opcode = spv::OpUMulExtended;
1479 }
1480
1481 // Our SPIR-V op returns a struct, create a type for it
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001482 SmallVector<Type *, 2> TwoValueType = {AType, AType};
Kévin Petit617a76d2019-04-04 13:54:16 +01001483 auto ExMulRetType = StructType::create(TwoValueType);
1484
1485 // Call the SPIR-V op
1486 auto Call = clspv::InsertSPIRVOp(CI, opcode, {Attribute::ReadNone},
1487 ExMulRetType, {AValue, BValue});
1488
1489 // Get the high part of the result
1490 unsigned Idxs[] = {1};
1491 Value *V = ExtractValueInst::Create(Call, Idxs, "", CI);
1492
1493 // If we're handling a mad_hi, add the third argument to the result
1494 if (isMad) {
1495 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
1496 }
1497
1498 // Replace call with the expression
1499 CI->replaceAllUsesWith(V);
1500
1501 // Lastly, remember to remove the user.
1502 ToRemoves.push_back(CI);
1503 }
1504 }
1505
1506 Changed = !ToRemoves.empty();
1507
1508 // And cleanup the calls we don't use anymore.
1509 for (auto V : ToRemoves) {
1510 V->eraseFromParent();
1511 }
1512
1513 // And remove the function we don't need either too.
1514 F->eraseFromParent();
1515 }
1516
Kévin Petit8a560882019-03-21 15:24:34 +00001517 return Changed;
1518}
1519
Kévin Petitf5b78a22018-10-25 14:32:17 +00001520bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
1521 bool Changed = false;
1522
1523 for (auto const &SymVal : M.getValueSymbolTable()) {
1524 // Skip symbols whose name doesn't match
1525 if (!SymVal.getKey().startswith("_Z6select")) {
1526 continue;
1527 }
1528 // Is there a function going by that name?
1529 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1530
1531 SmallVector<Instruction *, 4> ToRemoves;
1532
1533 // Walk the users of the function.
1534 for (auto &U : F->uses()) {
1535 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1536
1537 // Get arguments
1538 auto FalseValue = CI->getOperand(0);
1539 auto TrueValue = CI->getOperand(1);
1540 auto PredicateValue = CI->getOperand(2);
1541
1542 // Don't touch overloads that aren't in OpenCL C
1543 auto FalseType = FalseValue->getType();
1544 auto TrueType = TrueValue->getType();
1545 auto PredicateType = PredicateValue->getType();
1546
1547 if (FalseType != TrueType) {
1548 continue;
1549 }
1550
1551 if (!PredicateType->isIntOrIntVectorTy()) {
1552 continue;
1553 }
1554
1555 if (!FalseType->isIntOrIntVectorTy() &&
1556 !FalseType->getScalarType()->isFloatingPointTy()) {
1557 continue;
1558 }
1559
1560 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1561 continue;
1562 }
1563
1564 if (FalseType->getScalarSizeInBits() !=
1565 PredicateType->getScalarSizeInBits()) {
1566 continue;
1567 }
1568
1569 if (FalseType->isVectorTy()) {
1570 if (FalseType->getVectorNumElements() !=
1571 PredicateType->getVectorNumElements()) {
1572 continue;
1573 }
1574
1575 if ((FalseType->getVectorNumElements() != 2) &&
1576 (FalseType->getVectorNumElements() != 3) &&
1577 (FalseType->getVectorNumElements() != 4) &&
1578 (FalseType->getVectorNumElements() != 8) &&
1579 (FalseType->getVectorNumElements() != 16)) {
1580 continue;
1581 }
1582 }
1583
1584 // Create constant
1585 const auto ZeroValue = Constant::getNullValue(PredicateType);
1586
1587 // Scalar and vector are to be treated differently
1588 CmpInst::Predicate Pred;
1589 if (PredicateType->isVectorTy()) {
1590 Pred = CmpInst::ICMP_SLT;
1591 } else {
1592 Pred = CmpInst::ICMP_NE;
1593 }
1594
1595 // Create comparison instruction
1596 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1597 ZeroValue, "", CI);
1598
1599 // Create select
1600 Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1601
1602 // Replace call with the selection
1603 CI->replaceAllUsesWith(V);
1604
1605 // Lastly, remember to remove the user.
1606 ToRemoves.push_back(CI);
1607 }
1608 }
1609
1610 Changed = !ToRemoves.empty();
1611
1612 // And cleanup the calls we don't use anymore.
1613 for (auto V : ToRemoves) {
1614 V->eraseFromParent();
1615 }
1616
1617 // And remove the function we don't need either too.
1618 F->eraseFromParent();
1619 }
1620 }
1621
1622 return Changed;
1623}
1624
Kévin Petite7d0cce2018-10-31 12:38:56 +00001625bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Module &M) {
1626 bool Changed = false;
1627
1628 for (auto const &SymVal : M.getValueSymbolTable()) {
1629 // Skip symbols whose name doesn't match
1630 if (!SymVal.getKey().startswith("_Z9bitselect")) {
1631 continue;
1632 }
1633 // Is there a function going by that name?
1634 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
1635
1636 SmallVector<Instruction *, 4> ToRemoves;
1637
1638 // Walk the users of the function.
1639 for (auto &U : F->uses()) {
1640 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1641
1642 if (CI->getNumOperands() != 4) {
1643 continue;
1644 }
1645
1646 // Get arguments
1647 auto FalseValue = CI->getOperand(0);
1648 auto TrueValue = CI->getOperand(1);
1649 auto PredicateValue = CI->getOperand(2);
1650
1651 // Don't touch overloads that aren't in OpenCL C
1652 auto FalseType = FalseValue->getType();
1653 auto TrueType = TrueValue->getType();
1654 auto PredicateType = PredicateValue->getType();
1655
1656 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1657 continue;
1658 }
1659
1660 if (TrueType->isVectorTy()) {
1661 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1662 !TrueType->getScalarType()->isIntegerTy()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001663 continue;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001664 }
1665 if ((TrueType->getVectorNumElements() != 2) &&
1666 (TrueType->getVectorNumElements() != 3) &&
1667 (TrueType->getVectorNumElements() != 4) &&
1668 (TrueType->getVectorNumElements() != 8) &&
1669 (TrueType->getVectorNumElements() != 16)) {
1670 continue;
1671 }
1672 }
1673
1674 // Remember the type of the operands
1675 auto OpType = TrueType;
1676
1677 // The actual bit selection will always be done on an integer type,
1678 // declare it here
1679 Type *BitType;
1680
1681 // If the operands are float, then bitcast them to int
1682 if (OpType->getScalarType()->isFloatingPointTy()) {
1683
1684 // First create the new type
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001685 BitType = getIntOrIntVectorTyForCast(M.getContext(), OpType);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001686
1687 // Then bitcast all operands
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001688 PredicateValue =
1689 CastInst::CreateZExtOrBitCast(PredicateValue, BitType, "", CI);
1690 FalseValue =
1691 CastInst::CreateZExtOrBitCast(FalseValue, BitType, "", CI);
1692 TrueValue =
1693 CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001694
1695 } else {
1696 // The operands have an integer type, use it directly
1697 BitType = OpType;
1698 }
1699
1700 // All the operands are now always integers
1701 // implement as (c & b) | (~c & a)
1702
1703 // Create our negated predicate value
1704 auto AllOnes = Constant::getAllOnesValue(BitType);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001705 auto NotPredicateValue = BinaryOperator::Create(
1706 Instruction::Xor, PredicateValue, AllOnes, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001707
1708 // Then put everything together
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001709 auto BitsFalse = BinaryOperator::Create(
1710 Instruction::And, NotPredicateValue, FalseValue, "", CI);
1711 auto BitsTrue = BinaryOperator::Create(
1712 Instruction::And, PredicateValue, TrueValue, "", CI);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001713
1714 Value *V = BinaryOperator::Create(Instruction::Or, BitsFalse,
1715 BitsTrue, "", CI);
1716
1717 // If we were dealing with a floating point type, we must bitcast
1718 // the result back to that
1719 if (OpType->getScalarType()->isFloatingPointTy()) {
1720 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1721 }
1722
1723 // Replace call with our new code
1724 CI->replaceAllUsesWith(V);
1725
1726 // Lastly, remember to remove the user.
1727 ToRemoves.push_back(CI);
1728 }
1729 }
1730
1731 Changed = !ToRemoves.empty();
1732
1733 // And cleanup the calls we don't use anymore.
1734 for (auto V : ToRemoves) {
1735 V->eraseFromParent();
1736 }
1737
1738 // And remove the function we don't need either too.
1739 F->eraseFromParent();
1740 }
1741 }
1742
1743 return Changed;
1744}
1745
Kévin Petit6b0a9532018-10-30 20:00:39 +00001746bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
1747 bool Changed = false;
1748
1749 const std::map<const char *, const char *> Map = {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001750 {"_Z4stepfDv2_f", "_Z4stepDv2_fS_"},
1751 {"_Z4stepfDv3_f", "_Z4stepDv3_fS_"},
1752 {"_Z4stepfDv4_f", "_Z4stepDv4_fS_"},
1753 {"_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_"},
1754 {"_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_"},
1755 {"_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_"},
Kévin Petit6b0a9532018-10-30 20:00:39 +00001756 };
1757
1758 for (auto Pair : Map) {
1759 // If we find a function with the matching name.
1760 if (auto F = M.getFunction(Pair.first)) {
1761 SmallVector<Instruction *, 4> ToRemoves;
1762
1763 // Walk the users of the function.
1764 for (auto &U : F->uses()) {
1765 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1766
1767 auto ReplacementFn = Pair.second;
1768
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001769 SmallVector<Value *, 2> ArgsToSplat = {CI->getOperand(0)};
Kévin Petit6b0a9532018-10-30 20:00:39 +00001770 Value *VectorArg;
1771
1772 // First figure out which function we're dealing with
1773 if (F->getName().startswith("_Z10smoothstep")) {
1774 ArgsToSplat.push_back(CI->getOperand(1));
1775 VectorArg = CI->getOperand(2);
1776 } else {
1777 VectorArg = CI->getOperand(1);
1778 }
1779
1780 // Splat arguments that need to be
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001781 SmallVector<Value *, 2> SplatArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001782 auto VecType = VectorArg->getType();
1783
1784 for (auto arg : ArgsToSplat) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001785 Value *NewVectorArg = UndefValue::get(VecType);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001786 for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001787 auto index =
1788 ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1789 NewVectorArg =
1790 InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001791 }
1792 SplatArgs.push_back(NewVectorArg);
1793 }
1794
1795 // Replace the call with the vector/vector flavour
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001796 SmallVector<Type *, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1797 const auto NewFType =
1798 FunctionType::get(CI->getType(), NewArgTypes, false);
Kévin Petit6b0a9532018-10-30 20:00:39 +00001799
1800 const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
1801
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001802 SmallVector<Value *, 3> NewArgs;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001803 for (auto arg : SplatArgs) {
1804 NewArgs.push_back(arg);
1805 }
1806 NewArgs.push_back(VectorArg);
1807
1808 const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
1809
1810 CI->replaceAllUsesWith(NewCI);
1811
1812 // Lastly, remember to remove the user.
1813 ToRemoves.push_back(CI);
1814 }
1815 }
1816
1817 Changed = !ToRemoves.empty();
1818
1819 // And cleanup the calls we don't use anymore.
1820 for (auto V : ToRemoves) {
1821 V->eraseFromParent();
1822 }
1823
1824 // And remove the function we don't need either too.
1825 F->eraseFromParent();
1826 }
1827 }
1828
1829 return Changed;
1830}
1831
David Neto22f144c2017-06-12 14:26:21 -04001832bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
1833 bool Changed = false;
1834
1835 const std::map<const char *, Instruction::BinaryOps> Map = {
1836 {"_Z7signbitf", Instruction::LShr},
1837 {"_Z7signbitDv2_f", Instruction::AShr},
1838 {"_Z7signbitDv3_f", Instruction::AShr},
1839 {"_Z7signbitDv4_f", Instruction::AShr},
1840 };
1841
1842 for (auto Pair : Map) {
1843 // If we find a function with the matching name.
1844 if (auto F = M.getFunction(Pair.first)) {
1845 SmallVector<Instruction *, 4> ToRemoves;
1846
1847 // Walk the users of the function.
1848 for (auto &U : F->uses()) {
1849 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1850 auto Arg = CI->getOperand(0);
1851
1852 auto Bitcast =
1853 CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
1854
1855 auto Shr = BinaryOperator::Create(Pair.second, Bitcast,
1856 ConstantInt::get(CI->getType(), 31),
1857 "", CI);
1858
1859 CI->replaceAllUsesWith(Shr);
1860
1861 // Lastly, remember to remove the user.
1862 ToRemoves.push_back(CI);
1863 }
1864 }
1865
1866 Changed = !ToRemoves.empty();
1867
1868 // And cleanup the calls we don't use anymore.
1869 for (auto V : ToRemoves) {
1870 V->eraseFromParent();
1871 }
1872
1873 // And remove the function we don't need either too.
1874 F->eraseFromParent();
1875 }
1876 }
1877
1878 return Changed;
1879}
1880
1881bool ReplaceOpenCLBuiltinPass::replaceMadandMad24andMul24(Module &M) {
1882 bool Changed = false;
1883
1884 const std::map<const char *,
1885 std::pair<Instruction::BinaryOps, Instruction::BinaryOps>>
1886 Map = {
1887 {"_Z3madfff", {Instruction::FMul, Instruction::FAdd}},
1888 {"_Z3madDv2_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1889 {"_Z3madDv3_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1890 {"_Z3madDv4_fS_S_", {Instruction::FMul, Instruction::FAdd}},
1891 {"_Z5mad24iii", {Instruction::Mul, Instruction::Add}},
1892 {"_Z5mad24Dv2_iS_S_", {Instruction::Mul, Instruction::Add}},
1893 {"_Z5mad24Dv3_iS_S_", {Instruction::Mul, Instruction::Add}},
1894 {"_Z5mad24Dv4_iS_S_", {Instruction::Mul, Instruction::Add}},
1895 {"_Z5mad24jjj", {Instruction::Mul, Instruction::Add}},
1896 {"_Z5mad24Dv2_jS_S_", {Instruction::Mul, Instruction::Add}},
1897 {"_Z5mad24Dv3_jS_S_", {Instruction::Mul, Instruction::Add}},
1898 {"_Z5mad24Dv4_jS_S_", {Instruction::Mul, Instruction::Add}},
1899 {"_Z5mul24ii", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1900 {"_Z5mul24Dv2_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1901 {"_Z5mul24Dv3_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1902 {"_Z5mul24Dv4_iS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1903 {"_Z5mul24jj", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1904 {"_Z5mul24Dv2_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1905 {"_Z5mul24Dv3_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1906 {"_Z5mul24Dv4_jS_", {Instruction::Mul, Instruction::BinaryOpsEnd}},
1907 };
1908
1909 for (auto Pair : Map) {
1910 // If we find a function with the matching name.
1911 if (auto F = M.getFunction(Pair.first)) {
1912 SmallVector<Instruction *, 4> ToRemoves;
1913
1914 // Walk the users of the function.
1915 for (auto &U : F->uses()) {
1916 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
1917 // The multiply instruction to use.
1918 auto MulInst = Pair.second.first;
1919
1920 // The add instruction to use.
1921 auto AddInst = Pair.second.second;
1922
1923 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
1924
1925 auto I = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1926 CI->getArgOperand(1), "", CI);
1927
1928 if (Instruction::BinaryOpsEnd != AddInst) {
1929 I = BinaryOperator::Create(AddInst, I, CI->getArgOperand(2), "",
1930 CI);
1931 }
1932
1933 CI->replaceAllUsesWith(I);
1934
1935 // Lastly, remember to remove the user.
1936 ToRemoves.push_back(CI);
1937 }
1938 }
1939
1940 Changed = !ToRemoves.empty();
1941
1942 // And cleanup the calls we don't use anymore.
1943 for (auto V : ToRemoves) {
1944 V->eraseFromParent();
1945 }
1946
1947 // And remove the function we don't need either too.
1948 F->eraseFromParent();
1949 }
1950 }
1951
1952 return Changed;
1953}
1954
Derek Chowcfd368b2017-10-19 20:58:45 -07001955bool ReplaceOpenCLBuiltinPass::replaceVstore(Module &M) {
1956 bool Changed = false;
1957
alan-bakerf795f392019-06-11 18:24:34 -04001958 for (auto const &SymVal : M.getValueSymbolTable()) {
1959 if (!SymVal.getKey().contains("vstore"))
1960 continue;
1961 if (SymVal.getKey().contains("vstore_"))
1962 continue;
1963 if (SymVal.getKey().contains("vstorea"))
1964 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001965
alan-bakerf795f392019-06-11 18:24:34 -04001966 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07001967 SmallVector<Instruction *, 4> ToRemoves;
1968
alan-bakerf795f392019-06-11 18:24:34 -04001969 auto fname = F->getName();
1970 if (!fname.consume_front("_Z"))
1971 continue;
1972 size_t name_len;
1973 if (fname.consumeInteger(10, name_len))
1974 continue;
1975 std::string name = fname.take_front(name_len);
1976
1977 bool ok = StringSwitch<bool>(name)
1978 .Case("vstore2", true)
1979 .Case("vstore3", true)
1980 .Case("vstore4", true)
1981 .Case("vstore8", true)
1982 .Case("vstore16", true)
1983 .Default(false);
1984 if (!ok)
1985 continue;
1986
Derek Chowcfd368b2017-10-19 20:58:45 -07001987 for (auto &U : F->uses()) {
1988 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04001989 auto data = CI->getOperand(0);
Derek Chowcfd368b2017-10-19 20:58:45 -07001990
alan-bakerf795f392019-06-11 18:24:34 -04001991 auto data_type = data->getType();
1992 if (!data_type->isVectorTy())
1993 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001994
alan-bakerf795f392019-06-11 18:24:34 -04001995 auto elems = data_type->getVectorNumElements();
1996 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
1997 elems != 16)
1998 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07001999
alan-bakerf795f392019-06-11 18:24:34 -04002000 auto offset = CI->getOperand(1);
2001 auto ptr = CI->getOperand(2);
2002 auto ptr_type = ptr->getType();
2003 auto pointee_type = ptr_type->getPointerElementType();
2004 if (pointee_type != data_type->getVectorElementType())
2005 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002006
alan-bakerf795f392019-06-11 18:24:34 -04002007 // Avoid pointer casts. Instead generate the correct number of stores
2008 // and rely on drivers to coalesce appropriately.
2009 IRBuilder<> builder(CI);
2010 auto elems_const = builder.getInt32(elems);
2011 auto adjust = builder.CreateMul(offset, elems_const);
2012 for (auto i = 0; i < elems; ++i) {
2013 auto idx = builder.getInt32(i);
2014 auto add = builder.CreateAdd(adjust, idx);
2015 auto gep = builder.CreateGEP(ptr, add);
2016 auto extract = builder.CreateExtractElement(data, i);
2017 auto store = builder.CreateStore(extract, gep);
2018 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002019
Derek Chowcfd368b2017-10-19 20:58:45 -07002020 ToRemoves.push_back(CI);
2021 }
2022 }
2023
2024 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07002025 for (auto V : ToRemoves) {
2026 V->eraseFromParent();
2027 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002028 F->eraseFromParent();
2029 }
2030 }
2031
2032 return Changed;
2033}
2034
2035bool ReplaceOpenCLBuiltinPass::replaceVload(Module &M) {
2036 bool Changed = false;
2037
alan-bakerf795f392019-06-11 18:24:34 -04002038 for (auto const &SymVal : M.getValueSymbolTable()) {
2039 if (!SymVal.getKey().contains("vload"))
2040 continue;
2041 if (SymVal.getKey().contains("vload_"))
2042 continue;
2043 if (SymVal.getKey().contains("vloada"))
2044 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002045
alan-bakerf795f392019-06-11 18:24:34 -04002046 if (auto F = dyn_cast<Function>(SymVal.getValue())) {
Derek Chowcfd368b2017-10-19 20:58:45 -07002047 SmallVector<Instruction *, 4> ToRemoves;
2048
alan-bakerf795f392019-06-11 18:24:34 -04002049 auto fname = F->getName();
2050 if (!fname.consume_front("_Z"))
2051 continue;
2052 size_t name_len;
2053 if (fname.consumeInteger(10, name_len))
2054 continue;
2055 std::string name = fname.take_front(name_len);
2056
2057 bool ok = StringSwitch<bool>(name)
2058 .Case("vload2", true)
2059 .Case("vload3", true)
2060 .Case("vload4", true)
2061 .Case("vload8", true)
2062 .Case("vload16", true)
2063 .Default(false);
2064 if (!ok)
2065 continue;
2066
Derek Chowcfd368b2017-10-19 20:58:45 -07002067 for (auto &U : F->uses()) {
2068 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
alan-bakerf795f392019-06-11 18:24:34 -04002069 auto ret_type = F->getReturnType();
2070 if (!ret_type->isVectorTy())
2071 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002072
alan-bakerf795f392019-06-11 18:24:34 -04002073 auto elems = ret_type->getVectorNumElements();
2074 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 &&
2075 elems != 16)
2076 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002077
alan-bakerf795f392019-06-11 18:24:34 -04002078 auto offset = CI->getOperand(0);
2079 auto ptr = CI->getOperand(1);
2080 auto ptr_type = ptr->getType();
2081 auto pointee_type = ptr_type->getPointerElementType();
2082 if (pointee_type != ret_type->getVectorElementType())
2083 continue;
Derek Chowcfd368b2017-10-19 20:58:45 -07002084
alan-bakerf795f392019-06-11 18:24:34 -04002085 // Avoid pointer casts. Instead generate the correct number of loads
2086 // and rely on drivers to coalesce appropriately.
2087 IRBuilder<> builder(CI);
2088 auto elems_const = builder.getInt32(elems);
2089 Value *insert = UndefValue::get(ret_type);
2090 auto adjust = builder.CreateMul(offset, elems_const);
2091 for (auto i = 0; i < elems; ++i) {
2092 auto idx = builder.getInt32(i);
2093 auto add = builder.CreateAdd(adjust, idx);
2094 auto gep = builder.CreateGEP(ptr, add);
2095 auto load = builder.CreateLoad(gep);
2096 insert = builder.CreateInsertElement(insert, load, i);
2097 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002098
alan-bakerf795f392019-06-11 18:24:34 -04002099 CI->replaceAllUsesWith(insert);
Derek Chowcfd368b2017-10-19 20:58:45 -07002100 ToRemoves.push_back(CI);
2101 }
2102 }
2103
2104 Changed = !ToRemoves.empty();
Derek Chowcfd368b2017-10-19 20:58:45 -07002105 for (auto V : ToRemoves) {
2106 V->eraseFromParent();
2107 }
Derek Chowcfd368b2017-10-19 20:58:45 -07002108 F->eraseFromParent();
Derek Chowcfd368b2017-10-19 20:58:45 -07002109 }
2110 }
2111
2112 return Changed;
2113}
2114
David Neto22f144c2017-06-12 14:26:21 -04002115bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Module &M) {
2116 bool Changed = false;
2117
2118 const std::vector<const char *> Map = {"_Z10vload_halfjPU3AS1KDh",
2119 "_Z10vload_halfjPU3AS2KDh"};
2120
2121 for (auto Name : Map) {
2122 // If we find a function with the matching name.
2123 if (auto F = M.getFunction(Name)) {
2124 SmallVector<Instruction *, 4> ToRemoves;
2125
2126 // Walk the users of the function.
2127 for (auto &U : F->uses()) {
2128 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2129 // The index argument from vload_half.
2130 auto Arg0 = CI->getOperand(0);
2131
2132 // The pointer argument from vload_half.
2133 auto Arg1 = CI->getOperand(1);
2134
David Neto22f144c2017-06-12 14:26:21 -04002135 auto IntTy = Type::getInt32Ty(M.getContext());
2136 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
David Neto22f144c2017-06-12 14:26:21 -04002137 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
2138
David Neto22f144c2017-06-12 14:26:21 -04002139 // Our intrinsic to unpack a float2 from an int.
2140 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
2141
2142 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
2143
David Neto482550a2018-03-24 05:21:07 -07002144 if (clspv::Option::F16BitStorage()) {
David Netoac825b82017-05-30 12:49:01 -04002145 auto ShortTy = Type::getInt16Ty(M.getContext());
2146 auto ShortPointerTy = PointerType::get(
2147 ShortTy, Arg1->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002148
David Netoac825b82017-05-30 12:49:01 -04002149 // Cast the half* pointer to short*.
2150 auto Cast =
2151 CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002152
David Netoac825b82017-05-30 12:49:01 -04002153 // Index into the correct address of the casted pointer.
2154 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
2155
2156 // Load from the short* we casted to.
2157 auto Load = new LoadInst(Index, "", CI);
2158
2159 // ZExt the short -> int.
2160 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
2161
2162 // Get our float2.
2163 auto Call = CallInst::Create(NewF, ZExt, "", CI);
2164
2165 // Extract out the bottom element which is our float result.
2166 auto Extract = ExtractElementInst::Create(
2167 Call, ConstantInt::get(IntTy, 0), "", CI);
2168
2169 CI->replaceAllUsesWith(Extract);
2170 } else {
2171 // Assume the pointer argument points to storage aligned to 32bits
2172 // or more.
2173 // TODO(dneto): Do more analysis to make sure this is true?
2174 //
2175 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
2176 // with:
2177 //
2178 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
2179 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
2180 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
2181 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
2182 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
2183 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
2184 // x float> %converted, %index_is_odd32
2185
2186 auto IntPointerTy = PointerType::get(
2187 IntTy, Arg1->getType()->getPointerAddressSpace());
2188
David Neto973e6a82017-05-30 13:48:18 -04002189 // Cast the base pointer to int*.
David Netoac825b82017-05-30 12:49:01 -04002190 // In a valid call (according to assumptions), this should get
David Neto973e6a82017-05-30 13:48:18 -04002191 // optimized away in the simplify GEP pass.
David Netoac825b82017-05-30 12:49:01 -04002192 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
2193
2194 auto One = ConstantInt::get(IntTy, 1);
2195 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
2196 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
2197
2198 // Index into the correct address of the casted pointer.
2199 auto Ptr =
2200 GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
2201
2202 // Load from the int* we casted to.
2203 auto Load = new LoadInst(Ptr, "", CI);
2204
2205 // Get our float2.
2206 auto Call = CallInst::Create(NewF, Load, "", CI);
2207
2208 // Extract out the float result, where the element number is
2209 // determined by whether the original index was even or odd.
2210 auto Extract = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
2211
2212 CI->replaceAllUsesWith(Extract);
2213 }
David Neto22f144c2017-06-12 14:26:21 -04002214
2215 // Lastly, remember to remove the user.
2216 ToRemoves.push_back(CI);
2217 }
2218 }
2219
2220 Changed = !ToRemoves.empty();
2221
2222 // And cleanup the calls we don't use anymore.
2223 for (auto V : ToRemoves) {
2224 V->eraseFromParent();
2225 }
2226
2227 // And remove the function we don't need either too.
2228 F->eraseFromParent();
2229 }
2230 }
2231
2232 return Changed;
2233}
2234
2235bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002236
Kévin Petite8edce32019-04-10 14:23:32 +01002237 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002238 "_Z11vload_half2jPU3AS1KDh",
2239 "_Z12vloada_half2jPU3AS1KDh", // vloada_half2 global
2240 "_Z11vload_half2jPU3AS2KDh",
2241 "_Z12vloada_half2jPU3AS2KDh", // vloada_half2 constant
2242 };
David Neto22f144c2017-06-12 14:26:21 -04002243
Kévin Petite8edce32019-04-10 14:23:32 +01002244 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2245 // The index argument from vload_half.
2246 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002247
Kévin Petite8edce32019-04-10 14:23:32 +01002248 // The pointer argument from vload_half.
2249 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002250
Kévin Petite8edce32019-04-10 14:23:32 +01002251 auto IntTy = Type::getInt32Ty(M.getContext());
2252 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002253 auto NewPointerTy =
2254 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002255 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002256
Kévin Petite8edce32019-04-10 14:23:32 +01002257 // Cast the half* pointer to int*.
2258 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002259
Kévin Petite8edce32019-04-10 14:23:32 +01002260 // Index into the correct address of the casted pointer.
2261 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002262
Kévin Petite8edce32019-04-10 14:23:32 +01002263 // Load from the int* we casted to.
2264 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002265
Kévin Petite8edce32019-04-10 14:23:32 +01002266 // Our intrinsic to unpack a float2 from an int.
2267 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002268
Kévin Petite8edce32019-04-10 14:23:32 +01002269 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002270
Kévin Petite8edce32019-04-10 14:23:32 +01002271 // Get our float2.
2272 return CallInst::Create(NewF, Load, "", CI);
2273 });
David Neto22f144c2017-06-12 14:26:21 -04002274}
2275
2276bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002277
Kévin Petite8edce32019-04-10 14:23:32 +01002278 const std::vector<const char *> Names = {
David Neto556c7e62018-06-08 13:45:55 -07002279 "_Z11vload_half4jPU3AS1KDh",
2280 "_Z12vloada_half4jPU3AS1KDh",
2281 "_Z11vload_half4jPU3AS2KDh",
2282 "_Z12vloada_half4jPU3AS2KDh",
2283 };
David Neto22f144c2017-06-12 14:26:21 -04002284
Kévin Petite8edce32019-04-10 14:23:32 +01002285 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2286 // The index argument from vload_half.
2287 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002288
Kévin Petite8edce32019-04-10 14:23:32 +01002289 // The pointer argument from vload_half.
2290 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002291
Kévin Petite8edce32019-04-10 14:23:32 +01002292 auto IntTy = Type::getInt32Ty(M.getContext());
2293 auto Int2Ty = VectorType::get(IntTy, 2);
2294 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002295 auto NewPointerTy =
2296 PointerType::get(Int2Ty, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002297 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04002298
Kévin Petite8edce32019-04-10 14:23:32 +01002299 // Cast the half* pointer to int2*.
2300 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002301
Kévin Petite8edce32019-04-10 14:23:32 +01002302 // Index into the correct address of the casted pointer.
2303 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002304
Kévin Petite8edce32019-04-10 14:23:32 +01002305 // Load from the int2* we casted to.
2306 auto Load = new LoadInst(Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002307
Kévin Petite8edce32019-04-10 14:23:32 +01002308 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002309 auto X =
2310 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2311 auto Y =
2312 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002313
Kévin Petite8edce32019-04-10 14:23:32 +01002314 // Our intrinsic to unpack a float2 from an int.
2315 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002316
Kévin Petite8edce32019-04-10 14:23:32 +01002317 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002318
Kévin Petite8edce32019-04-10 14:23:32 +01002319 // Get the lower (x & y) components of our final float4.
2320 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002321
Kévin Petite8edce32019-04-10 14:23:32 +01002322 // Get the higher (z & w) components of our final float4.
2323 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002324
Kévin Petite8edce32019-04-10 14:23:32 +01002325 Constant *ShuffleMask[4] = {
2326 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2327 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002328
Kévin Petite8edce32019-04-10 14:23:32 +01002329 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002330 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2331 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002332 });
David Neto22f144c2017-06-12 14:26:21 -04002333}
2334
David Neto6ad93232018-06-07 15:42:58 -07002335bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002336
2337 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
2338 //
2339 // %u = load i32 %ptr
2340 // %fxy = call <2 x float> Unpack2xHalf(u)
2341 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002342 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002343 "_Z20__clspv_vloada_half2jPU3AS1Kj", // global
2344 "_Z20__clspv_vloada_half2jPU3AS3Kj", // local
2345 "_Z20__clspv_vloada_half2jPKj", // private
2346 };
2347
Kévin Petite8edce32019-04-10 14:23:32 +01002348 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2349 auto Index = CI->getOperand(0);
2350 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002351
Kévin Petite8edce32019-04-10 14:23:32 +01002352 auto IntTy = Type::getInt32Ty(M.getContext());
2353 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2354 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002355
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002356 auto IndexedPtr = GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002357 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002358
Kévin Petite8edce32019-04-10 14:23:32 +01002359 // Our intrinsic to unpack a float2 from an int.
2360 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002361
Kévin Petite8edce32019-04-10 14:23:32 +01002362 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002363
Kévin Petite8edce32019-04-10 14:23:32 +01002364 // Get our final float2.
2365 return CallInst::Create(NewF, Load, "", CI);
2366 });
David Neto6ad93232018-06-07 15:42:58 -07002367}
2368
2369bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Module &M) {
David Neto6ad93232018-06-07 15:42:58 -07002370
2371 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
2372 //
2373 // %u2 = load <2 x i32> %ptr
2374 // %u2xy = extractelement %u2, 0
2375 // %u2zw = extractelement %u2, 1
2376 // %fxy = call <2 x float> Unpack2xHalf(uint)
2377 // %fzw = call <2 x float> Unpack2xHalf(uint)
2378 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
Kévin Petite8edce32019-04-10 14:23:32 +01002379 const std::vector<const char *> Names = {
David Neto6ad93232018-06-07 15:42:58 -07002380 "_Z20__clspv_vloada_half4jPU3AS1KDv2_j", // global
2381 "_Z20__clspv_vloada_half4jPU3AS3KDv2_j", // local
2382 "_Z20__clspv_vloada_half4jPKDv2_j", // private
2383 };
2384
Kévin Petite8edce32019-04-10 14:23:32 +01002385 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2386 auto Index = CI->getOperand(0);
2387 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07002388
Kévin Petite8edce32019-04-10 14:23:32 +01002389 auto IntTy = Type::getInt32Ty(M.getContext());
2390 auto Int2Ty = VectorType::get(IntTy, 2);
2391 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2392 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07002393
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002394 auto IndexedPtr = GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002395 auto Load = new LoadInst(IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002396
Kévin Petite8edce32019-04-10 14:23:32 +01002397 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002398 auto X =
2399 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
2400 auto Y =
2401 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002402
Kévin Petite8edce32019-04-10 14:23:32 +01002403 // Our intrinsic to unpack a float2 from an int.
2404 auto SPIRVIntrinsic = "spirv.unpack.v2f16";
David Neto6ad93232018-06-07 15:42:58 -07002405
Kévin Petite8edce32019-04-10 14:23:32 +01002406 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07002407
Kévin Petite8edce32019-04-10 14:23:32 +01002408 // Get the lower (x & y) components of our final float4.
2409 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002410
Kévin Petite8edce32019-04-10 14:23:32 +01002411 // Get the higher (z & w) components of our final float4.
2412 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07002413
Kévin Petite8edce32019-04-10 14:23:32 +01002414 Constant *ShuffleMask[4] = {
2415 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2416 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto6ad93232018-06-07 15:42:58 -07002417
Kévin Petite8edce32019-04-10 14:23:32 +01002418 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002419 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
2420 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002421 });
David Neto6ad93232018-06-07 15:42:58 -07002422}
2423
David Neto22f144c2017-06-12 14:26:21 -04002424bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002425
Kévin Petite8edce32019-04-10 14:23:32 +01002426 const std::vector<const char *> Names = {"_Z11vstore_halffjPU3AS1Dh",
2427 "_Z15vstore_half_rtefjPU3AS1Dh",
2428 "_Z15vstore_half_rtzfjPU3AS1Dh"};
David Neto22f144c2017-06-12 14:26:21 -04002429
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002430 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002431 // The value to store.
2432 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002433
Kévin Petite8edce32019-04-10 14:23:32 +01002434 // The index argument from vstore_half.
2435 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002436
Kévin Petite8edce32019-04-10 14:23:32 +01002437 // The pointer argument from vstore_half.
2438 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002439
Kévin Petite8edce32019-04-10 14:23:32 +01002440 auto IntTy = Type::getInt32Ty(M.getContext());
2441 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
2442 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
2443 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04002444
Kévin Petite8edce32019-04-10 14:23:32 +01002445 // Our intrinsic to pack a float2 to an int.
2446 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002447
Kévin Petite8edce32019-04-10 14:23:32 +01002448 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002449
Kévin Petite8edce32019-04-10 14:23:32 +01002450 // Insert our value into a float2 so that we can pack it.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002451 auto TempVec = InsertElementInst::Create(
2452 UndefValue::get(Float2Ty), Arg0, ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002453
Kévin Petite8edce32019-04-10 14:23:32 +01002454 // Pack the float2 -> half2 (in an int).
2455 auto X = CallInst::Create(NewF, TempVec, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002456
Kévin Petite8edce32019-04-10 14:23:32 +01002457 Value *Ret;
2458 if (clspv::Option::F16BitStorage()) {
2459 auto ShortTy = Type::getInt16Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002460 auto ShortPointerTy =
2461 PointerType::get(ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002462
Kévin Petite8edce32019-04-10 14:23:32 +01002463 // Truncate our i32 to an i16.
2464 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002465
Kévin Petite8edce32019-04-10 14:23:32 +01002466 // Cast the half* pointer to short*.
2467 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002468
Kévin Petite8edce32019-04-10 14:23:32 +01002469 // Index into the correct address of the casted pointer.
2470 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002471
Kévin Petite8edce32019-04-10 14:23:32 +01002472 // Store to the int* we casted to.
2473 Ret = new StoreInst(Trunc, Index, CI);
2474 } else {
2475 // We can only write to 32-bit aligned words.
2476 //
2477 // Assuming base is aligned to 32-bits, replace the equivalent of
2478 // vstore_half(value, index, base)
2479 // with:
2480 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2481 // uint32_t write_to_upper_half = index & 1u;
2482 // uint32_t shift = write_to_upper_half << 4;
2483 //
2484 // // Pack the float value as a half number in bottom 16 bits
2485 // // of an i32.
2486 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2487 //
2488 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2489 // ^ ((packed & 0xffff) << shift)
2490 // // We only need relaxed consistency, but OpenCL 1.2 only has
2491 // // sequentially consistent atomics.
2492 // // TODO(dneto): Use relaxed consistency.
2493 // atomic_xor(target_ptr, xor_value)
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002494 auto IntPointerTy =
2495 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002496
Kévin Petite8edce32019-04-10 14:23:32 +01002497 auto Four = ConstantInt::get(IntTy, 4);
2498 auto FFFF = ConstantInt::get(IntTy, 0xffff);
David Neto17852de2017-05-29 17:29:31 -04002499
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002500 auto IndexIsOdd =
2501 BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002502 // Compute index / 2
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002503 auto IndexIntoI32 =
2504 BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2505 auto BaseI32Ptr =
2506 CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2507 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32,
2508 "base_i32_ptr", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002509 auto CurrentValue = new LoadInst(OutPtr, "current_value", CI);
2510 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002511 auto MaskBitsToWrite =
2512 BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2513 auto MaskedCurrent = BinaryOperator::CreateAnd(
2514 MaskBitsToWrite, CurrentValue, "masked_current", CI);
David Neto17852de2017-05-29 17:29:31 -04002515
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002516 auto XLowerBits =
2517 BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2518 auto NewBitsToWrite =
2519 BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2520 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite,
2521 "value_to_xor", CI);
David Neto17852de2017-05-29 17:29:31 -04002522
Kévin Petite8edce32019-04-10 14:23:32 +01002523 // Generate the call to atomi_xor.
2524 SmallVector<Type *, 5> ParamTypes;
2525 // The pointer type.
2526 ParamTypes.push_back(IntPointerTy);
2527 // The Types for memory scope, semantics, and value.
2528 ParamTypes.push_back(IntTy);
2529 ParamTypes.push_back(IntTy);
2530 ParamTypes.push_back(IntTy);
2531 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2532 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
David Neto17852de2017-05-29 17:29:31 -04002533
Kévin Petite8edce32019-04-10 14:23:32 +01002534 const auto ConstantScopeDevice =
2535 ConstantInt::get(IntTy, spv::ScopeDevice);
2536 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2537 // (SPIR-V Workgroup).
2538 const auto AddrSpaceSemanticsBits =
2539 IntPointerTy->getPointerAddressSpace() == 1
2540 ? spv::MemorySemanticsUniformMemoryMask
2541 : spv::MemorySemanticsWorkgroupMemoryMask;
David Neto17852de2017-05-29 17:29:31 -04002542
Kévin Petite8edce32019-04-10 14:23:32 +01002543 // We're using relaxed consistency here.
2544 const auto ConstantMemorySemantics =
2545 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2546 AddrSpaceSemanticsBits);
David Neto17852de2017-05-29 17:29:31 -04002547
Kévin Petite8edce32019-04-10 14:23:32 +01002548 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2549 ConstantMemorySemantics, ValueToXor};
2550 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
2551 Ret = nullptr;
David Neto22f144c2017-06-12 14:26:21 -04002552 }
David Neto22f144c2017-06-12 14:26:21 -04002553
Kévin Petite8edce32019-04-10 14:23:32 +01002554 return Ret;
2555 });
David Neto22f144c2017-06-12 14:26:21 -04002556}
2557
2558bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002559
Kévin Petite8edce32019-04-10 14:23:32 +01002560 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002561 "_Z12vstore_half2Dv2_fjPU3AS1Dh",
2562 "_Z13vstorea_half2Dv2_fjPU3AS1Dh", // vstorea global
2563 "_Z13vstorea_half2Dv2_fjPU3AS3Dh", // vstorea local
2564 "_Z13vstorea_half2Dv2_fjPDh", // vstorea private
2565 "_Z16vstore_half2_rteDv2_fjPU3AS1Dh",
2566 "_Z17vstorea_half2_rteDv2_fjPU3AS1Dh", // vstorea global
2567 "_Z17vstorea_half2_rteDv2_fjPU3AS3Dh", // vstorea local
2568 "_Z17vstorea_half2_rteDv2_fjPDh", // vstorea private
2569 "_Z16vstore_half2_rtzDv2_fjPU3AS1Dh",
2570 "_Z17vstorea_half2_rtzDv2_fjPU3AS1Dh", // vstorea global
2571 "_Z17vstorea_half2_rtzDv2_fjPU3AS3Dh", // vstorea local
2572 "_Z17vstorea_half2_rtzDv2_fjPDh", // vstorea private
2573 };
David Neto22f144c2017-06-12 14:26:21 -04002574
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002575 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002576 // The value to store.
2577 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002578
Kévin Petite8edce32019-04-10 14:23:32 +01002579 // The index argument from vstore_half.
2580 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002581
Kévin Petite8edce32019-04-10 14:23:32 +01002582 // The pointer argument from vstore_half.
2583 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002584
Kévin Petite8edce32019-04-10 14:23:32 +01002585 auto IntTy = Type::getInt32Ty(M.getContext());
2586 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002587 auto NewPointerTy =
2588 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002589 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002590
Kévin Petite8edce32019-04-10 14:23:32 +01002591 // Our intrinsic to pack a float2 to an int.
2592 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002593
Kévin Petite8edce32019-04-10 14:23:32 +01002594 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002595
Kévin Petite8edce32019-04-10 14:23:32 +01002596 // Turn the packed x & y into the final packing.
2597 auto X = CallInst::Create(NewF, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002598
Kévin Petite8edce32019-04-10 14:23:32 +01002599 // Cast the half* pointer to int*.
2600 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002601
Kévin Petite8edce32019-04-10 14:23:32 +01002602 // Index into the correct address of the casted pointer.
2603 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002604
Kévin Petite8edce32019-04-10 14:23:32 +01002605 // Store to the int* we casted to.
2606 return new StoreInst(X, Index, CI);
2607 });
David Neto22f144c2017-06-12 14:26:21 -04002608}
2609
2610bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002611
Kévin Petite8edce32019-04-10 14:23:32 +01002612 const std::vector<const char *> Names = {
David Netoe2871522018-06-08 11:09:54 -07002613 "_Z12vstore_half4Dv4_fjPU3AS1Dh",
2614 "_Z13vstorea_half4Dv4_fjPU3AS1Dh", // global
2615 "_Z13vstorea_half4Dv4_fjPU3AS3Dh", // local
2616 "_Z13vstorea_half4Dv4_fjPDh", // private
2617 "_Z16vstore_half4_rteDv4_fjPU3AS1Dh",
2618 "_Z17vstorea_half4_rteDv4_fjPU3AS1Dh", // global
2619 "_Z17vstorea_half4_rteDv4_fjPU3AS3Dh", // local
2620 "_Z17vstorea_half4_rteDv4_fjPDh", // private
2621 "_Z16vstore_half4_rtzDv4_fjPU3AS1Dh",
2622 "_Z17vstorea_half4_rtzDv4_fjPU3AS1Dh", // global
2623 "_Z17vstorea_half4_rtzDv4_fjPU3AS3Dh", // local
2624 "_Z17vstorea_half4_rtzDv4_fjPDh", // private
2625 };
David Neto22f144c2017-06-12 14:26:21 -04002626
Kévin Petite8edce32019-04-10 14:23:32 +01002627 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
2628 // The value to store.
2629 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002630
Kévin Petite8edce32019-04-10 14:23:32 +01002631 // The index argument from vstore_half.
2632 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002633
Kévin Petite8edce32019-04-10 14:23:32 +01002634 // The pointer argument from vstore_half.
2635 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002636
Kévin Petite8edce32019-04-10 14:23:32 +01002637 auto IntTy = Type::getInt32Ty(M.getContext());
2638 auto Int2Ty = VectorType::get(IntTy, 2);
2639 auto Float2Ty = VectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002640 auto NewPointerTy =
2641 PointerType::get(Int2Ty, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002642 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002643
Kévin Petite8edce32019-04-10 14:23:32 +01002644 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2645 ConstantInt::get(IntTy, 1)};
David Neto22f144c2017-06-12 14:26:21 -04002646
Kévin Petite8edce32019-04-10 14:23:32 +01002647 // Extract out the x & y components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002648 auto Lo = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2649 ConstantVector::get(LoShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002650
Kévin Petite8edce32019-04-10 14:23:32 +01002651 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2652 ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002653
Kévin Petite8edce32019-04-10 14:23:32 +01002654 // Extract out the z & w components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002655 auto Hi = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2656 ConstantVector::get(HiShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002657
Kévin Petite8edce32019-04-10 14:23:32 +01002658 // Our intrinsic to pack a float2 to an int.
2659 auto SPIRVIntrinsic = "spirv.pack.v2f16";
David Neto22f144c2017-06-12 14:26:21 -04002660
Kévin Petite8edce32019-04-10 14:23:32 +01002661 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002662
Kévin Petite8edce32019-04-10 14:23:32 +01002663 // Turn the packed x & y into the final component of our int2.
2664 auto X = CallInst::Create(NewF, Lo, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002665
Kévin Petite8edce32019-04-10 14:23:32 +01002666 // Turn the packed z & w into the final component of our int2.
2667 auto Y = CallInst::Create(NewF, Hi, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002668
Kévin Petite8edce32019-04-10 14:23:32 +01002669 auto Combine = InsertElementInst::Create(
2670 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002671 Combine = InsertElementInst::Create(Combine, Y, ConstantInt::get(IntTy, 1),
2672 "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002673
Kévin Petite8edce32019-04-10 14:23:32 +01002674 // Cast the half* pointer to int2*.
2675 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002676
Kévin Petite8edce32019-04-10 14:23:32 +01002677 // Index into the correct address of the casted pointer.
2678 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002679
Kévin Petite8edce32019-04-10 14:23:32 +01002680 // Store to the int2* we casted to.
2681 return new StoreInst(Combine, Index, CI);
2682 });
David Neto22f144c2017-06-12 14:26:21 -04002683}
2684
Kévin Petit06517a12019-12-09 19:40:31 +00002685bool ReplaceOpenCLBuiltinPass::replaceSampledReadImageWithIntCoords(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002686 bool Changed = false;
2687
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002688 const std::map<const char *, const char *> Map = {
Kévin Petit06517a12019-12-09 19:40:31 +00002689 // TODO 1D, 1Darray
2690 // 2D
2691 {"_Z11read_imagei14ocl_image2d_ro11ocl_samplerDv2_i",
2692 "_Z11read_imagei14ocl_image2d_ro11ocl_samplerDv2_f"},
2693 {"_Z12read_imageui14ocl_image2d_ro11ocl_samplerDv2_i",
2694 "_Z12read_imageui14ocl_image2d_ro11ocl_samplerDv2_f"},
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002695 {"_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_i",
2696 "_Z11read_imagef14ocl_image2d_ro11ocl_samplerDv2_f"},
Kévin Petit06517a12019-12-09 19:40:31 +00002697 // TODO 2D array
2698 // 3D
2699 {"_Z11read_imagei14ocl_image3d_ro11ocl_samplerDv4_i",
2700 "_Z11read_imagei14ocl_image3d_ro11ocl_samplerDv4_f"},
2701 {"_Z12read_imageui14ocl_image3d_ro11ocl_samplerDv4_i",
2702 "_Z12read_imageui14ocl_image3d_ro11ocl_samplerDv4_f"},
2703 {"_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_i",
2704 "_Z11read_imagef14ocl_image3d_ro11ocl_samplerDv4_f"}};
David Neto22f144c2017-06-12 14:26:21 -04002705
2706 for (auto Pair : Map) {
2707 // If we find a function with the matching name.
2708 if (auto F = M.getFunction(Pair.first)) {
2709 SmallVector<Instruction *, 4> ToRemoves;
2710
2711 // Walk the users of the function.
2712 for (auto &U : F->uses()) {
2713 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2714 // The image.
2715 auto Arg0 = CI->getOperand(0);
2716
2717 // The sampler.
2718 auto Arg1 = CI->getOperand(1);
2719
2720 // The coordinate (integer type that we can't handle).
2721 auto Arg2 = CI->getOperand(2);
2722
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002723 auto FloatVecTy =
2724 VectorType::get(Type::getFloatTy(M.getContext()),
2725 Arg2->getType()->getVectorNumElements());
David Neto22f144c2017-06-12 14:26:21 -04002726
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002727 auto NewFType = FunctionType::get(
2728 CI->getType(), {Arg0->getType(), Arg1->getType(), FloatVecTy},
2729 false);
David Neto22f144c2017-06-12 14:26:21 -04002730
2731 auto NewF = M.getOrInsertFunction(Pair.second, NewFType);
2732
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002733 auto Cast =
2734 CastInst::Create(Instruction::SIToFP, Arg2, FloatVecTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002735
2736 auto NewCI = CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2737
2738 CI->replaceAllUsesWith(NewCI);
2739
2740 // Lastly, remember to remove the user.
2741 ToRemoves.push_back(CI);
2742 }
2743 }
2744
2745 Changed = !ToRemoves.empty();
2746
2747 // And cleanup the calls we don't use anymore.
2748 for (auto V : ToRemoves) {
2749 V->eraseFromParent();
2750 }
2751
2752 // And remove the function we don't need either too.
2753 F->eraseFromParent();
2754 }
2755 }
2756
2757 return Changed;
2758}
2759
2760bool ReplaceOpenCLBuiltinPass::replaceAtomics(Module &M) {
2761 bool Changed = false;
2762
Kévin Petit9b340262019-06-19 18:31:11 +01002763 const std::map<const char *, spv::Op> Map = {
2764 {"_Z8atom_incPU3AS1Vi", spv::OpAtomicIIncrement},
2765 {"_Z8atom_incPU3AS3Vi", spv::OpAtomicIIncrement},
2766 {"_Z8atom_incPU3AS1Vj", spv::OpAtomicIIncrement},
2767 {"_Z8atom_incPU3AS3Vj", spv::OpAtomicIIncrement},
2768 {"_Z8atom_decPU3AS1Vi", spv::OpAtomicIDecrement},
2769 {"_Z8atom_decPU3AS3Vi", spv::OpAtomicIDecrement},
2770 {"_Z8atom_decPU3AS1Vj", spv::OpAtomicIDecrement},
2771 {"_Z8atom_decPU3AS3Vj", spv::OpAtomicIDecrement},
2772 {"_Z12atom_cmpxchgPU3AS1Viii", spv::OpAtomicCompareExchange},
2773 {"_Z12atom_cmpxchgPU3AS3Viii", spv::OpAtomicCompareExchange},
2774 {"_Z12atom_cmpxchgPU3AS1Vjjj", spv::OpAtomicCompareExchange},
2775 {"_Z12atom_cmpxchgPU3AS3Vjjj", spv::OpAtomicCompareExchange},
2776 {"_Z10atomic_incPU3AS1Vi", spv::OpAtomicIIncrement},
2777 {"_Z10atomic_incPU3AS3Vi", spv::OpAtomicIIncrement},
2778 {"_Z10atomic_incPU3AS1Vj", spv::OpAtomicIIncrement},
2779 {"_Z10atomic_incPU3AS3Vj", spv::OpAtomicIIncrement},
2780 {"_Z10atomic_decPU3AS1Vi", spv::OpAtomicIDecrement},
2781 {"_Z10atomic_decPU3AS3Vi", spv::OpAtomicIDecrement},
2782 {"_Z10atomic_decPU3AS1Vj", spv::OpAtomicIDecrement},
2783 {"_Z10atomic_decPU3AS3Vj", spv::OpAtomicIDecrement},
2784 {"_Z14atomic_cmpxchgPU3AS1Viii", spv::OpAtomicCompareExchange},
2785 {"_Z14atomic_cmpxchgPU3AS3Viii", spv::OpAtomicCompareExchange},
2786 {"_Z14atomic_cmpxchgPU3AS1Vjjj", spv::OpAtomicCompareExchange},
2787 {"_Z14atomic_cmpxchgPU3AS3Vjjj", spv::OpAtomicCompareExchange}};
David Neto22f144c2017-06-12 14:26:21 -04002788
2789 for (auto Pair : Map) {
2790 // If we find a function with the matching name.
2791 if (auto F = M.getFunction(Pair.first)) {
2792 SmallVector<Instruction *, 4> ToRemoves;
2793
2794 // Walk the users of the function.
2795 for (auto &U : F->uses()) {
2796 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
David Neto22f144c2017-06-12 14:26:21 -04002797
2798 auto IntTy = Type::getInt32Ty(M.getContext());
2799
David Neto22f144c2017-06-12 14:26:21 -04002800 // We need to map the OpenCL constants to the SPIR-V equivalents.
2801 const auto ConstantScopeDevice =
2802 ConstantInt::get(IntTy, spv::ScopeDevice);
2803 const auto ConstantMemorySemantics = ConstantInt::get(
2804 IntTy, spv::MemorySemanticsUniformMemoryMask |
2805 spv::MemorySemanticsSequentiallyConsistentMask);
2806
2807 SmallVector<Value *, 5> Params;
2808
2809 // The pointer.
2810 Params.push_back(CI->getArgOperand(0));
2811
2812 // The memory scope.
2813 Params.push_back(ConstantScopeDevice);
2814
2815 // The memory semantics.
2816 Params.push_back(ConstantMemorySemantics);
2817
2818 if (2 < CI->getNumArgOperands()) {
2819 // The unequal memory semantics.
2820 Params.push_back(ConstantMemorySemantics);
2821
2822 // The value.
2823 Params.push_back(CI->getArgOperand(2));
2824
2825 // The comparator.
2826 Params.push_back(CI->getArgOperand(1));
2827 } else if (1 < CI->getNumArgOperands()) {
2828 // The value.
2829 Params.push_back(CI->getArgOperand(1));
2830 }
2831
Kévin Petit9b340262019-06-19 18:31:11 +01002832 auto NewCI =
2833 clspv::InsertSPIRVOp(CI, Pair.second, {}, CI->getType(), Params);
David Neto22f144c2017-06-12 14:26:21 -04002834
2835 CI->replaceAllUsesWith(NewCI);
2836
2837 // Lastly, remember to remove the user.
2838 ToRemoves.push_back(CI);
2839 }
2840 }
2841
2842 Changed = !ToRemoves.empty();
2843
2844 // And cleanup the calls we don't use anymore.
2845 for (auto V : ToRemoves) {
2846 V->eraseFromParent();
2847 }
2848
2849 // And remove the function we don't need either too.
2850 F->eraseFromParent();
2851 }
2852 }
2853
Neil Henning39672102017-09-29 14:33:13 +01002854 const std::map<const char *, llvm::AtomicRMWInst::BinOp> Map2 = {
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002855 {"_Z8atom_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002856 {"_Z8atom_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002857 {"_Z8atom_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002858 {"_Z8atom_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002859 {"_Z8atom_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002860 {"_Z8atom_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002861 {"_Z8atom_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002862 {"_Z8atom_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002863 {"_Z9atom_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002864 {"_Z9atom_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002865 {"_Z9atom_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002866 {"_Z9atom_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002867 {"_Z8atom_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002868 {"_Z8atom_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002869 {"_Z8atom_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002870 {"_Z8atom_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002871 {"_Z8atom_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002872 {"_Z8atom_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002873 {"_Z8atom_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002874 {"_Z8atom_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002875 {"_Z8atom_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002876 {"_Z8atom_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002877 {"_Z8atom_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002878 {"_Z8atom_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002879 {"_Z7atom_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002880 {"_Z7atom_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002881 {"_Z7atom_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002882 {"_Z7atom_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002883 {"_Z8atom_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002884 {"_Z8atom_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
Kévin Petit4f6c6b02018-10-25 18:56:55 +00002885 {"_Z8atom_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002886 {"_Z8atom_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor},
Neil Henning39672102017-09-29 14:33:13 +01002887 {"_Z10atomic_addPU3AS1Vii", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002888 {"_Z10atomic_addPU3AS3Vii", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002889 {"_Z10atomic_addPU3AS1Vjj", llvm::AtomicRMWInst::Add},
Kévin Petita303dc62019-03-26 21:40:35 +00002890 {"_Z10atomic_addPU3AS3Vjj", llvm::AtomicRMWInst::Add},
Neil Henning39672102017-09-29 14:33:13 +01002891 {"_Z10atomic_subPU3AS1Vii", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002892 {"_Z10atomic_subPU3AS3Vii", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002893 {"_Z10atomic_subPU3AS1Vjj", llvm::AtomicRMWInst::Sub},
Kévin Petita303dc62019-03-26 21:40:35 +00002894 {"_Z10atomic_subPU3AS3Vjj", llvm::AtomicRMWInst::Sub},
Neil Henning39672102017-09-29 14:33:13 +01002895 {"_Z11atomic_xchgPU3AS1Vii", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002896 {"_Z11atomic_xchgPU3AS3Vii", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002897 {"_Z11atomic_xchgPU3AS1Vjj", llvm::AtomicRMWInst::Xchg},
Kévin Petita303dc62019-03-26 21:40:35 +00002898 {"_Z11atomic_xchgPU3AS3Vjj", llvm::AtomicRMWInst::Xchg},
Neil Henning39672102017-09-29 14:33:13 +01002899 {"_Z10atomic_minPU3AS1Vii", llvm::AtomicRMWInst::Min},
Kévin Petita303dc62019-03-26 21:40:35 +00002900 {"_Z10atomic_minPU3AS3Vii", llvm::AtomicRMWInst::Min},
Neil Henning39672102017-09-29 14:33:13 +01002901 {"_Z10atomic_minPU3AS1Vjj", llvm::AtomicRMWInst::UMin},
Kévin Petita303dc62019-03-26 21:40:35 +00002902 {"_Z10atomic_minPU3AS3Vjj", llvm::AtomicRMWInst::UMin},
Neil Henning39672102017-09-29 14:33:13 +01002903 {"_Z10atomic_maxPU3AS1Vii", llvm::AtomicRMWInst::Max},
Kévin Petita303dc62019-03-26 21:40:35 +00002904 {"_Z10atomic_maxPU3AS3Vii", llvm::AtomicRMWInst::Max},
Neil Henning39672102017-09-29 14:33:13 +01002905 {"_Z10atomic_maxPU3AS1Vjj", llvm::AtomicRMWInst::UMax},
Kévin Petita303dc62019-03-26 21:40:35 +00002906 {"_Z10atomic_maxPU3AS3Vjj", llvm::AtomicRMWInst::UMax},
Neil Henning39672102017-09-29 14:33:13 +01002907 {"_Z10atomic_andPU3AS1Vii", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002908 {"_Z10atomic_andPU3AS3Vii", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002909 {"_Z10atomic_andPU3AS1Vjj", llvm::AtomicRMWInst::And},
Kévin Petita303dc62019-03-26 21:40:35 +00002910 {"_Z10atomic_andPU3AS3Vjj", llvm::AtomicRMWInst::And},
Neil Henning39672102017-09-29 14:33:13 +01002911 {"_Z9atomic_orPU3AS1Vii", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002912 {"_Z9atomic_orPU3AS3Vii", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002913 {"_Z9atomic_orPU3AS1Vjj", llvm::AtomicRMWInst::Or},
Kévin Petita303dc62019-03-26 21:40:35 +00002914 {"_Z9atomic_orPU3AS3Vjj", llvm::AtomicRMWInst::Or},
Neil Henning39672102017-09-29 14:33:13 +01002915 {"_Z10atomic_xorPU3AS1Vii", llvm::AtomicRMWInst::Xor},
Kévin Petita303dc62019-03-26 21:40:35 +00002916 {"_Z10atomic_xorPU3AS3Vii", llvm::AtomicRMWInst::Xor},
2917 {"_Z10atomic_xorPU3AS1Vjj", llvm::AtomicRMWInst::Xor},
2918 {"_Z10atomic_xorPU3AS3Vjj", llvm::AtomicRMWInst::Xor}};
Neil Henning39672102017-09-29 14:33:13 +01002919
2920 for (auto Pair : Map2) {
2921 // If we find a function with the matching name.
2922 if (auto F = M.getFunction(Pair.first)) {
2923 SmallVector<Instruction *, 4> ToRemoves;
2924
2925 // Walk the users of the function.
2926 for (auto &U : F->uses()) {
2927 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
2928 auto AtomicOp = new AtomicRMWInst(
2929 Pair.second, CI->getArgOperand(0), CI->getArgOperand(1),
2930 AtomicOrdering::SequentiallyConsistent, SyncScope::System, CI);
2931
2932 CI->replaceAllUsesWith(AtomicOp);
2933
2934 // Lastly, remember to remove the user.
2935 ToRemoves.push_back(CI);
2936 }
2937 }
2938
2939 Changed = !ToRemoves.empty();
2940
2941 // And cleanup the calls we don't use anymore.
2942 for (auto V : ToRemoves) {
2943 V->eraseFromParent();
2944 }
2945
2946 // And remove the function we don't need either too.
2947 F->eraseFromParent();
2948 }
2949 }
2950
David Neto22f144c2017-06-12 14:26:21 -04002951 return Changed;
2952}
2953
2954bool ReplaceOpenCLBuiltinPass::replaceCross(Module &M) {
David Neto22f144c2017-06-12 14:26:21 -04002955
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002956 std::vector<const char *> Names = {
2957 "_Z5crossDv4_fS_",
Kévin Petite8edce32019-04-10 14:23:32 +01002958 };
2959
2960 return replaceCallsWithValue(M, Names, [&M](CallInst *CI) {
David Neto22f144c2017-06-12 14:26:21 -04002961 auto IntTy = Type::getInt32Ty(M.getContext());
2962 auto FloatTy = Type::getFloatTy(M.getContext());
2963
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002964 Constant *DownShuffleMask[3] = {ConstantInt::get(IntTy, 0),
2965 ConstantInt::get(IntTy, 1),
2966 ConstantInt::get(IntTy, 2)};
David Neto22f144c2017-06-12 14:26:21 -04002967
2968 Constant *UpShuffleMask[4] = {
2969 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2970 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2971
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002972 Constant *FloatVec[3] = {ConstantFP::get(FloatTy, 0.0f),
2973 UndefValue::get(FloatTy),
2974 UndefValue::get(FloatTy)};
David Neto22f144c2017-06-12 14:26:21 -04002975
Kévin Petite8edce32019-04-10 14:23:32 +01002976 auto Vec4Ty = CI->getArgOperand(0)->getType();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002977 auto Arg0 =
2978 new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty),
2979 ConstantVector::get(DownShuffleMask), "", CI);
2980 auto Arg1 =
2981 new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty),
2982 ConstantVector::get(DownShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002983 auto Vec3Ty = Arg0->getType();
David Neto22f144c2017-06-12 14:26:21 -04002984
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002985 auto NewFType = FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
David Neto22f144c2017-06-12 14:26:21 -04002986
Kévin Petite8edce32019-04-10 14:23:32 +01002987 auto Cross3Func = M.getOrInsertFunction("_Z5crossDv3_fS_", NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002988
Kévin Petite8edce32019-04-10 14:23:32 +01002989 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002990
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002991 return new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec),
2992 ConstantVector::get(UpShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002993 });
David Neto22f144c2017-06-12 14:26:21 -04002994}
David Neto62653202017-10-16 19:05:18 -04002995
2996bool ReplaceOpenCLBuiltinPass::replaceFract(Module &M) {
2997 bool Changed = false;
2998
2999 // OpenCL's float result = fract(float x, float* ptr)
3000 //
3001 // In the LLVM domain:
3002 //
3003 // %floor_result = call spir_func float @floor(float %x)
3004 // store float %floor_result, float * %ptr
3005 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
3006 // %result = call spir_func float
3007 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
3008 //
3009 // Becomes in the SPIR-V domain, where translations of floor, fmin,
3010 // and clspv.fract occur in the SPIR-V generator pass:
3011 //
3012 // %glsl_ext = OpExtInstImport "GLSL.std.450"
3013 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
3014 // ...
3015 // %floor_result = OpExtInst %float %glsl_ext Floor %x
3016 // OpStore %ptr %floor_result
3017 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
3018 // %fract_result = OpExtInst %float
3019 // %glsl_ext Fmin %fract_intermediate %just_under_1
3020
David Neto62653202017-10-16 19:05:18 -04003021 using std::string;
3022
3023 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
3024 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003025 using QuadType =
3026 std::tuple<const char *, const char *, const char *, const char *>;
David Neto62653202017-10-16 19:05:18 -04003027 auto make_quad = [](const char *a, const char *b, const char *c,
3028 const char *d) {
3029 return std::tuple<const char *, const char *, const char *, const char *>(
3030 a, b, c, d);
3031 };
3032 const std::vector<QuadType> Functions = {
3033 make_quad("_Z5fractfPf", "_Z5floorff", "_Z4fminff", "clspv.fract.f"),
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003034 make_quad("_Z5fractDv2_fPS_", "_Z5floorDv2_f", "_Z4fminDv2_ff",
3035 "clspv.fract.v2f"),
3036 make_quad("_Z5fractDv3_fPS_", "_Z5floorDv3_f", "_Z4fminDv3_ff",
3037 "clspv.fract.v3f"),
3038 make_quad("_Z5fractDv4_fPS_", "_Z5floorDv4_f", "_Z4fminDv4_ff",
3039 "clspv.fract.v4f"),
David Neto62653202017-10-16 19:05:18 -04003040 };
3041
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003042 for (auto &quad : Functions) {
David Neto62653202017-10-16 19:05:18 -04003043 const StringRef fract_name(std::get<0>(quad));
3044
3045 // If we find a function with the matching name.
3046 if (auto F = M.getFunction(fract_name)) {
3047 if (F->use_begin() == F->use_end())
3048 continue;
3049
3050 // We have some uses.
3051 Changed = true;
3052
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003053 auto &Context = M.getContext();
David Neto62653202017-10-16 19:05:18 -04003054
3055 const StringRef floor_name(std::get<1>(quad));
3056 const StringRef fmin_name(std::get<2>(quad));
3057 const StringRef clspv_fract_name(std::get<3>(quad));
3058
3059 // This is either float or a float vector. All the float-like
3060 // types are this type.
3061 auto result_ty = F->getReturnType();
3062
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003063 Function *fmin_fn = M.getFunction(fmin_name);
David Neto62653202017-10-16 19:05:18 -04003064 if (!fmin_fn) {
3065 // Make the fmin function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003066 FunctionType *fn_ty =
3067 FunctionType::get(result_ty, {result_ty, result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003068 fmin_fn =
3069 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003070 fmin_fn->addFnAttr(Attribute::ReadNone);
3071 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
3072 }
3073
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003074 Function *floor_fn = M.getFunction(floor_name);
David Neto62653202017-10-16 19:05:18 -04003075 if (!floor_fn) {
3076 // Make the floor function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003077 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003078 floor_fn = cast<Function>(
3079 M.getOrInsertFunction(floor_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003080 floor_fn->addFnAttr(Attribute::ReadNone);
3081 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
3082 }
3083
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003084 Function *clspv_fract_fn = M.getFunction(clspv_fract_name);
David Neto62653202017-10-16 19:05:18 -04003085 if (!clspv_fract_fn) {
3086 // Make the clspv_fract function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003087 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
alan-bakerbccf62c2019-03-29 10:32:41 -04003088 clspv_fract_fn = cast<Function>(
3089 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
David Neto62653202017-10-16 19:05:18 -04003090 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
3091 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
3092 }
3093
3094 // Number of significant significand bits, whether represented or not.
3095 unsigned num_significand_bits;
3096 switch (result_ty->getScalarType()->getTypeID()) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003097 case Type::HalfTyID:
3098 num_significand_bits = 11;
3099 break;
3100 case Type::FloatTyID:
3101 num_significand_bits = 24;
3102 break;
3103 case Type::DoubleTyID:
3104 num_significand_bits = 53;
3105 break;
3106 default:
3107 assert(false && "Unhandled float type when processing fract builtin");
3108 break;
David Neto62653202017-10-16 19:05:18 -04003109 }
3110 // Beware that the disassembler displays this value as
3111 // OpConstant %float 1
3112 // which is not quite right.
3113 const double kJustUnderOneScalar =
3114 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
3115
3116 Constant *just_under_one =
3117 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
3118 if (result_ty->isVectorTy()) {
3119 just_under_one = ConstantVector::getSplat(
3120 result_ty->getVectorNumElements(), just_under_one);
3121 }
3122
3123 IRBuilder<> Builder(Context);
3124
3125 SmallVector<Instruction *, 4> ToRemoves;
3126
3127 // Walk the users of the function.
3128 for (auto &U : F->uses()) {
3129 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
3130
3131 Builder.SetInsertPoint(CI);
3132 auto arg = CI->getArgOperand(0);
3133 auto ptr = CI->getArgOperand(1);
3134
3135 // Compute floor result and store it.
3136 auto floor = Builder.CreateCall(floor_fn, {arg});
3137 Builder.CreateStore(floor, ptr);
3138
3139 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04003140 auto fract_result =
3141 Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
David Neto62653202017-10-16 19:05:18 -04003142
3143 CI->replaceAllUsesWith(fract_result);
3144
3145 // Lastly, remember to remove the user.
3146 ToRemoves.push_back(CI);
3147 }
3148 }
3149
3150 // And cleanup the calls we don't use anymore.
3151 for (auto V : ToRemoves) {
3152 V->eraseFromParent();
3153 }
3154
3155 // And remove the function we don't need either too.
3156 F->eraseFromParent();
3157 }
3158 }
3159
3160 return Changed;
3161}