blob: 60754dcbfdf8ff634e054a995e097e3095aa6737 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
David Neto118188e2018-08-24 11:27:54 -040021#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040022#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040023#include "llvm/IR/Module.h"
alan-baker4986eff2020-10-29 13:38:00 -040024#include "llvm/IR/Operator.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000025#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040026#include "llvm/Pass.h"
27#include "llvm/Support/CommandLine.h"
28#include "llvm/Support/raw_ostream.h"
alan-baker4986eff2020-10-29 13:38:00 -040029#include "llvm/Transforms/Utils/BasicBlockUtils.h"
David Neto118188e2018-08-24 11:27:54 -040030#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040031
alan-bakere0902602020-03-23 08:43:40 -040032#include "spirv/unified1/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040033
alan-baker931d18a2019-12-12 08:21:32 -050034#include "clspv/AddressSpace.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035#include "clspv/Option.h"
David Neto482550a2018-03-24 05:21:07 -070036
SJW2c317da2020-03-23 07:39:13 -050037#include "Builtins.h"
alan-baker931d18a2019-12-12 08:21:32 -050038#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040039#include "Passes.h"
40#include "SPIRVOp.h"
alan-bakerf906d2b2019-12-10 11:26:23 -050041#include "Types.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040042
SJW2c317da2020-03-23 07:39:13 -050043using namespace clspv;
David Neto22f144c2017-06-12 14:26:21 -040044using namespace llvm;
45
46#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
47
48namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000049
David Neto22f144c2017-06-12 14:26:21 -040050uint32_t clz(uint32_t v) {
51 uint32_t r;
52 uint32_t shift;
53
54 r = (v > 0xFFFF) << 4;
55 v >>= r;
56 shift = (v > 0xFF) << 3;
57 v >>= shift;
58 r |= shift;
59 shift = (v > 0xF) << 2;
60 v >>= shift;
61 r |= shift;
62 shift = (v > 0x3) << 1;
63 v >>= shift;
64 r |= shift;
65 r |= (v >> 1);
66
67 return r;
68}
69
Kévin Petitfdfa92e2019-09-25 14:20:58 +010070Type *getIntOrIntVectorTyForCast(LLVMContext &C, Type *Ty) {
71 Type *IntTy = Type::getIntNTy(C, Ty->getScalarSizeInBits());
James Pricecf53df42020-04-20 14:41:24 -040072 if (auto vec_ty = dyn_cast<VectorType>(Ty)) {
alan-baker5a8c3be2020-09-09 13:44:26 -040073 IntTy = FixedVectorType::get(IntTy,
74 vec_ty->getElementCount().getKnownMinValue());
Kévin Petitfdfa92e2019-09-25 14:20:58 +010075 }
76 return IntTy;
77}
78
alan-baker4986eff2020-10-29 13:38:00 -040079Value *MemoryOrderSemantics(Value *order, bool is_global,
80 Instruction *InsertBefore,
81 spv::MemorySemanticsMask base_semantics) {
82 enum AtomicMemoryOrder : uint32_t {
83 kMemoryOrderRelaxed = 0,
84 kMemoryOrderAcquire = 2,
85 kMemoryOrderRelease = 3,
86 kMemoryOrderAcqRel = 4,
87 kMemoryOrderSeqCst = 5
88 };
89
90 IRBuilder<> builder(InsertBefore);
91
92 // Constants for OpenCL C 2.0 memory_order.
93 const auto relaxed = builder.getInt32(AtomicMemoryOrder::kMemoryOrderRelaxed);
94 const auto acquire = builder.getInt32(AtomicMemoryOrder::kMemoryOrderAcquire);
95 const auto release = builder.getInt32(AtomicMemoryOrder::kMemoryOrderRelease);
96 const auto acq_rel = builder.getInt32(AtomicMemoryOrder::kMemoryOrderAcqRel);
97
98 // Constants for SPIR-V ordering memory semantics.
99 const auto RelaxedSemantics = builder.getInt32(spv::MemorySemanticsMaskNone);
100 const auto AcquireSemantics =
101 builder.getInt32(spv::MemorySemanticsAcquireMask);
102 const auto ReleaseSemantics =
103 builder.getInt32(spv::MemorySemanticsReleaseMask);
104 const auto AcqRelSemantics =
105 builder.getInt32(spv::MemorySemanticsAcquireReleaseMask);
106
107 // Constants for SPIR-V storage class semantics.
108 const auto UniformSemantics =
109 builder.getInt32(spv::MemorySemanticsUniformMemoryMask);
110 const auto WorkgroupSemantics =
111 builder.getInt32(spv::MemorySemanticsWorkgroupMemoryMask);
112
113 // Instead of sequentially consistent, use acquire, release or acquire
114 // release semantics.
115 Value *base_order = nullptr;
116 switch (base_semantics) {
117 case spv::MemorySemanticsAcquireMask:
118 base_order = AcquireSemantics;
119 break;
120 case spv::MemorySemanticsReleaseMask:
121 base_order = ReleaseSemantics;
122 break;
123 default:
124 base_order = AcqRelSemantics;
125 break;
126 }
127
128 Value *storage = is_global ? UniformSemantics : WorkgroupSemantics;
129 if (order == nullptr)
130 return builder.CreateOr({storage, base_order});
131
132 auto is_relaxed = builder.CreateICmpEQ(order, relaxed);
133 auto is_acquire = builder.CreateICmpEQ(order, acquire);
134 auto is_release = builder.CreateICmpEQ(order, release);
135 auto is_acq_rel = builder.CreateICmpEQ(order, acq_rel);
136 auto semantics =
137 builder.CreateSelect(is_relaxed, RelaxedSemantics, base_order);
138 semantics = builder.CreateSelect(is_acquire, AcquireSemantics, semantics);
139 semantics = builder.CreateSelect(is_release, ReleaseSemantics, semantics);
140 semantics = builder.CreateSelect(is_acq_rel, AcqRelSemantics, semantics);
141 return builder.CreateOr({storage, semantics});
142}
143
144Value *MemoryScope(Value *scope, bool is_global, Instruction *InsertBefore) {
145 enum AtomicMemoryScope : uint32_t {
146 kMemoryScopeWorkItem = 0,
147 kMemoryScopeWorkGroup = 1,
148 kMemoryScopeDevice = 2,
149 kMemoryScopeAllSVMDevices = 3, // not supported
150 kMemoryScopeSubGroup = 4
151 };
152
153 IRBuilder<> builder(InsertBefore);
154
155 // Constants for OpenCL C 2.0 memory_scope.
156 const auto work_item =
157 builder.getInt32(AtomicMemoryScope::kMemoryScopeWorkItem);
158 const auto work_group =
159 builder.getInt32(AtomicMemoryScope::kMemoryScopeWorkGroup);
160 const auto sub_group =
161 builder.getInt32(AtomicMemoryScope::kMemoryScopeSubGroup);
162 const auto device = builder.getInt32(AtomicMemoryScope::kMemoryScopeDevice);
163
164 // Constants for SPIR-V memory scopes.
165 const auto InvocationScope = builder.getInt32(spv::ScopeInvocation);
166 const auto WorkgroupScope = builder.getInt32(spv::ScopeWorkgroup);
167 const auto DeviceScope = builder.getInt32(spv::ScopeDevice);
168 const auto SubgroupScope = builder.getInt32(spv::ScopeSubgroup);
169
170 auto base_scope = is_global ? DeviceScope : WorkgroupScope;
171 if (scope == nullptr)
172 return base_scope;
173
174 auto is_work_item = builder.CreateICmpEQ(scope, work_item);
175 auto is_work_group = builder.CreateICmpEQ(scope, work_group);
176 auto is_sub_group = builder.CreateICmpEQ(scope, sub_group);
177 auto is_device = builder.CreateICmpEQ(scope, device);
178
179 scope = builder.CreateSelect(is_work_item, InvocationScope, base_scope);
180 scope = builder.CreateSelect(is_work_group, WorkgroupScope, scope);
181 scope = builder.CreateSelect(is_sub_group, SubgroupScope, scope);
182 scope = builder.CreateSelect(is_device, DeviceScope, scope);
183
184 return scope;
185}
186
SJW2c317da2020-03-23 07:39:13 -0500187bool replaceCallsWithValue(Function &F,
188 std::function<Value *(CallInst *)> Replacer) {
189
190 bool Changed = false;
191
192 SmallVector<Instruction *, 4> ToRemoves;
193
194 // Walk the users of the function.
195 for (auto &U : F.uses()) {
196 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
197
198 auto NewValue = Replacer(CI);
199
200 if (NewValue != nullptr) {
201 CI->replaceAllUsesWith(NewValue);
202
203 // Lastly, remember to remove the user.
204 ToRemoves.push_back(CI);
205 }
206 }
207 }
208
209 Changed = !ToRemoves.empty();
210
211 // And cleanup the calls we don't use anymore.
212 for (auto V : ToRemoves) {
213 V->eraseFromParent();
214 }
215
216 return Changed;
217}
218
David Neto22f144c2017-06-12 14:26:21 -0400219struct ReplaceOpenCLBuiltinPass final : public ModulePass {
220 static char ID;
221 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
222
223 bool runOnModule(Module &M) override;
alan-baker6b9d1ee2020-11-03 23:11:32 -0500224
225private:
SJW2c317da2020-03-23 07:39:13 -0500226 bool runOnFunction(Function &F);
227 bool replaceAbs(Function &F);
228 bool replaceAbsDiff(Function &F, bool is_signed);
229 bool replaceCopysign(Function &F);
230 bool replaceRecip(Function &F);
231 bool replaceDivide(Function &F);
232 bool replaceDot(Function &F);
233 bool replaceFmod(Function &F);
SJW61531372020-06-09 07:31:08 -0500234 bool replaceExp10(Function &F, const std::string &basename);
235 bool replaceLog10(Function &F, const std::string &basename);
gnl21636e7992020-09-09 16:08:16 +0100236 bool replaceLog1p(Function &F);
alan-baker12d2c182020-07-20 08:22:42 -0400237 bool replaceBarrier(Function &F, bool subgroup = false);
SJW2c317da2020-03-23 07:39:13 -0500238 bool replaceMemFence(Function &F, uint32_t semantics);
Kévin Petit1cb45112020-04-27 18:55:48 +0100239 bool replacePrefetch(Function &F);
alan-baker3e217772020-11-07 17:29:40 -0500240 bool replaceRelational(Function &F, CmpInst::Predicate P);
SJW2c317da2020-03-23 07:39:13 -0500241 bool replaceIsInfAndIsNan(Function &F, spv::Op SPIRVOp, int32_t isvec);
242 bool replaceIsFinite(Function &F);
243 bool replaceAllAndAny(Function &F, spv::Op SPIRVOp);
244 bool replaceUpsample(Function &F);
245 bool replaceRotate(Function &F);
246 bool replaceConvert(Function &F, bool SrcIsSigned, bool DstIsSigned);
247 bool replaceMulHi(Function &F, bool is_signed, bool is_mad = false);
248 bool replaceSelect(Function &F);
249 bool replaceBitSelect(Function &F);
SJW61531372020-06-09 07:31:08 -0500250 bool replaceStep(Function &F, bool is_smooth);
SJW2c317da2020-03-23 07:39:13 -0500251 bool replaceSignbit(Function &F, bool is_vec);
252 bool replaceMul(Function &F, bool is_float, bool is_mad);
253 bool replaceVloadHalf(Function &F, const std::string &name, int vec_size);
254 bool replaceVloadHalf(Function &F);
255 bool replaceVloadHalf2(Function &F);
256 bool replaceVloadHalf4(Function &F);
257 bool replaceClspvVloadaHalf2(Function &F);
258 bool replaceClspvVloadaHalf4(Function &F);
259 bool replaceVstoreHalf(Function &F, int vec_size);
260 bool replaceVstoreHalf(Function &F);
261 bool replaceVstoreHalf2(Function &F);
262 bool replaceVstoreHalf4(Function &F);
263 bool replaceHalfReadImage(Function &F);
264 bool replaceHalfWriteImage(Function &F);
265 bool replaceSampledReadImageWithIntCoords(Function &F);
266 bool replaceAtomics(Function &F, spv::Op Op);
267 bool replaceAtomics(Function &F, llvm::AtomicRMWInst::BinOp Op);
alan-baker4986eff2020-10-29 13:38:00 -0400268 bool replaceAtomicLoad(Function &F);
269 bool replaceExplicitAtomics(Function &F, spv::Op Op,
270 spv::MemorySemanticsMask semantics =
271 spv::MemorySemanticsAcquireReleaseMask);
272 bool replaceAtomicCompareExchange(Function &);
SJW2c317da2020-03-23 07:39:13 -0500273 bool replaceCross(Function &F);
274 bool replaceFract(Function &F, int vec_size);
275 bool replaceVload(Function &F);
276 bool replaceVstore(Function &F);
alan-baker3f1bf492020-11-05 09:07:36 -0500277 bool replaceAddSubSat(Function &F, bool is_signed, bool is_add);
Kévin Petit8576f682020-11-02 14:51:32 +0000278 bool replaceHadd(Function &F, bool is_signed,
279 Instruction::BinaryOps join_opcode);
alan-baker2cecaa72020-11-05 14:05:20 -0500280 bool replaceCountZeroes(Function &F, bool leading);
alan-baker6b9d1ee2020-11-03 23:11:32 -0500281 bool replaceMadSat(Function &F, bool is_signed);
alan-baker15106572020-11-06 15:08:10 -0500282 bool replaceOrdered(Function &F, bool is_ordered);
alan-baker497920b2020-11-09 16:41:36 -0500283 bool replaceIsNormal(Function &F);
alan-bakere0406e72020-11-10 12:32:04 -0500284 bool replaceFDim(Function &F);
alan-baker6b9d1ee2020-11-03 23:11:32 -0500285
286 // Caches struct types for { |type|, |type| }. This prevents
287 // getOrInsertFunction from introducing a bitcasts between structs with
288 // identical contents.
289 Type *GetPairStruct(Type *type);
290
291 DenseMap<Type *, Type *> PairStructMap;
David Neto22f144c2017-06-12 14:26:21 -0400292};
SJW2c317da2020-03-23 07:39:13 -0500293
Kévin Petit91bc72e2019-04-08 15:17:46 +0100294} // namespace
David Neto22f144c2017-06-12 14:26:21 -0400295
296char ReplaceOpenCLBuiltinPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400297INITIALIZE_PASS(ReplaceOpenCLBuiltinPass, "ReplaceOpenCLBuiltin",
298 "Replace OpenCL Builtins Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -0400299
300namespace clspv {
301ModulePass *createReplaceOpenCLBuiltinPass() {
302 return new ReplaceOpenCLBuiltinPass();
303}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400304} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -0400305
306bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
SJW2c317da2020-03-23 07:39:13 -0500307 std::list<Function *> func_list;
308 for (auto &F : M.getFunctionList()) {
309 // process only function declarations
310 if (F.isDeclaration() && runOnFunction(F)) {
311 func_list.push_front(&F);
Kévin Petit2444e9b2018-11-09 14:14:37 +0000312 }
313 }
SJW2c317da2020-03-23 07:39:13 -0500314 if (func_list.size() != 0) {
315 // recursively convert functions, but first remove dead
316 for (auto *F : func_list) {
317 if (F->use_empty()) {
318 F->eraseFromParent();
319 }
320 }
321 runOnModule(M);
322 return true;
323 }
324 return false;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000325}
326
SJW2c317da2020-03-23 07:39:13 -0500327bool ReplaceOpenCLBuiltinPass::runOnFunction(Function &F) {
328 auto &FI = Builtins::Lookup(&F);
329 switch (FI.getType()) {
330 case Builtins::kAbs:
331 if (!FI.getParameter(0).is_signed) {
332 return replaceAbs(F);
333 }
334 break;
335 case Builtins::kAbsDiff:
336 return replaceAbsDiff(F, FI.getParameter(0).is_signed);
alan-bakera52b7312020-10-26 08:58:51 -0400337
338 case Builtins::kAddSat:
alan-baker3f1bf492020-11-05 09:07:36 -0500339 return replaceAddSubSat(F, FI.getParameter(0).is_signed, true);
alan-bakera52b7312020-10-26 08:58:51 -0400340
alan-bakercc2bafb2020-11-02 08:30:18 -0500341 case Builtins::kClz:
alan-baker2cecaa72020-11-05 14:05:20 -0500342 return replaceCountZeroes(F, true);
343
344 case Builtins::kCtz:
345 return replaceCountZeroes(F, false);
alan-bakercc2bafb2020-11-02 08:30:18 -0500346
alan-bakerb6da5132020-10-29 15:59:06 -0400347 case Builtins::kHadd:
Kévin Petit8576f682020-11-02 14:51:32 +0000348 return replaceHadd(F, FI.getParameter(0).is_signed, Instruction::And);
alan-bakerb6da5132020-10-29 15:59:06 -0400349 case Builtins::kRhadd:
Kévin Petit8576f682020-11-02 14:51:32 +0000350 return replaceHadd(F, FI.getParameter(0).is_signed, Instruction::Or);
alan-bakerb6da5132020-10-29 15:59:06 -0400351
SJW2c317da2020-03-23 07:39:13 -0500352 case Builtins::kCopysign:
353 return replaceCopysign(F);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100354
SJW2c317da2020-03-23 07:39:13 -0500355 case Builtins::kHalfRecip:
356 case Builtins::kNativeRecip:
357 return replaceRecip(F);
Kévin Petite8edce32019-04-10 14:23:32 +0100358
SJW2c317da2020-03-23 07:39:13 -0500359 case Builtins::kHalfDivide:
360 case Builtins::kNativeDivide:
361 return replaceDivide(F);
362
363 case Builtins::kDot:
364 return replaceDot(F);
365
366 case Builtins::kExp10:
367 case Builtins::kHalfExp10:
SJW61531372020-06-09 07:31:08 -0500368 case Builtins::kNativeExp10:
369 return replaceExp10(F, FI.getName());
SJW2c317da2020-03-23 07:39:13 -0500370
371 case Builtins::kLog10:
372 case Builtins::kHalfLog10:
SJW61531372020-06-09 07:31:08 -0500373 case Builtins::kNativeLog10:
374 return replaceLog10(F, FI.getName());
SJW2c317da2020-03-23 07:39:13 -0500375
gnl21636e7992020-09-09 16:08:16 +0100376 case Builtins::kLog1p:
377 return replaceLog1p(F);
378
alan-bakere0406e72020-11-10 12:32:04 -0500379 case Builtins::kFdim:
380 return replaceFDim(F);
381
SJW2c317da2020-03-23 07:39:13 -0500382 case Builtins::kFmod:
383 return replaceFmod(F);
384
385 case Builtins::kBarrier:
386 case Builtins::kWorkGroupBarrier:
387 return replaceBarrier(F);
388
alan-baker12d2c182020-07-20 08:22:42 -0400389 case Builtins::kSubGroupBarrier:
390 return replaceBarrier(F, true);
391
SJW2c317da2020-03-23 07:39:13 -0500392 case Builtins::kMemFence:
alan-baker12d2c182020-07-20 08:22:42 -0400393 return replaceMemFence(F, spv::MemorySemanticsAcquireReleaseMask);
SJW2c317da2020-03-23 07:39:13 -0500394 case Builtins::kReadMemFence:
395 return replaceMemFence(F, spv::MemorySemanticsAcquireMask);
396 case Builtins::kWriteMemFence:
397 return replaceMemFence(F, spv::MemorySemanticsReleaseMask);
398
399 // Relational
400 case Builtins::kIsequal:
alan-baker3e217772020-11-07 17:29:40 -0500401 return replaceRelational(F, CmpInst::FCMP_OEQ);
SJW2c317da2020-03-23 07:39:13 -0500402 case Builtins::kIsgreater:
alan-baker3e217772020-11-07 17:29:40 -0500403 return replaceRelational(F, CmpInst::FCMP_OGT);
SJW2c317da2020-03-23 07:39:13 -0500404 case Builtins::kIsgreaterequal:
alan-baker3e217772020-11-07 17:29:40 -0500405 return replaceRelational(F, CmpInst::FCMP_OGE);
SJW2c317da2020-03-23 07:39:13 -0500406 case Builtins::kIsless:
alan-baker3e217772020-11-07 17:29:40 -0500407 return replaceRelational(F, CmpInst::FCMP_OLT);
SJW2c317da2020-03-23 07:39:13 -0500408 case Builtins::kIslessequal:
alan-baker3e217772020-11-07 17:29:40 -0500409 return replaceRelational(F, CmpInst::FCMP_OLE);
SJW2c317da2020-03-23 07:39:13 -0500410 case Builtins::kIsnotequal:
alan-baker3e217772020-11-07 17:29:40 -0500411 return replaceRelational(F, CmpInst::FCMP_UNE);
412 case Builtins::kIslessgreater:
413 return replaceRelational(F, CmpInst::FCMP_ONE);
SJW2c317da2020-03-23 07:39:13 -0500414
alan-baker15106572020-11-06 15:08:10 -0500415 case Builtins::kIsordered:
416 return replaceOrdered(F, true);
417
418 case Builtins::kIsunordered:
419 return replaceOrdered(F, false);
420
SJW2c317da2020-03-23 07:39:13 -0500421 case Builtins::kIsinf: {
422 bool is_vec = FI.getParameter(0).vector_size != 0;
423 return replaceIsInfAndIsNan(F, spv::OpIsInf, is_vec ? -1 : 1);
424 }
425 case Builtins::kIsnan: {
426 bool is_vec = FI.getParameter(0).vector_size != 0;
427 return replaceIsInfAndIsNan(F, spv::OpIsNan, is_vec ? -1 : 1);
428 }
429
430 case Builtins::kIsfinite:
431 return replaceIsFinite(F);
432
433 case Builtins::kAll: {
434 bool is_vec = FI.getParameter(0).vector_size != 0;
435 return replaceAllAndAny(F, !is_vec ? spv::OpNop : spv::OpAll);
436 }
437 case Builtins::kAny: {
438 bool is_vec = FI.getParameter(0).vector_size != 0;
439 return replaceAllAndAny(F, !is_vec ? spv::OpNop : spv::OpAny);
440 }
441
alan-baker497920b2020-11-09 16:41:36 -0500442 case Builtins::kIsnormal:
443 return replaceIsNormal(F);
444
SJW2c317da2020-03-23 07:39:13 -0500445 case Builtins::kUpsample:
446 return replaceUpsample(F);
447
448 case Builtins::kRotate:
449 return replaceRotate(F);
450
451 case Builtins::kConvert:
452 return replaceConvert(F, FI.getParameter(0).is_signed,
453 FI.getReturnType().is_signed);
454
alan-baker4986eff2020-10-29 13:38:00 -0400455 // OpenCL 2.0 explicit atomics have different default scopes and semantics
456 // than legacy atomic functions.
457 case Builtins::kAtomicLoad:
458 case Builtins::kAtomicLoadExplicit:
459 return replaceAtomicLoad(F);
460 case Builtins::kAtomicStore:
461 case Builtins::kAtomicStoreExplicit:
462 return replaceExplicitAtomics(F, spv::OpAtomicStore,
463 spv::MemorySemanticsReleaseMask);
464 case Builtins::kAtomicExchange:
465 case Builtins::kAtomicExchangeExplicit:
466 return replaceExplicitAtomics(F, spv::OpAtomicExchange);
467 case Builtins::kAtomicFetchAdd:
468 case Builtins::kAtomicFetchAddExplicit:
469 return replaceExplicitAtomics(F, spv::OpAtomicIAdd);
470 case Builtins::kAtomicFetchSub:
471 case Builtins::kAtomicFetchSubExplicit:
472 return replaceExplicitAtomics(F, spv::OpAtomicISub);
473 case Builtins::kAtomicFetchOr:
474 case Builtins::kAtomicFetchOrExplicit:
475 return replaceExplicitAtomics(F, spv::OpAtomicOr);
476 case Builtins::kAtomicFetchXor:
477 case Builtins::kAtomicFetchXorExplicit:
478 return replaceExplicitAtomics(F, spv::OpAtomicXor);
479 case Builtins::kAtomicFetchAnd:
480 case Builtins::kAtomicFetchAndExplicit:
481 return replaceExplicitAtomics(F, spv::OpAtomicAnd);
482 case Builtins::kAtomicFetchMin:
483 case Builtins::kAtomicFetchMinExplicit:
484 return replaceExplicitAtomics(F, FI.getParameter(1).is_signed
485 ? spv::OpAtomicSMin
486 : spv::OpAtomicUMin);
487 case Builtins::kAtomicFetchMax:
488 case Builtins::kAtomicFetchMaxExplicit:
489 return replaceExplicitAtomics(F, FI.getParameter(1).is_signed
490 ? spv::OpAtomicSMax
491 : spv::OpAtomicUMax);
492 // Weak compare exchange is generated as strong compare exchange.
493 case Builtins::kAtomicCompareExchangeWeak:
494 case Builtins::kAtomicCompareExchangeWeakExplicit:
495 case Builtins::kAtomicCompareExchangeStrong:
496 case Builtins::kAtomicCompareExchangeStrongExplicit:
497 return replaceAtomicCompareExchange(F);
498
499 // Legacy atomic functions.
SJW2c317da2020-03-23 07:39:13 -0500500 case Builtins::kAtomicInc:
501 return replaceAtomics(F, spv::OpAtomicIIncrement);
502 case Builtins::kAtomicDec:
503 return replaceAtomics(F, spv::OpAtomicIDecrement);
504 case Builtins::kAtomicCmpxchg:
505 return replaceAtomics(F, spv::OpAtomicCompareExchange);
506 case Builtins::kAtomicAdd:
507 return replaceAtomics(F, llvm::AtomicRMWInst::Add);
508 case Builtins::kAtomicSub:
509 return replaceAtomics(F, llvm::AtomicRMWInst::Sub);
510 case Builtins::kAtomicXchg:
511 return replaceAtomics(F, llvm::AtomicRMWInst::Xchg);
512 case Builtins::kAtomicMin:
513 return replaceAtomics(F, FI.getParameter(0).is_signed
514 ? llvm::AtomicRMWInst::Min
515 : llvm::AtomicRMWInst::UMin);
516 case Builtins::kAtomicMax:
517 return replaceAtomics(F, FI.getParameter(0).is_signed
518 ? llvm::AtomicRMWInst::Max
519 : llvm::AtomicRMWInst::UMax);
520 case Builtins::kAtomicAnd:
521 return replaceAtomics(F, llvm::AtomicRMWInst::And);
522 case Builtins::kAtomicOr:
523 return replaceAtomics(F, llvm::AtomicRMWInst::Or);
524 case Builtins::kAtomicXor:
525 return replaceAtomics(F, llvm::AtomicRMWInst::Xor);
526
527 case Builtins::kCross:
528 if (FI.getParameter(0).vector_size == 4) {
529 return replaceCross(F);
530 }
531 break;
532
533 case Builtins::kFract:
534 if (FI.getParameterCount()) {
535 return replaceFract(F, FI.getParameter(0).vector_size);
536 }
537 break;
538
539 case Builtins::kMadHi:
540 return replaceMulHi(F, FI.getParameter(0).is_signed, true);
541 case Builtins::kMulHi:
542 return replaceMulHi(F, FI.getParameter(0).is_signed, false);
543
alan-baker6b9d1ee2020-11-03 23:11:32 -0500544 case Builtins::kMadSat:
545 return replaceMadSat(F, FI.getParameter(0).is_signed);
546
SJW2c317da2020-03-23 07:39:13 -0500547 case Builtins::kMad:
548 case Builtins::kMad24:
549 return replaceMul(F, FI.getParameter(0).type_id == llvm::Type::FloatTyID,
550 true);
551 case Builtins::kMul24:
552 return replaceMul(F, FI.getParameter(0).type_id == llvm::Type::FloatTyID,
553 false);
554
555 case Builtins::kSelect:
556 return replaceSelect(F);
557
558 case Builtins::kBitselect:
559 return replaceBitSelect(F);
560
561 case Builtins::kVload:
562 return replaceVload(F);
563
564 case Builtins::kVloadaHalf:
565 case Builtins::kVloadHalf:
566 return replaceVloadHalf(F, FI.getName(), FI.getParameter(0).vector_size);
567
568 case Builtins::kVstore:
569 return replaceVstore(F);
570
571 case Builtins::kVstoreHalf:
572 case Builtins::kVstoreaHalf:
573 return replaceVstoreHalf(F, FI.getParameter(0).vector_size);
574
575 case Builtins::kSmoothstep: {
576 int vec_size = FI.getLastParameter().vector_size;
577 if (FI.getParameter(0).vector_size == 0 && vec_size != 0) {
SJW61531372020-06-09 07:31:08 -0500578 return replaceStep(F, true);
SJW2c317da2020-03-23 07:39:13 -0500579 }
580 break;
581 }
582 case Builtins::kStep: {
583 int vec_size = FI.getLastParameter().vector_size;
584 if (FI.getParameter(0).vector_size == 0 && vec_size != 0) {
SJW61531372020-06-09 07:31:08 -0500585 return replaceStep(F, false);
SJW2c317da2020-03-23 07:39:13 -0500586 }
587 break;
588 }
589
590 case Builtins::kSignbit:
591 return replaceSignbit(F, FI.getParameter(0).vector_size != 0);
592
alan-baker3f1bf492020-11-05 09:07:36 -0500593 case Builtins::kSubSat:
594 return replaceAddSubSat(F, FI.getParameter(0).is_signed, false);
595
SJW2c317da2020-03-23 07:39:13 -0500596 case Builtins::kReadImageh:
597 return replaceHalfReadImage(F);
598 case Builtins::kReadImagef:
599 case Builtins::kReadImagei:
600 case Builtins::kReadImageui: {
601 if (FI.getParameter(1).isSampler() &&
602 FI.getParameter(2).type_id == llvm::Type::IntegerTyID) {
603 return replaceSampledReadImageWithIntCoords(F);
604 }
605 break;
606 }
607
608 case Builtins::kWriteImageh:
609 return replaceHalfWriteImage(F);
610
Kévin Petit1cb45112020-04-27 18:55:48 +0100611 case Builtins::kPrefetch:
612 return replacePrefetch(F);
613
SJW2c317da2020-03-23 07:39:13 -0500614 default:
615 break;
616 }
617
618 return false;
619}
620
alan-baker6b9d1ee2020-11-03 23:11:32 -0500621Type *ReplaceOpenCLBuiltinPass::GetPairStruct(Type *type) {
622 auto iter = PairStructMap.find(type);
623 if (iter != PairStructMap.end())
624 return iter->second;
625
626 auto new_struct = StructType::get(type->getContext(), {type, type});
627 PairStructMap[type] = new_struct;
628 return new_struct;
629}
630
SJW2c317da2020-03-23 07:39:13 -0500631bool ReplaceOpenCLBuiltinPass::replaceAbs(Function &F) {
632 return replaceCallsWithValue(F,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400633 [](CallInst *CI) { return CI->getOperand(0); });
Kévin Petite8edce32019-04-10 14:23:32 +0100634}
635
SJW2c317da2020-03-23 07:39:13 -0500636bool ReplaceOpenCLBuiltinPass::replaceAbsDiff(Function &F, bool is_signed) {
637 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100638 auto XValue = CI->getOperand(0);
639 auto YValue = CI->getOperand(1);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100640
Kévin Petite8edce32019-04-10 14:23:32 +0100641 IRBuilder<> Builder(CI);
642 auto XmY = Builder.CreateSub(XValue, YValue);
643 auto YmX = Builder.CreateSub(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100644
SJW2c317da2020-03-23 07:39:13 -0500645 Value *Cmp = nullptr;
646 if (is_signed) {
Kévin Petite8edce32019-04-10 14:23:32 +0100647 Cmp = Builder.CreateICmpSGT(YValue, XValue);
648 } else {
649 Cmp = Builder.CreateICmpUGT(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100650 }
Kévin Petit91bc72e2019-04-08 15:17:46 +0100651
Kévin Petite8edce32019-04-10 14:23:32 +0100652 return Builder.CreateSelect(Cmp, YmX, XmY);
653 });
Kévin Petit91bc72e2019-04-08 15:17:46 +0100654}
655
SJW2c317da2020-03-23 07:39:13 -0500656bool ReplaceOpenCLBuiltinPass::replaceCopysign(Function &F) {
657 return replaceCallsWithValue(F, [&F](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100658 auto XValue = CI->getOperand(0);
659 auto YValue = CI->getOperand(1);
Kévin Petit8c1be282019-04-02 19:34:25 +0100660
Kévin Petite8edce32019-04-10 14:23:32 +0100661 auto Ty = XValue->getType();
Kévin Petit8c1be282019-04-02 19:34:25 +0100662
SJW2c317da2020-03-23 07:39:13 -0500663 Type *IntTy = Type::getIntNTy(F.getContext(), Ty->getScalarSizeInBits());
James Pricecf53df42020-04-20 14:41:24 -0400664 if (auto vec_ty = dyn_cast<VectorType>(Ty)) {
alan-baker5a8c3be2020-09-09 13:44:26 -0400665 IntTy = FixedVectorType::get(
666 IntTy, vec_ty->getElementCount().getKnownMinValue());
Kévin Petit8c1be282019-04-02 19:34:25 +0100667 }
Kévin Petit8c1be282019-04-02 19:34:25 +0100668
Kévin Petite8edce32019-04-10 14:23:32 +0100669 // Return X with the sign of Y
670
671 // Sign bit masks
672 auto SignBit = IntTy->getScalarSizeInBits() - 1;
673 auto SignBitMask = 1 << SignBit;
674 auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
675 auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);
676
677 IRBuilder<> Builder(CI);
678
679 // Extract sign of Y
680 auto YInt = Builder.CreateBitCast(YValue, IntTy);
681 auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);
682
683 // Clear sign bit in X
684 auto XInt = Builder.CreateBitCast(XValue, IntTy);
685 XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);
686
687 // Insert sign bit of Y into X
688 auto NewXInt = Builder.CreateOr(XInt, YSign);
689
690 // And cast back to floating-point
691 return Builder.CreateBitCast(NewXInt, Ty);
692 });
Kévin Petit8c1be282019-04-02 19:34:25 +0100693}
694
SJW2c317da2020-03-23 07:39:13 -0500695bool ReplaceOpenCLBuiltinPass::replaceRecip(Function &F) {
696 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100697 // Recip has one arg.
698 auto Arg = CI->getOperand(0);
699 auto Cst1 = ConstantFP::get(Arg->getType(), 1.0);
700 return BinaryOperator::Create(Instruction::FDiv, Cst1, Arg, "", CI);
701 });
David Neto22f144c2017-06-12 14:26:21 -0400702}
703
SJW2c317da2020-03-23 07:39:13 -0500704bool ReplaceOpenCLBuiltinPass::replaceDivide(Function &F) {
705 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100706 auto Op0 = CI->getOperand(0);
707 auto Op1 = CI->getOperand(1);
708 return BinaryOperator::Create(Instruction::FDiv, Op0, Op1, "", CI);
709 });
David Neto22f144c2017-06-12 14:26:21 -0400710}
711
SJW2c317da2020-03-23 07:39:13 -0500712bool ReplaceOpenCLBuiltinPass::replaceDot(Function &F) {
713 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petit1329a002019-06-15 05:54:05 +0100714 auto Op0 = CI->getOperand(0);
715 auto Op1 = CI->getOperand(1);
716
SJW2c317da2020-03-23 07:39:13 -0500717 Value *V = nullptr;
Kévin Petit1329a002019-06-15 05:54:05 +0100718 if (Op0->getType()->isVectorTy()) {
719 V = clspv::InsertSPIRVOp(CI, spv::OpDot, {Attribute::ReadNone},
720 CI->getType(), {Op0, Op1});
721 } else {
722 V = BinaryOperator::Create(Instruction::FMul, Op0, Op1, "", CI);
723 }
724
725 return V;
726 });
727}
728
SJW2c317da2020-03-23 07:39:13 -0500729bool ReplaceOpenCLBuiltinPass::replaceExp10(Function &F,
SJW61531372020-06-09 07:31:08 -0500730 const std::string &basename) {
SJW2c317da2020-03-23 07:39:13 -0500731 // convert to natural
732 auto slen = basename.length() - 2;
SJW61531372020-06-09 07:31:08 -0500733 std::string NewFName = basename.substr(0, slen);
734 NewFName =
735 Builtins::GetMangledFunctionName(NewFName.c_str(), F.getFunctionType());
David Neto22f144c2017-06-12 14:26:21 -0400736
SJW2c317da2020-03-23 07:39:13 -0500737 Module &M = *F.getParent();
738 return replaceCallsWithValue(F, [&](CallInst *CI) {
739 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
740
741 auto Arg = CI->getOperand(0);
742
743 // Constant of the natural log of 10 (ln(10)).
744 const double Ln10 =
745 2.302585092994045684017991454684364207601101488628772976033;
746
747 auto Mul = BinaryOperator::Create(
748 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "", CI);
749
750 return CallInst::Create(NewF, Mul, "", CI);
751 });
David Neto22f144c2017-06-12 14:26:21 -0400752}
753
SJW2c317da2020-03-23 07:39:13 -0500754bool ReplaceOpenCLBuiltinPass::replaceFmod(Function &F) {
Kévin Petit0644a9c2019-06-20 21:08:46 +0100755 // OpenCL fmod(x,y) is x - y * trunc(x/y)
756 // The sign for a non-zero result is taken from x.
757 // (Try an example.)
758 // So translate to FRem
SJW2c317da2020-03-23 07:39:13 -0500759 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petit0644a9c2019-06-20 21:08:46 +0100760 auto Op0 = CI->getOperand(0);
761 auto Op1 = CI->getOperand(1);
762 return BinaryOperator::Create(Instruction::FRem, Op0, Op1, "", CI);
763 });
764}
765
SJW2c317da2020-03-23 07:39:13 -0500766bool ReplaceOpenCLBuiltinPass::replaceLog10(Function &F,
SJW61531372020-06-09 07:31:08 -0500767 const std::string &basename) {
SJW2c317da2020-03-23 07:39:13 -0500768 // convert to natural
769 auto slen = basename.length() - 2;
SJW61531372020-06-09 07:31:08 -0500770 std::string NewFName = basename.substr(0, slen);
771 NewFName =
772 Builtins::GetMangledFunctionName(NewFName.c_str(), F.getFunctionType());
David Neto22f144c2017-06-12 14:26:21 -0400773
SJW2c317da2020-03-23 07:39:13 -0500774 Module &M = *F.getParent();
775 return replaceCallsWithValue(F, [&](CallInst *CI) {
776 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
777
778 auto Arg = CI->getOperand(0);
779
780 // Constant of the reciprocal of the natural log of 10 (ln(10)).
781 const double Ln10 =
782 0.434294481903251827651128918916605082294397005803666566114;
783
784 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
785
786 return BinaryOperator::Create(Instruction::FMul,
787 ConstantFP::get(Arg->getType(), Ln10), NewCI,
788 "", CI);
789 });
David Neto22f144c2017-06-12 14:26:21 -0400790}
791
gnl21636e7992020-09-09 16:08:16 +0100792bool ReplaceOpenCLBuiltinPass::replaceLog1p(Function &F) {
793 // convert to natural
794 std::string NewFName =
795 Builtins::GetMangledFunctionName("log", F.getFunctionType());
796
797 Module &M = *F.getParent();
798 return replaceCallsWithValue(F, [&](CallInst *CI) {
799 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
800
801 auto Arg = CI->getOperand(0);
802
803 auto ArgP1 = BinaryOperator::Create(
804 Instruction::FAdd, ConstantFP::get(Arg->getType(), 1.0), Arg, "", CI);
805
806 return CallInst::Create(NewF, ArgP1, "", CI);
807 });
808}
809
alan-baker12d2c182020-07-20 08:22:42 -0400810bool ReplaceOpenCLBuiltinPass::replaceBarrier(Function &F, bool subgroup) {
David Neto22f144c2017-06-12 14:26:21 -0400811
alan-bakerf6bc8252020-09-23 14:58:55 -0400812 enum {
813 CLK_LOCAL_MEM_FENCE = 0x01,
814 CLK_GLOBAL_MEM_FENCE = 0x02,
815 CLK_IMAGE_MEM_FENCE = 0x04
816 };
David Neto22f144c2017-06-12 14:26:21 -0400817
alan-baker12d2c182020-07-20 08:22:42 -0400818 return replaceCallsWithValue(F, [subgroup](CallInst *CI) {
Kévin Petitc4643922019-06-17 19:32:05 +0100819 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400820
Kévin Petitc4643922019-06-17 19:32:05 +0100821 // We need to map the OpenCL constants to the SPIR-V equivalents.
822 const auto LocalMemFence =
823 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
824 const auto GlobalMemFence =
825 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
alan-bakerf6bc8252020-09-23 14:58:55 -0400826 const auto ImageMemFence =
827 ConstantInt::get(Arg->getType(), CLK_IMAGE_MEM_FENCE);
alan-baker12d2c182020-07-20 08:22:42 -0400828 const auto ConstantAcquireRelease = ConstantInt::get(
829 Arg->getType(), spv::MemorySemanticsAcquireReleaseMask);
Kévin Petitc4643922019-06-17 19:32:05 +0100830 const auto ConstantScopeDevice =
831 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
832 const auto ConstantScopeWorkgroup =
833 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
alan-baker12d2c182020-07-20 08:22:42 -0400834 const auto ConstantScopeSubgroup =
835 ConstantInt::get(Arg->getType(), spv::ScopeSubgroup);
David Neto22f144c2017-06-12 14:26:21 -0400836
Kévin Petitc4643922019-06-17 19:32:05 +0100837 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
838 const auto LocalMemFenceMask =
839 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
840 const auto WorkgroupShiftAmount =
841 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
842 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
843 Instruction::Shl, LocalMemFenceMask,
844 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400845
Kévin Petitc4643922019-06-17 19:32:05 +0100846 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
847 const auto GlobalMemFenceMask =
848 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
849 const auto UniformShiftAmount =
850 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
851 const auto MemorySemanticsUniform = BinaryOperator::Create(
852 Instruction::Shl, GlobalMemFenceMask,
853 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400854
alan-bakerf6bc8252020-09-23 14:58:55 -0400855 // OpenCL 2.0
856 // Map CLK_IMAGE_MEM_FENCE to MemorySemanticsImageMemoryMask.
857 const auto ImageMemFenceMask =
858 BinaryOperator::Create(Instruction::And, ImageMemFence, Arg, "", CI);
859 const auto ImageShiftAmount =
860 clz(spv::MemorySemanticsImageMemoryMask) - clz(CLK_IMAGE_MEM_FENCE);
861 const auto MemorySemanticsImage = BinaryOperator::Create(
862 Instruction::Shl, ImageMemFenceMask,
863 ConstantInt::get(Arg->getType(), ImageShiftAmount), "", CI);
864
Kévin Petitc4643922019-06-17 19:32:05 +0100865 // And combine the above together, also adding in
alan-bakerf6bc8252020-09-23 14:58:55 -0400866 // MemorySemanticsSequentiallyConsistentMask.
867 auto MemorySemantics1 =
Kévin Petitc4643922019-06-17 19:32:05 +0100868 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
alan-baker12d2c182020-07-20 08:22:42 -0400869 ConstantAcquireRelease, "", CI);
alan-bakerf6bc8252020-09-23 14:58:55 -0400870 auto MemorySemantics2 = BinaryOperator::Create(
871 Instruction::Or, MemorySemanticsUniform, MemorySemanticsImage, "", CI);
872 auto MemorySemantics = BinaryOperator::Create(
873 Instruction::Or, MemorySemantics1, MemorySemantics2, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400874
alan-baker12d2c182020-07-20 08:22:42 -0400875 // If the memory scope is not specified explicitly, it is either Subgroup
876 // or Workgroup depending on the type of barrier.
877 Value *MemoryScope =
878 subgroup ? ConstantScopeSubgroup : ConstantScopeWorkgroup;
879 if (CI->data_operands_size() > 1) {
880 enum {
881 CL_MEMORY_SCOPE_WORKGROUP = 0x1,
882 CL_MEMORY_SCOPE_DEVICE = 0x2,
883 CL_MEMORY_SCOPE_SUBGROUP = 0x4
884 };
885 // The call was given an explicit memory scope.
886 const auto MemoryScopeSubgroup =
887 ConstantInt::get(Arg->getType(), CL_MEMORY_SCOPE_SUBGROUP);
888 const auto MemoryScopeDevice =
889 ConstantInt::get(Arg->getType(), CL_MEMORY_SCOPE_DEVICE);
David Neto22f144c2017-06-12 14:26:21 -0400890
alan-baker12d2c182020-07-20 08:22:42 -0400891 auto Cmp =
892 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
893 MemoryScopeSubgroup, CI->getOperand(1), "", CI);
894 MemoryScope = SelectInst::Create(Cmp, ConstantScopeSubgroup,
895 ConstantScopeWorkgroup, "", CI);
896 Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
897 MemoryScopeDevice, CI->getOperand(1), "", CI);
898 MemoryScope =
899 SelectInst::Create(Cmp, ConstantScopeDevice, MemoryScope, "", CI);
900 }
901
902 // Lastly, the Execution Scope is either Workgroup or Subgroup depending on
903 // the type of barrier;
904 const auto ExecutionScope =
905 subgroup ? ConstantScopeSubgroup : ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400906
Kévin Petitc4643922019-06-17 19:32:05 +0100907 return clspv::InsertSPIRVOp(CI, spv::OpControlBarrier,
alan-baker3d905692020-10-28 14:02:37 -0400908 {Attribute::NoDuplicate, Attribute::Convergent},
909 CI->getType(),
Kévin Petitc4643922019-06-17 19:32:05 +0100910 {ExecutionScope, MemoryScope, MemorySemantics});
911 });
David Neto22f144c2017-06-12 14:26:21 -0400912}
913
SJW2c317da2020-03-23 07:39:13 -0500914bool ReplaceOpenCLBuiltinPass::replaceMemFence(Function &F,
915 uint32_t semantics) {
David Neto22f144c2017-06-12 14:26:21 -0400916
SJW2c317da2020-03-23 07:39:13 -0500917 return replaceCallsWithValue(F, [&](CallInst *CI) {
alan-bakerf6bc8252020-09-23 14:58:55 -0400918 enum {
919 CLK_LOCAL_MEM_FENCE = 0x01,
920 CLK_GLOBAL_MEM_FENCE = 0x02,
921 CLK_IMAGE_MEM_FENCE = 0x04,
922 };
David Neto22f144c2017-06-12 14:26:21 -0400923
SJW2c317da2020-03-23 07:39:13 -0500924 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400925
SJW2c317da2020-03-23 07:39:13 -0500926 // We need to map the OpenCL constants to the SPIR-V equivalents.
927 const auto LocalMemFence =
928 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
929 const auto GlobalMemFence =
930 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
alan-bakerf6bc8252020-09-23 14:58:55 -0400931 const auto ImageMemFence =
932 ConstantInt::get(Arg->getType(), CLK_IMAGE_MEM_FENCE);
SJW2c317da2020-03-23 07:39:13 -0500933 const auto ConstantMemorySemantics =
934 ConstantInt::get(Arg->getType(), semantics);
alan-baker12d2c182020-07-20 08:22:42 -0400935 const auto ConstantScopeWorkgroup =
936 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
David Neto22f144c2017-06-12 14:26:21 -0400937
SJW2c317da2020-03-23 07:39:13 -0500938 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
939 const auto LocalMemFenceMask =
940 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
941 const auto WorkgroupShiftAmount =
942 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
943 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
944 Instruction::Shl, LocalMemFenceMask,
945 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400946
SJW2c317da2020-03-23 07:39:13 -0500947 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
948 const auto GlobalMemFenceMask =
949 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
950 const auto UniformShiftAmount =
951 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
952 const auto MemorySemanticsUniform = BinaryOperator::Create(
953 Instruction::Shl, GlobalMemFenceMask,
954 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400955
alan-bakerf6bc8252020-09-23 14:58:55 -0400956 // OpenCL 2.0
957 // Map CLK_IMAGE_MEM_FENCE to MemorySemanticsImageMemoryMask.
958 const auto ImageMemFenceMask =
959 BinaryOperator::Create(Instruction::And, ImageMemFence, Arg, "", CI);
960 const auto ImageShiftAmount =
961 clz(spv::MemorySemanticsImageMemoryMask) - clz(CLK_IMAGE_MEM_FENCE);
962 const auto MemorySemanticsImage = BinaryOperator::Create(
963 Instruction::Shl, ImageMemFenceMask,
964 ConstantInt::get(Arg->getType(), ImageShiftAmount), "", CI);
965
SJW2c317da2020-03-23 07:39:13 -0500966 // And combine the above together, also adding in
alan-bakerf6bc8252020-09-23 14:58:55 -0400967 // |semantics|.
968 auto MemorySemantics1 =
SJW2c317da2020-03-23 07:39:13 -0500969 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
970 ConstantMemorySemantics, "", CI);
alan-bakerf6bc8252020-09-23 14:58:55 -0400971 auto MemorySemantics2 = BinaryOperator::Create(
972 Instruction::Or, MemorySemanticsUniform, MemorySemanticsImage, "", CI);
973 auto MemorySemantics = BinaryOperator::Create(
974 Instruction::Or, MemorySemantics1, MemorySemantics2, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400975
alan-baker12d2c182020-07-20 08:22:42 -0400976 // Memory Scope is always workgroup.
977 const auto MemoryScope = ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400978
alan-baker3d905692020-10-28 14:02:37 -0400979 return clspv::InsertSPIRVOp(CI, spv::OpMemoryBarrier,
980 {Attribute::Convergent}, CI->getType(),
SJW2c317da2020-03-23 07:39:13 -0500981 {MemoryScope, MemorySemantics});
982 });
David Neto22f144c2017-06-12 14:26:21 -0400983}
984
Kévin Petit1cb45112020-04-27 18:55:48 +0100985bool ReplaceOpenCLBuiltinPass::replacePrefetch(Function &F) {
986 bool Changed = false;
987
988 SmallVector<Instruction *, 4> ToRemoves;
989
990 // Find all calls to the function
991 for (auto &U : F.uses()) {
992 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
993 ToRemoves.push_back(CI);
994 }
995 }
996
997 Changed = !ToRemoves.empty();
998
999 // Delete them
1000 for (auto V : ToRemoves) {
1001 V->eraseFromParent();
1002 }
1003
1004 return Changed;
1005}
1006
SJW2c317da2020-03-23 07:39:13 -05001007bool ReplaceOpenCLBuiltinPass::replaceRelational(Function &F,
alan-baker3e217772020-11-07 17:29:40 -05001008 CmpInst::Predicate P) {
SJW2c317da2020-03-23 07:39:13 -05001009 return replaceCallsWithValue(F, [&](CallInst *CI) {
1010 // The predicate to use in the CmpInst.
1011 auto Predicate = P;
David Neto22f144c2017-06-12 14:26:21 -04001012
SJW2c317da2020-03-23 07:39:13 -05001013 auto Arg1 = CI->getOperand(0);
1014 auto Arg2 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001015
SJW2c317da2020-03-23 07:39:13 -05001016 const auto Cmp =
1017 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
alan-baker3e217772020-11-07 17:29:40 -05001018 if (isa<VectorType>(F.getReturnType()))
1019 return CastInst::Create(Instruction::SExt, Cmp, CI->getType(), "", CI);
1020 return CastInst::Create(Instruction::ZExt, Cmp, CI->getType(), "", CI);
SJW2c317da2020-03-23 07:39:13 -05001021 });
David Neto22f144c2017-06-12 14:26:21 -04001022}
1023
SJW2c317da2020-03-23 07:39:13 -05001024bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Function &F,
1025 spv::Op SPIRVOp,
1026 int32_t C) {
1027 Module &M = *F.getParent();
1028 return replaceCallsWithValue(F, [&](CallInst *CI) {
1029 const auto CITy = CI->getType();
David Neto22f144c2017-06-12 14:26:21 -04001030
SJW2c317da2020-03-23 07:39:13 -05001031 // The value to return for true.
1032 auto TrueValue = ConstantInt::getSigned(CITy, C);
David Neto22f144c2017-06-12 14:26:21 -04001033
SJW2c317da2020-03-23 07:39:13 -05001034 // The value to return for false.
1035 auto FalseValue = Constant::getNullValue(CITy);
David Neto22f144c2017-06-12 14:26:21 -04001036
SJW2c317da2020-03-23 07:39:13 -05001037 Type *CorrespondingBoolTy = Type::getInt1Ty(M.getContext());
James Pricecf53df42020-04-20 14:41:24 -04001038 if (auto CIVecTy = dyn_cast<VectorType>(CITy)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001039 CorrespondingBoolTy =
1040 FixedVectorType::get(Type::getInt1Ty(M.getContext()),
1041 CIVecTy->getElementCount().getKnownMinValue());
David Neto22f144c2017-06-12 14:26:21 -04001042 }
David Neto22f144c2017-06-12 14:26:21 -04001043
SJW2c317da2020-03-23 07:39:13 -05001044 auto NewCI = clspv::InsertSPIRVOp(CI, SPIRVOp, {Attribute::ReadNone},
1045 CorrespondingBoolTy, {CI->getOperand(0)});
1046
1047 return SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
1048 });
David Neto22f144c2017-06-12 14:26:21 -04001049}
1050
SJW2c317da2020-03-23 07:39:13 -05001051bool ReplaceOpenCLBuiltinPass::replaceIsFinite(Function &F) {
1052 Module &M = *F.getParent();
1053 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001054 auto &C = M.getContext();
1055 auto Val = CI->getOperand(0);
1056 auto ValTy = Val->getType();
1057 auto RetTy = CI->getType();
1058
1059 // Get a suitable integer type to represent the number
1060 auto IntTy = getIntOrIntVectorTyForCast(C, ValTy);
1061
1062 // Create Mask
1063 auto ScalarSize = ValTy->getScalarSizeInBits();
SJW2c317da2020-03-23 07:39:13 -05001064 Value *InfMask = nullptr;
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001065 switch (ScalarSize) {
1066 case 16:
1067 InfMask = ConstantInt::get(IntTy, 0x7C00U);
1068 break;
1069 case 32:
1070 InfMask = ConstantInt::get(IntTy, 0x7F800000U);
1071 break;
1072 case 64:
1073 InfMask = ConstantInt::get(IntTy, 0x7FF0000000000000ULL);
1074 break;
1075 default:
1076 llvm_unreachable("Unsupported floating-point type");
1077 }
1078
1079 IRBuilder<> Builder(CI);
1080
1081 // Bitcast to int
1082 auto ValInt = Builder.CreateBitCast(Val, IntTy);
1083
1084 // Mask and compare
1085 auto InfBits = Builder.CreateAnd(InfMask, ValInt);
1086 auto Cmp = Builder.CreateICmp(CmpInst::ICMP_EQ, InfBits, InfMask);
1087
1088 auto RetFalse = ConstantInt::get(RetTy, 0);
SJW2c317da2020-03-23 07:39:13 -05001089 Value *RetTrue = nullptr;
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001090 if (ValTy->isVectorTy()) {
1091 RetTrue = ConstantInt::getSigned(RetTy, -1);
1092 } else {
1093 RetTrue = ConstantInt::get(RetTy, 1);
1094 }
1095 return Builder.CreateSelect(Cmp, RetFalse, RetTrue);
1096 });
1097}
1098
SJW2c317da2020-03-23 07:39:13 -05001099bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Function &F, spv::Op SPIRVOp) {
1100 Module &M = *F.getParent();
1101 return replaceCallsWithValue(F, [&](CallInst *CI) {
1102 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001103
SJW2c317da2020-03-23 07:39:13 -05001104 Value *V = nullptr;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001105
SJW2c317da2020-03-23 07:39:13 -05001106 // If the argument is a 32-bit int, just use a shift
1107 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
1108 V = BinaryOperator::Create(Instruction::LShr, Arg,
1109 ConstantInt::get(Arg->getType(), 31), "", CI);
1110 } else {
1111 // The value for zero to compare against.
1112 const auto ZeroValue = Constant::getNullValue(Arg->getType());
David Neto22f144c2017-06-12 14:26:21 -04001113
SJW2c317da2020-03-23 07:39:13 -05001114 // The value to return for true.
1115 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
David Neto22f144c2017-06-12 14:26:21 -04001116
SJW2c317da2020-03-23 07:39:13 -05001117 // The value to return for false.
1118 const auto FalseValue = Constant::getNullValue(CI->getType());
David Neto22f144c2017-06-12 14:26:21 -04001119
SJW2c317da2020-03-23 07:39:13 -05001120 const auto Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT,
1121 Arg, ZeroValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001122
SJW2c317da2020-03-23 07:39:13 -05001123 Value *SelectSource = nullptr;
David Neto22f144c2017-06-12 14:26:21 -04001124
SJW2c317da2020-03-23 07:39:13 -05001125 // If we have a function to call, call it!
1126 if (SPIRVOp != spv::OpNop) {
David Neto22f144c2017-06-12 14:26:21 -04001127
SJW2c317da2020-03-23 07:39:13 -05001128 const auto BoolTy = Type::getInt1Ty(M.getContext());
David Neto22f144c2017-06-12 14:26:21 -04001129
SJW2c317da2020-03-23 07:39:13 -05001130 const auto NewCI = clspv::InsertSPIRVOp(
1131 CI, SPIRVOp, {Attribute::ReadNone}, BoolTy, {Cmp});
1132 SelectSource = NewCI;
David Neto22f144c2017-06-12 14:26:21 -04001133
SJW2c317da2020-03-23 07:39:13 -05001134 } else {
1135 SelectSource = Cmp;
David Neto22f144c2017-06-12 14:26:21 -04001136 }
1137
SJW2c317da2020-03-23 07:39:13 -05001138 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001139 }
SJW2c317da2020-03-23 07:39:13 -05001140 return V;
1141 });
David Neto22f144c2017-06-12 14:26:21 -04001142}
1143
SJW2c317da2020-03-23 07:39:13 -05001144bool ReplaceOpenCLBuiltinPass::replaceUpsample(Function &F) {
1145 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1146 // Get arguments
1147 auto HiValue = CI->getOperand(0);
1148 auto LoValue = CI->getOperand(1);
Kévin Petitbf0036c2019-03-06 13:57:10 +00001149
SJW2c317da2020-03-23 07:39:13 -05001150 // Don't touch overloads that aren't in OpenCL C
1151 auto HiType = HiValue->getType();
1152 auto LoType = LoValue->getType();
1153
1154 if (HiType != LoType) {
1155 return nullptr;
Kévin Petitbf0036c2019-03-06 13:57:10 +00001156 }
Kévin Petitbf0036c2019-03-06 13:57:10 +00001157
SJW2c317da2020-03-23 07:39:13 -05001158 if (!HiType->isIntOrIntVectorTy()) {
1159 return nullptr;
Kévin Petitbf0036c2019-03-06 13:57:10 +00001160 }
Kévin Petitbf0036c2019-03-06 13:57:10 +00001161
SJW2c317da2020-03-23 07:39:13 -05001162 if (HiType->getScalarSizeInBits() * 2 !=
1163 CI->getType()->getScalarSizeInBits()) {
1164 return nullptr;
1165 }
1166
1167 if ((HiType->getScalarSizeInBits() != 8) &&
1168 (HiType->getScalarSizeInBits() != 16) &&
1169 (HiType->getScalarSizeInBits() != 32)) {
1170 return nullptr;
1171 }
1172
James Pricecf53df42020-04-20 14:41:24 -04001173 if (auto HiVecType = dyn_cast<VectorType>(HiType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001174 unsigned NumElements = HiVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001175 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1176 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001177 return nullptr;
1178 }
1179 }
1180
1181 // Convert both operands to the result type
1182 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(), "", CI);
1183 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(), "", CI);
1184
1185 // Shift high operand
1186 auto ShiftAmount =
1187 ConstantInt::get(CI->getType(), HiType->getScalarSizeInBits());
1188 auto HiShifted =
1189 BinaryOperator::Create(Instruction::Shl, HiCast, ShiftAmount, "", CI);
1190
1191 // OR both results
1192 return BinaryOperator::Create(Instruction::Or, HiShifted, LoCast, "", CI);
1193 });
Kévin Petitbf0036c2019-03-06 13:57:10 +00001194}
1195
SJW2c317da2020-03-23 07:39:13 -05001196bool ReplaceOpenCLBuiltinPass::replaceRotate(Function &F) {
1197 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1198 // Get arguments
1199 auto SrcValue = CI->getOperand(0);
1200 auto RotAmount = CI->getOperand(1);
Kévin Petitd44eef52019-03-08 13:22:14 +00001201
SJW2c317da2020-03-23 07:39:13 -05001202 // Don't touch overloads that aren't in OpenCL C
1203 auto SrcType = SrcValue->getType();
1204 auto RotType = RotAmount->getType();
1205
1206 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1207 return nullptr;
Kévin Petitd44eef52019-03-08 13:22:14 +00001208 }
Kévin Petitd44eef52019-03-08 13:22:14 +00001209
SJW2c317da2020-03-23 07:39:13 -05001210 if (!SrcType->isIntOrIntVectorTy()) {
1211 return nullptr;
Kévin Petitd44eef52019-03-08 13:22:14 +00001212 }
Kévin Petitd44eef52019-03-08 13:22:14 +00001213
SJW2c317da2020-03-23 07:39:13 -05001214 if ((SrcType->getScalarSizeInBits() != 8) &&
1215 (SrcType->getScalarSizeInBits() != 16) &&
1216 (SrcType->getScalarSizeInBits() != 32) &&
1217 (SrcType->getScalarSizeInBits() != 64)) {
1218 return nullptr;
1219 }
1220
James Pricecf53df42020-04-20 14:41:24 -04001221 if (auto SrcVecType = dyn_cast<VectorType>(SrcType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001222 unsigned NumElements = SrcVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001223 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1224 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001225 return nullptr;
1226 }
1227 }
1228
alan-bakerfd22ae12020-10-29 15:59:22 -04001229 // Replace with LLVM's funnel shift left intrinsic because it is more
1230 // generic than rotate.
1231 Function *intrinsic =
1232 Intrinsic::getDeclaration(F.getParent(), Intrinsic::fshl, SrcType);
1233 return CallInst::Create(intrinsic->getFunctionType(), intrinsic,
1234 {SrcValue, SrcValue, RotAmount}, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001235 });
Kévin Petitd44eef52019-03-08 13:22:14 +00001236}
1237
SJW2c317da2020-03-23 07:39:13 -05001238bool ReplaceOpenCLBuiltinPass::replaceConvert(Function &F, bool SrcIsSigned,
1239 bool DstIsSigned) {
1240 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1241 Value *V = nullptr;
1242 // Get arguments
1243 auto SrcValue = CI->getOperand(0);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001244
SJW2c317da2020-03-23 07:39:13 -05001245 // Don't touch overloads that aren't in OpenCL C
1246 auto SrcType = SrcValue->getType();
1247 auto DstType = CI->getType();
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001248
SJW2c317da2020-03-23 07:39:13 -05001249 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1250 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1251 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001252 }
1253
James Pricecf53df42020-04-20 14:41:24 -04001254 if (auto SrcVecType = dyn_cast<VectorType>(SrcType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001255 unsigned SrcNumElements =
1256 SrcVecType->getElementCount().getKnownMinValue();
1257 unsigned DstNumElements =
1258 cast<VectorType>(DstType)->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001259 if (SrcNumElements != DstNumElements) {
SJW2c317da2020-03-23 07:39:13 -05001260 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001261 }
1262
James Pricecf53df42020-04-20 14:41:24 -04001263 if ((SrcNumElements != 2) && (SrcNumElements != 3) &&
1264 (SrcNumElements != 4) && (SrcNumElements != 8) &&
1265 (SrcNumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001266 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001267 }
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001268 }
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001269
SJW2c317da2020-03-23 07:39:13 -05001270 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1271 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1272
1273 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1274 bool DstIsInt = DstType->isIntOrIntVectorTy();
1275
1276 if (SrcType == DstType && DstIsSigned == SrcIsSigned) {
1277 // Unnecessary cast operation.
1278 V = SrcValue;
1279 } else if (SrcIsFloat && DstIsFloat) {
1280 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1281 } else if (SrcIsFloat && DstIsInt) {
1282 if (DstIsSigned) {
1283 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "", CI);
1284 } else {
1285 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "", CI);
1286 }
1287 } else if (SrcIsInt && DstIsFloat) {
1288 if (SrcIsSigned) {
1289 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "", CI);
1290 } else {
1291 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "", CI);
1292 }
1293 } else if (SrcIsInt && DstIsInt) {
1294 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "", CI);
1295 } else {
1296 // Not something we're supposed to handle, just move on
1297 }
1298
1299 return V;
1300 });
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001301}
1302
SJW2c317da2020-03-23 07:39:13 -05001303bool ReplaceOpenCLBuiltinPass::replaceMulHi(Function &F, bool is_signed,
1304 bool is_mad) {
1305 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1306 Value *V = nullptr;
1307 // Get arguments
1308 auto AValue = CI->getOperand(0);
1309 auto BValue = CI->getOperand(1);
1310 auto CValue = CI->getOperand(2);
Kévin Petit8a560882019-03-21 15:24:34 +00001311
SJW2c317da2020-03-23 07:39:13 -05001312 // Don't touch overloads that aren't in OpenCL C
1313 auto AType = AValue->getType();
1314 auto BType = BValue->getType();
1315 auto CType = CValue->getType();
Kévin Petit8a560882019-03-21 15:24:34 +00001316
SJW2c317da2020-03-23 07:39:13 -05001317 if ((AType != BType) || (CI->getType() != AType) ||
1318 (is_mad && (AType != CType))) {
1319 return V;
Kévin Petit8a560882019-03-21 15:24:34 +00001320 }
1321
SJW2c317da2020-03-23 07:39:13 -05001322 if (!AType->isIntOrIntVectorTy()) {
1323 return V;
Kévin Petit8a560882019-03-21 15:24:34 +00001324 }
Kévin Petit8a560882019-03-21 15:24:34 +00001325
SJW2c317da2020-03-23 07:39:13 -05001326 if ((AType->getScalarSizeInBits() != 8) &&
1327 (AType->getScalarSizeInBits() != 16) &&
1328 (AType->getScalarSizeInBits() != 32) &&
1329 (AType->getScalarSizeInBits() != 64)) {
1330 return V;
1331 }
Kévin Petit617a76d2019-04-04 13:54:16 +01001332
James Pricecf53df42020-04-20 14:41:24 -04001333 if (auto AVecType = dyn_cast<VectorType>(AType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001334 unsigned NumElements = AVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001335 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1336 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001337 return V;
Kévin Petit617a76d2019-04-04 13:54:16 +01001338 }
1339 }
1340
SJW2c317da2020-03-23 07:39:13 -05001341 // Our SPIR-V op returns a struct, create a type for it
alan-baker6b9d1ee2020-11-03 23:11:32 -05001342 auto ExMulRetType = GetPairStruct(AType);
Kévin Petit617a76d2019-04-04 13:54:16 +01001343
SJW2c317da2020-03-23 07:39:13 -05001344 // Select the appropriate signed/unsigned SPIR-V op
1345 spv::Op opcode = is_signed ? spv::OpSMulExtended : spv::OpUMulExtended;
1346
1347 // Call the SPIR-V op
1348 auto Call = clspv::InsertSPIRVOp(CI, opcode, {Attribute::ReadNone},
1349 ExMulRetType, {AValue, BValue});
1350
1351 // Get the high part of the result
1352 unsigned Idxs[] = {1};
1353 V = ExtractValueInst::Create(Call, Idxs, "", CI);
1354
1355 // If we're handling a mad_hi, add the third argument to the result
1356 if (is_mad) {
1357 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
Kévin Petit617a76d2019-04-04 13:54:16 +01001358 }
1359
SJW2c317da2020-03-23 07:39:13 -05001360 return V;
1361 });
Kévin Petit8a560882019-03-21 15:24:34 +00001362}
1363
SJW2c317da2020-03-23 07:39:13 -05001364bool ReplaceOpenCLBuiltinPass::replaceSelect(Function &F) {
1365 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1366 // Get arguments
1367 auto FalseValue = CI->getOperand(0);
1368 auto TrueValue = CI->getOperand(1);
1369 auto PredicateValue = CI->getOperand(2);
Kévin Petitf5b78a22018-10-25 14:32:17 +00001370
SJW2c317da2020-03-23 07:39:13 -05001371 // Don't touch overloads that aren't in OpenCL C
1372 auto FalseType = FalseValue->getType();
1373 auto TrueType = TrueValue->getType();
1374 auto PredicateType = PredicateValue->getType();
1375
1376 if (FalseType != TrueType) {
1377 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001378 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001379
SJW2c317da2020-03-23 07:39:13 -05001380 if (!PredicateType->isIntOrIntVectorTy()) {
1381 return nullptr;
1382 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001383
SJW2c317da2020-03-23 07:39:13 -05001384 if (!FalseType->isIntOrIntVectorTy() &&
1385 !FalseType->getScalarType()->isFloatingPointTy()) {
1386 return nullptr;
1387 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001388
SJW2c317da2020-03-23 07:39:13 -05001389 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1390 return nullptr;
1391 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001392
SJW2c317da2020-03-23 07:39:13 -05001393 if (FalseType->getScalarSizeInBits() !=
1394 PredicateType->getScalarSizeInBits()) {
1395 return nullptr;
1396 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001397
James Pricecf53df42020-04-20 14:41:24 -04001398 if (auto FalseVecType = dyn_cast<VectorType>(FalseType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001399 unsigned NumElements = FalseVecType->getElementCount().getKnownMinValue();
1400 if (NumElements != cast<VectorType>(PredicateType)
1401 ->getElementCount()
1402 .getKnownMinValue()) {
SJW2c317da2020-03-23 07:39:13 -05001403 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001404 }
1405
James Pricecf53df42020-04-20 14:41:24 -04001406 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1407 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001408 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001409 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001410 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001411
SJW2c317da2020-03-23 07:39:13 -05001412 // Create constant
1413 const auto ZeroValue = Constant::getNullValue(PredicateType);
1414
1415 // Scalar and vector are to be treated differently
1416 CmpInst::Predicate Pred;
1417 if (PredicateType->isVectorTy()) {
1418 Pred = CmpInst::ICMP_SLT;
1419 } else {
1420 Pred = CmpInst::ICMP_NE;
1421 }
1422
1423 // Create comparison instruction
1424 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1425 ZeroValue, "", CI);
1426
1427 // Create select
1428 return SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1429 });
Kévin Petitf5b78a22018-10-25 14:32:17 +00001430}
1431
SJW2c317da2020-03-23 07:39:13 -05001432bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Function &F) {
1433 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1434 Value *V = nullptr;
1435 if (CI->getNumOperands() != 4) {
1436 return V;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001437 }
Kévin Petite7d0cce2018-10-31 12:38:56 +00001438
SJW2c317da2020-03-23 07:39:13 -05001439 // Get arguments
1440 auto FalseValue = CI->getOperand(0);
1441 auto TrueValue = CI->getOperand(1);
1442 auto PredicateValue = CI->getOperand(2);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001443
SJW2c317da2020-03-23 07:39:13 -05001444 // Don't touch overloads that aren't in OpenCL C
1445 auto FalseType = FalseValue->getType();
1446 auto TrueType = TrueValue->getType();
1447 auto PredicateType = PredicateValue->getType();
Kévin Petite7d0cce2018-10-31 12:38:56 +00001448
SJW2c317da2020-03-23 07:39:13 -05001449 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1450 return V;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001451 }
Kévin Petite7d0cce2018-10-31 12:38:56 +00001452
James Pricecf53df42020-04-20 14:41:24 -04001453 if (auto TrueVecType = dyn_cast<VectorType>(TrueType)) {
SJW2c317da2020-03-23 07:39:13 -05001454 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1455 !TrueType->getScalarType()->isIntegerTy()) {
1456 return V;
1457 }
alan-baker5a8c3be2020-09-09 13:44:26 -04001458 unsigned NumElements = TrueVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001459 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1460 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001461 return V;
1462 }
1463 }
1464
1465 // Remember the type of the operands
1466 auto OpType = TrueType;
1467
1468 // The actual bit selection will always be done on an integer type,
1469 // declare it here
1470 Type *BitType;
1471
1472 // If the operands are float, then bitcast them to int
1473 if (OpType->getScalarType()->isFloatingPointTy()) {
1474
1475 // First create the new type
1476 BitType = getIntOrIntVectorTyForCast(F.getContext(), OpType);
1477
1478 // Then bitcast all operands
1479 PredicateValue =
1480 CastInst::CreateZExtOrBitCast(PredicateValue, BitType, "", CI);
1481 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue, BitType, "", CI);
1482 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1483
1484 } else {
1485 // The operands have an integer type, use it directly
1486 BitType = OpType;
1487 }
1488
1489 // All the operands are now always integers
1490 // implement as (c & b) | (~c & a)
1491
1492 // Create our negated predicate value
1493 auto AllOnes = Constant::getAllOnesValue(BitType);
1494 auto NotPredicateValue = BinaryOperator::Create(
1495 Instruction::Xor, PredicateValue, AllOnes, "", CI);
1496
1497 // Then put everything together
1498 auto BitsFalse = BinaryOperator::Create(Instruction::And, NotPredicateValue,
1499 FalseValue, "", CI);
1500 auto BitsTrue = BinaryOperator::Create(Instruction::And, PredicateValue,
1501 TrueValue, "", CI);
1502
1503 V = BinaryOperator::Create(Instruction::Or, BitsFalse, BitsTrue, "", CI);
1504
1505 // If we were dealing with a floating point type, we must bitcast
1506 // the result back to that
1507 if (OpType->getScalarType()->isFloatingPointTy()) {
1508 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1509 }
1510
1511 return V;
1512 });
Kévin Petite7d0cce2018-10-31 12:38:56 +00001513}
1514
SJW61531372020-06-09 07:31:08 -05001515bool ReplaceOpenCLBuiltinPass::replaceStep(Function &F, bool is_smooth) {
SJW2c317da2020-03-23 07:39:13 -05001516 // convert to vector versions
1517 Module &M = *F.getParent();
1518 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1519 SmallVector<Value *, 2> ArgsToSplat = {CI->getOperand(0)};
1520 Value *VectorArg = nullptr;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001521
SJW2c317da2020-03-23 07:39:13 -05001522 // First figure out which function we're dealing with
1523 if (is_smooth) {
1524 ArgsToSplat.push_back(CI->getOperand(1));
1525 VectorArg = CI->getOperand(2);
1526 } else {
1527 VectorArg = CI->getOperand(1);
1528 }
1529
1530 // Splat arguments that need to be
1531 SmallVector<Value *, 2> SplatArgs;
James Pricecf53df42020-04-20 14:41:24 -04001532 auto VecType = cast<VectorType>(VectorArg->getType());
SJW2c317da2020-03-23 07:39:13 -05001533
1534 for (auto arg : ArgsToSplat) {
1535 Value *NewVectorArg = UndefValue::get(VecType);
alan-baker5a8c3be2020-09-09 13:44:26 -04001536 for (auto i = 0; i < VecType->getElementCount().getKnownMinValue(); i++) {
SJW2c317da2020-03-23 07:39:13 -05001537 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1538 NewVectorArg =
1539 InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1540 }
1541 SplatArgs.push_back(NewVectorArg);
1542 }
1543
1544 // Replace the call with the vector/vector flavour
1545 SmallVector<Type *, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1546 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1547
SJW61531372020-06-09 07:31:08 -05001548 std::string NewFName = Builtins::GetMangledFunctionName(
1549 is_smooth ? "smoothstep" : "step", NewFType);
1550
SJW2c317da2020-03-23 07:39:13 -05001551 const auto NewF = M.getOrInsertFunction(NewFName, NewFType);
1552
1553 SmallVector<Value *, 3> NewArgs;
1554 for (auto arg : SplatArgs) {
1555 NewArgs.push_back(arg);
1556 }
1557 NewArgs.push_back(VectorArg);
1558
1559 return CallInst::Create(NewF, NewArgs, "", CI);
1560 });
Kévin Petit6b0a9532018-10-30 20:00:39 +00001561}
1562
SJW2c317da2020-03-23 07:39:13 -05001563bool ReplaceOpenCLBuiltinPass::replaceSignbit(Function &F, bool is_vec) {
SJW2c317da2020-03-23 07:39:13 -05001564 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1565 auto Arg = CI->getOperand(0);
1566 auto Op = is_vec ? Instruction::AShr : Instruction::LShr;
David Neto22f144c2017-06-12 14:26:21 -04001567
SJW2c317da2020-03-23 07:39:13 -05001568 auto Bitcast = CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001569
SJW2c317da2020-03-23 07:39:13 -05001570 return BinaryOperator::Create(Op, Bitcast,
1571 ConstantInt::get(CI->getType(), 31), "", CI);
1572 });
David Neto22f144c2017-06-12 14:26:21 -04001573}
1574
SJW2c317da2020-03-23 07:39:13 -05001575bool ReplaceOpenCLBuiltinPass::replaceMul(Function &F, bool is_float,
1576 bool is_mad) {
SJW2c317da2020-03-23 07:39:13 -05001577 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1578 // The multiply instruction to use.
1579 auto MulInst = is_float ? Instruction::FMul : Instruction::Mul;
David Neto22f144c2017-06-12 14:26:21 -04001580
SJW2c317da2020-03-23 07:39:13 -05001581 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
David Neto22f144c2017-06-12 14:26:21 -04001582
SJW2c317da2020-03-23 07:39:13 -05001583 Value *V = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1584 CI->getArgOperand(1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001585
SJW2c317da2020-03-23 07:39:13 -05001586 if (is_mad) {
1587 // The add instruction to use.
1588 auto AddInst = is_float ? Instruction::FAdd : Instruction::Add;
David Neto22f144c2017-06-12 14:26:21 -04001589
SJW2c317da2020-03-23 07:39:13 -05001590 V = BinaryOperator::Create(AddInst, V, CI->getArgOperand(2), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001591 }
David Neto22f144c2017-06-12 14:26:21 -04001592
SJW2c317da2020-03-23 07:39:13 -05001593 return V;
1594 });
David Neto22f144c2017-06-12 14:26:21 -04001595}
1596
SJW2c317da2020-03-23 07:39:13 -05001597bool ReplaceOpenCLBuiltinPass::replaceVstore(Function &F) {
SJW2c317da2020-03-23 07:39:13 -05001598 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1599 Value *V = nullptr;
1600 auto data = CI->getOperand(0);
Derek Chowcfd368b2017-10-19 20:58:45 -07001601
SJW2c317da2020-03-23 07:39:13 -05001602 auto data_type = data->getType();
1603 if (!data_type->isVectorTy())
1604 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001605
James Pricecf53df42020-04-20 14:41:24 -04001606 auto vec_data_type = cast<VectorType>(data_type);
1607
alan-baker5a8c3be2020-09-09 13:44:26 -04001608 auto elems = vec_data_type->getElementCount().getKnownMinValue();
SJW2c317da2020-03-23 07:39:13 -05001609 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 && elems != 16)
1610 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001611
SJW2c317da2020-03-23 07:39:13 -05001612 auto offset = CI->getOperand(1);
1613 auto ptr = CI->getOperand(2);
1614 auto ptr_type = ptr->getType();
1615 auto pointee_type = ptr_type->getPointerElementType();
James Pricecf53df42020-04-20 14:41:24 -04001616 if (pointee_type != vec_data_type->getElementType())
SJW2c317da2020-03-23 07:39:13 -05001617 return V;
alan-bakerf795f392019-06-11 18:24:34 -04001618
SJW2c317da2020-03-23 07:39:13 -05001619 // Avoid pointer casts. Instead generate the correct number of stores
1620 // and rely on drivers to coalesce appropriately.
1621 IRBuilder<> builder(CI);
1622 auto elems_const = builder.getInt32(elems);
1623 auto adjust = builder.CreateMul(offset, elems_const);
1624 for (auto i = 0; i < elems; ++i) {
1625 auto idx = builder.getInt32(i);
1626 auto add = builder.CreateAdd(adjust, idx);
1627 auto gep = builder.CreateGEP(ptr, add);
1628 auto extract = builder.CreateExtractElement(data, i);
1629 V = builder.CreateStore(extract, gep);
Derek Chowcfd368b2017-10-19 20:58:45 -07001630 }
SJW2c317da2020-03-23 07:39:13 -05001631 return V;
1632 });
Derek Chowcfd368b2017-10-19 20:58:45 -07001633}
1634
SJW2c317da2020-03-23 07:39:13 -05001635bool ReplaceOpenCLBuiltinPass::replaceVload(Function &F) {
SJW2c317da2020-03-23 07:39:13 -05001636 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1637 Value *V = nullptr;
1638 auto ret_type = F.getReturnType();
1639 if (!ret_type->isVectorTy())
1640 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001641
James Pricecf53df42020-04-20 14:41:24 -04001642 auto vec_ret_type = cast<VectorType>(ret_type);
1643
alan-baker5a8c3be2020-09-09 13:44:26 -04001644 auto elems = vec_ret_type->getElementCount().getKnownMinValue();
SJW2c317da2020-03-23 07:39:13 -05001645 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 && elems != 16)
1646 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001647
SJW2c317da2020-03-23 07:39:13 -05001648 auto offset = CI->getOperand(0);
1649 auto ptr = CI->getOperand(1);
1650 auto ptr_type = ptr->getType();
1651 auto pointee_type = ptr_type->getPointerElementType();
James Pricecf53df42020-04-20 14:41:24 -04001652 if (pointee_type != vec_ret_type->getElementType())
SJW2c317da2020-03-23 07:39:13 -05001653 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001654
SJW2c317da2020-03-23 07:39:13 -05001655 // Avoid pointer casts. Instead generate the correct number of loads
1656 // and rely on drivers to coalesce appropriately.
1657 IRBuilder<> builder(CI);
1658 auto elems_const = builder.getInt32(elems);
1659 V = UndefValue::get(ret_type);
1660 auto adjust = builder.CreateMul(offset, elems_const);
1661 for (auto i = 0; i < elems; ++i) {
1662 auto idx = builder.getInt32(i);
1663 auto add = builder.CreateAdd(adjust, idx);
1664 auto gep = builder.CreateGEP(ptr, add);
1665 auto load = builder.CreateLoad(gep);
1666 V = builder.CreateInsertElement(V, load, i);
Derek Chowcfd368b2017-10-19 20:58:45 -07001667 }
SJW2c317da2020-03-23 07:39:13 -05001668 return V;
1669 });
Derek Chowcfd368b2017-10-19 20:58:45 -07001670}
1671
SJW2c317da2020-03-23 07:39:13 -05001672bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Function &F,
1673 const std::string &name,
1674 int vec_size) {
1675 bool is_clspv_version = !name.compare(0, 8, "__clspv_");
1676 if (!vec_size) {
1677 // deduce vec_size from last character of name (e.g. vload_half4)
1678 vec_size = std::atoi(&name.back());
David Neto22f144c2017-06-12 14:26:21 -04001679 }
SJW2c317da2020-03-23 07:39:13 -05001680 switch (vec_size) {
1681 case 2:
1682 return is_clspv_version ? replaceClspvVloadaHalf2(F) : replaceVloadHalf2(F);
1683 case 4:
1684 return is_clspv_version ? replaceClspvVloadaHalf4(F) : replaceVloadHalf4(F);
1685 case 0:
1686 if (!is_clspv_version) {
1687 return replaceVloadHalf(F);
1688 }
1689 default:
1690 llvm_unreachable("Unsupported vload_half vector size");
1691 break;
1692 }
1693 return false;
David Neto22f144c2017-06-12 14:26:21 -04001694}
1695
SJW2c317da2020-03-23 07:39:13 -05001696bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Function &F) {
1697 Module &M = *F.getParent();
1698 return replaceCallsWithValue(F, [&](CallInst *CI) {
1699 // The index argument from vload_half.
1700 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001701
SJW2c317da2020-03-23 07:39:13 -05001702 // The pointer argument from vload_half.
1703 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001704
SJW2c317da2020-03-23 07:39:13 -05001705 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001706 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
SJW2c317da2020-03-23 07:39:13 -05001707 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1708
1709 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001710 auto SPIRVIntrinsic = clspv::UnpackFunction();
SJW2c317da2020-03-23 07:39:13 -05001711
1712 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1713
1714 Value *V = nullptr;
1715
alan-baker7efcaaa2020-05-06 19:33:27 -04001716 bool supports_16bit_storage = true;
1717 switch (Arg1->getType()->getPointerAddressSpace()) {
1718 case clspv::AddressSpace::Global:
1719 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1720 clspv::Option::StorageClass::kSSBO);
1721 break;
1722 case clspv::AddressSpace::Constant:
1723 if (clspv::Option::ConstantArgsInUniformBuffer())
1724 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1725 clspv::Option::StorageClass::kUBO);
1726 else
1727 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1728 clspv::Option::StorageClass::kSSBO);
1729 break;
1730 default:
1731 // Clspv will emit the Float16 capability if the half type is
1732 // encountered. That capability covers private and local addressspaces.
1733 break;
1734 }
1735
1736 if (supports_16bit_storage) {
SJW2c317da2020-03-23 07:39:13 -05001737 auto ShortTy = Type::getInt16Ty(M.getContext());
1738 auto ShortPointerTy =
1739 PointerType::get(ShortTy, Arg1->getType()->getPointerAddressSpace());
1740
1741 // Cast the half* pointer to short*.
1742 auto Cast = CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
1743
1744 // Index into the correct address of the casted pointer.
1745 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1746
1747 // Load from the short* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001748 auto Load = new LoadInst(ShortTy, Index, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001749
1750 // ZExt the short -> int.
1751 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1752
1753 // Get our float2.
1754 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1755
1756 // Extract out the bottom element which is our float result.
1757 V = ExtractElementInst::Create(Call, ConstantInt::get(IntTy, 0), "", CI);
1758 } else {
1759 // Assume the pointer argument points to storage aligned to 32bits
1760 // or more.
1761 // TODO(dneto): Do more analysis to make sure this is true?
1762 //
1763 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1764 // with:
1765 //
1766 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1767 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1768 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1769 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1770 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1771 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1772 // x float> %converted, %index_is_odd32
1773
1774 auto IntPointerTy =
1775 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
1776
1777 // Cast the base pointer to int*.
1778 // In a valid call (according to assumptions), this should get
1779 // optimized away in the simplify GEP pass.
1780 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1781
1782 auto One = ConstantInt::get(IntTy, 1);
1783 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1784 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1785
1786 // Index into the correct address of the casted pointer.
1787 auto Ptr = GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1788
1789 // Load from the int* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001790 auto Load = new LoadInst(IntTy, Ptr, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001791
1792 // Get our float2.
1793 auto Call = CallInst::Create(NewF, Load, "", CI);
1794
1795 // Extract out the float result, where the element number is
1796 // determined by whether the original index was even or odd.
1797 V = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1798 }
1799 return V;
1800 });
1801}
1802
1803bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Function &F) {
1804 Module &M = *F.getParent();
1805 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001806 // The index argument from vload_half.
1807 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001808
Kévin Petite8edce32019-04-10 14:23:32 +01001809 // The pointer argument from vload_half.
1810 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001811
Kévin Petite8edce32019-04-10 14:23:32 +01001812 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001813 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001814 auto NewPointerTy =
1815 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01001816 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04001817
Kévin Petite8edce32019-04-10 14:23:32 +01001818 // Cast the half* pointer to int*.
1819 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001820
Kévin Petite8edce32019-04-10 14:23:32 +01001821 // Index into the correct address of the casted pointer.
1822 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001823
Kévin Petite8edce32019-04-10 14:23:32 +01001824 // Load from the int* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001825 auto Load = new LoadInst(IntTy, Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001826
Kévin Petite8edce32019-04-10 14:23:32 +01001827 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001828 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto22f144c2017-06-12 14:26:21 -04001829
Kévin Petite8edce32019-04-10 14:23:32 +01001830 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04001831
Kévin Petite8edce32019-04-10 14:23:32 +01001832 // Get our float2.
1833 return CallInst::Create(NewF, Load, "", CI);
1834 });
David Neto22f144c2017-06-12 14:26:21 -04001835}
1836
SJW2c317da2020-03-23 07:39:13 -05001837bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Function &F) {
1838 Module &M = *F.getParent();
1839 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001840 // The index argument from vload_half.
1841 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001842
Kévin Petite8edce32019-04-10 14:23:32 +01001843 // The pointer argument from vload_half.
1844 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001845
Kévin Petite8edce32019-04-10 14:23:32 +01001846 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001847 auto Int2Ty = FixedVectorType::get(IntTy, 2);
1848 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001849 auto NewPointerTy =
1850 PointerType::get(Int2Ty, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01001851 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04001852
Kévin Petite8edce32019-04-10 14:23:32 +01001853 // Cast the half* pointer to int2*.
1854 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001855
Kévin Petite8edce32019-04-10 14:23:32 +01001856 // Index into the correct address of the casted pointer.
1857 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001858
Kévin Petite8edce32019-04-10 14:23:32 +01001859 // Load from the int2* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001860 auto Load = new LoadInst(Int2Ty, Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001861
Kévin Petite8edce32019-04-10 14:23:32 +01001862 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001863 auto X =
1864 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
1865 auto Y =
1866 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001867
Kévin Petite8edce32019-04-10 14:23:32 +01001868 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001869 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto22f144c2017-06-12 14:26:21 -04001870
Kévin Petite8edce32019-04-10 14:23:32 +01001871 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04001872
Kévin Petite8edce32019-04-10 14:23:32 +01001873 // Get the lower (x & y) components of our final float4.
1874 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001875
Kévin Petite8edce32019-04-10 14:23:32 +01001876 // Get the higher (z & w) components of our final float4.
1877 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001878
Kévin Petite8edce32019-04-10 14:23:32 +01001879 Constant *ShuffleMask[4] = {
1880 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1881 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04001882
Kévin Petite8edce32019-04-10 14:23:32 +01001883 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001884 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
1885 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01001886 });
David Neto22f144c2017-06-12 14:26:21 -04001887}
1888
SJW2c317da2020-03-23 07:39:13 -05001889bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Function &F) {
David Neto6ad93232018-06-07 15:42:58 -07001890
1891 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1892 //
1893 // %u = load i32 %ptr
1894 // %fxy = call <2 x float> Unpack2xHalf(u)
1895 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
SJW2c317da2020-03-23 07:39:13 -05001896 Module &M = *F.getParent();
1897 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001898 auto Index = CI->getOperand(0);
1899 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07001900
Kévin Petite8edce32019-04-10 14:23:32 +01001901 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001902 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001903 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07001904
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001905 auto IndexedPtr = GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04001906 auto Load = new LoadInst(IntTy, IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001907
Kévin Petite8edce32019-04-10 14:23:32 +01001908 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001909 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto6ad93232018-06-07 15:42:58 -07001910
Kévin Petite8edce32019-04-10 14:23:32 +01001911 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07001912
Kévin Petite8edce32019-04-10 14:23:32 +01001913 // Get our final float2.
1914 return CallInst::Create(NewF, Load, "", CI);
1915 });
David Neto6ad93232018-06-07 15:42:58 -07001916}
1917
SJW2c317da2020-03-23 07:39:13 -05001918bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Function &F) {
David Neto6ad93232018-06-07 15:42:58 -07001919
1920 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1921 //
1922 // %u2 = load <2 x i32> %ptr
1923 // %u2xy = extractelement %u2, 0
1924 // %u2zw = extractelement %u2, 1
1925 // %fxy = call <2 x float> Unpack2xHalf(uint)
1926 // %fzw = call <2 x float> Unpack2xHalf(uint)
1927 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
SJW2c317da2020-03-23 07:39:13 -05001928 Module &M = *F.getParent();
1929 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001930 auto Index = CI->getOperand(0);
1931 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07001932
Kévin Petite8edce32019-04-10 14:23:32 +01001933 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001934 auto Int2Ty = FixedVectorType::get(IntTy, 2);
1935 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001936 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07001937
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001938 auto IndexedPtr = GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04001939 auto Load = new LoadInst(Int2Ty, IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001940
Kévin Petite8edce32019-04-10 14:23:32 +01001941 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001942 auto X =
1943 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
1944 auto Y =
1945 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001946
Kévin Petite8edce32019-04-10 14:23:32 +01001947 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001948 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto6ad93232018-06-07 15:42:58 -07001949
Kévin Petite8edce32019-04-10 14:23:32 +01001950 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07001951
Kévin Petite8edce32019-04-10 14:23:32 +01001952 // Get the lower (x & y) components of our final float4.
1953 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001954
Kévin Petite8edce32019-04-10 14:23:32 +01001955 // Get the higher (z & w) components of our final float4.
1956 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001957
Kévin Petite8edce32019-04-10 14:23:32 +01001958 Constant *ShuffleMask[4] = {
1959 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1960 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto6ad93232018-06-07 15:42:58 -07001961
Kévin Petite8edce32019-04-10 14:23:32 +01001962 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001963 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
1964 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01001965 });
David Neto6ad93232018-06-07 15:42:58 -07001966}
1967
SJW2c317da2020-03-23 07:39:13 -05001968bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Function &F, int vec_size) {
1969 switch (vec_size) {
1970 case 0:
1971 return replaceVstoreHalf(F);
1972 case 2:
1973 return replaceVstoreHalf2(F);
1974 case 4:
1975 return replaceVstoreHalf4(F);
1976 default:
1977 llvm_unreachable("Unsupported vstore_half vector size");
1978 break;
1979 }
1980 return false;
1981}
David Neto22f144c2017-06-12 14:26:21 -04001982
SJW2c317da2020-03-23 07:39:13 -05001983bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Function &F) {
1984 Module &M = *F.getParent();
1985 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001986 // The value to store.
1987 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001988
Kévin Petite8edce32019-04-10 14:23:32 +01001989 // The index argument from vstore_half.
1990 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001991
Kévin Petite8edce32019-04-10 14:23:32 +01001992 // The pointer argument from vstore_half.
1993 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04001994
Kévin Petite8edce32019-04-10 14:23:32 +01001995 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001996 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001997 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
1998 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04001999
Kévin Petite8edce32019-04-10 14:23:32 +01002000 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05002001 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04002002
Kévin Petite8edce32019-04-10 14:23:32 +01002003 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002004
Kévin Petite8edce32019-04-10 14:23:32 +01002005 // Insert our value into a float2 so that we can pack it.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002006 auto TempVec = InsertElementInst::Create(
2007 UndefValue::get(Float2Ty), Arg0, ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002008
Kévin Petite8edce32019-04-10 14:23:32 +01002009 // Pack the float2 -> half2 (in an int).
2010 auto X = CallInst::Create(NewF, TempVec, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002011
alan-baker7efcaaa2020-05-06 19:33:27 -04002012 bool supports_16bit_storage = true;
2013 switch (Arg2->getType()->getPointerAddressSpace()) {
2014 case clspv::AddressSpace::Global:
2015 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2016 clspv::Option::StorageClass::kSSBO);
2017 break;
2018 case clspv::AddressSpace::Constant:
2019 if (clspv::Option::ConstantArgsInUniformBuffer())
2020 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2021 clspv::Option::StorageClass::kUBO);
2022 else
2023 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2024 clspv::Option::StorageClass::kSSBO);
2025 break;
2026 default:
2027 // Clspv will emit the Float16 capability if the half type is
2028 // encountered. That capability covers private and local addressspaces.
2029 break;
2030 }
2031
SJW2c317da2020-03-23 07:39:13 -05002032 Value *V = nullptr;
alan-baker7efcaaa2020-05-06 19:33:27 -04002033 if (supports_16bit_storage) {
Kévin Petite8edce32019-04-10 14:23:32 +01002034 auto ShortTy = Type::getInt16Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002035 auto ShortPointerTy =
2036 PointerType::get(ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002037
Kévin Petite8edce32019-04-10 14:23:32 +01002038 // Truncate our i32 to an i16.
2039 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002040
Kévin Petite8edce32019-04-10 14:23:32 +01002041 // Cast the half* pointer to short*.
2042 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002043
Kévin Petite8edce32019-04-10 14:23:32 +01002044 // Index into the correct address of the casted pointer.
2045 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002046
Kévin Petite8edce32019-04-10 14:23:32 +01002047 // Store to the int* we casted to.
SJW2c317da2020-03-23 07:39:13 -05002048 V = new StoreInst(Trunc, Index, CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002049 } else {
2050 // We can only write to 32-bit aligned words.
2051 //
2052 // Assuming base is aligned to 32-bits, replace the equivalent of
2053 // vstore_half(value, index, base)
2054 // with:
2055 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2056 // uint32_t write_to_upper_half = index & 1u;
2057 // uint32_t shift = write_to_upper_half << 4;
2058 //
2059 // // Pack the float value as a half number in bottom 16 bits
2060 // // of an i32.
2061 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2062 //
2063 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2064 // ^ ((packed & 0xffff) << shift)
2065 // // We only need relaxed consistency, but OpenCL 1.2 only has
2066 // // sequentially consistent atomics.
2067 // // TODO(dneto): Use relaxed consistency.
2068 // atomic_xor(target_ptr, xor_value)
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002069 auto IntPointerTy =
2070 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002071
Kévin Petite8edce32019-04-10 14:23:32 +01002072 auto Four = ConstantInt::get(IntTy, 4);
2073 auto FFFF = ConstantInt::get(IntTy, 0xffff);
David Neto17852de2017-05-29 17:29:31 -04002074
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002075 auto IndexIsOdd =
2076 BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002077 // Compute index / 2
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002078 auto IndexIntoI32 =
2079 BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2080 auto BaseI32Ptr =
2081 CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2082 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32,
2083 "base_i32_ptr", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04002084 auto CurrentValue = new LoadInst(IntTy, OutPtr, "current_value", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002085 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002086 auto MaskBitsToWrite =
2087 BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2088 auto MaskedCurrent = BinaryOperator::CreateAnd(
2089 MaskBitsToWrite, CurrentValue, "masked_current", CI);
David Neto17852de2017-05-29 17:29:31 -04002090
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002091 auto XLowerBits =
2092 BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2093 auto NewBitsToWrite =
2094 BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2095 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite,
2096 "value_to_xor", CI);
David Neto17852de2017-05-29 17:29:31 -04002097
Kévin Petite8edce32019-04-10 14:23:32 +01002098 // Generate the call to atomi_xor.
2099 SmallVector<Type *, 5> ParamTypes;
2100 // The pointer type.
2101 ParamTypes.push_back(IntPointerTy);
2102 // The Types for memory scope, semantics, and value.
2103 ParamTypes.push_back(IntTy);
2104 ParamTypes.push_back(IntTy);
2105 ParamTypes.push_back(IntTy);
2106 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2107 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
David Neto17852de2017-05-29 17:29:31 -04002108
Kévin Petite8edce32019-04-10 14:23:32 +01002109 const auto ConstantScopeDevice =
2110 ConstantInt::get(IntTy, spv::ScopeDevice);
2111 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2112 // (SPIR-V Workgroup).
2113 const auto AddrSpaceSemanticsBits =
2114 IntPointerTy->getPointerAddressSpace() == 1
2115 ? spv::MemorySemanticsUniformMemoryMask
2116 : spv::MemorySemanticsWorkgroupMemoryMask;
David Neto17852de2017-05-29 17:29:31 -04002117
Kévin Petite8edce32019-04-10 14:23:32 +01002118 // We're using relaxed consistency here.
2119 const auto ConstantMemorySemantics =
2120 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2121 AddrSpaceSemanticsBits);
David Neto17852de2017-05-29 17:29:31 -04002122
Kévin Petite8edce32019-04-10 14:23:32 +01002123 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2124 ConstantMemorySemantics, ValueToXor};
2125 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
SJW2c317da2020-03-23 07:39:13 -05002126
2127 // Return a Nop so the old Call is removed
2128 Function *donothing = Intrinsic::getDeclaration(&M, Intrinsic::donothing);
2129 V = CallInst::Create(donothing, {}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002130 }
David Neto22f144c2017-06-12 14:26:21 -04002131
SJW2c317da2020-03-23 07:39:13 -05002132 return V;
Kévin Petite8edce32019-04-10 14:23:32 +01002133 });
David Neto22f144c2017-06-12 14:26:21 -04002134}
2135
SJW2c317da2020-03-23 07:39:13 -05002136bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Function &F) {
2137 Module &M = *F.getParent();
2138 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002139 // The value to store.
2140 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002141
Kévin Petite8edce32019-04-10 14:23:32 +01002142 // The index argument from vstore_half.
2143 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002144
Kévin Petite8edce32019-04-10 14:23:32 +01002145 // The pointer argument from vstore_half.
2146 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002147
Kévin Petite8edce32019-04-10 14:23:32 +01002148 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04002149 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002150 auto NewPointerTy =
2151 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002152 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002153
Kévin Petite8edce32019-04-10 14:23:32 +01002154 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05002155 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04002156
Kévin Petite8edce32019-04-10 14:23:32 +01002157 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002158
Kévin Petite8edce32019-04-10 14:23:32 +01002159 // Turn the packed x & y into the final packing.
2160 auto X = CallInst::Create(NewF, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002161
Kévin Petite8edce32019-04-10 14:23:32 +01002162 // Cast the half* pointer to int*.
2163 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002164
Kévin Petite8edce32019-04-10 14:23:32 +01002165 // Index into the correct address of the casted pointer.
2166 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002167
Kévin Petite8edce32019-04-10 14:23:32 +01002168 // Store to the int* we casted to.
2169 return new StoreInst(X, Index, CI);
2170 });
David Neto22f144c2017-06-12 14:26:21 -04002171}
2172
SJW2c317da2020-03-23 07:39:13 -05002173bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Function &F) {
2174 Module &M = *F.getParent();
2175 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002176 // The value to store.
2177 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002178
Kévin Petite8edce32019-04-10 14:23:32 +01002179 // The index argument from vstore_half.
2180 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002181
Kévin Petite8edce32019-04-10 14:23:32 +01002182 // The pointer argument from vstore_half.
2183 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002184
Kévin Petite8edce32019-04-10 14:23:32 +01002185 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04002186 auto Int2Ty = FixedVectorType::get(IntTy, 2);
2187 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002188 auto NewPointerTy =
2189 PointerType::get(Int2Ty, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002190 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002191
Kévin Petite8edce32019-04-10 14:23:32 +01002192 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2193 ConstantInt::get(IntTy, 1)};
David Neto22f144c2017-06-12 14:26:21 -04002194
Kévin Petite8edce32019-04-10 14:23:32 +01002195 // Extract out the x & y components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002196 auto Lo = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2197 ConstantVector::get(LoShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002198
Kévin Petite8edce32019-04-10 14:23:32 +01002199 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2200 ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002201
Kévin Petite8edce32019-04-10 14:23:32 +01002202 // Extract out the z & w components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002203 auto Hi = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2204 ConstantVector::get(HiShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002205
Kévin Petite8edce32019-04-10 14:23:32 +01002206 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05002207 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04002208
Kévin Petite8edce32019-04-10 14:23:32 +01002209 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002210
Kévin Petite8edce32019-04-10 14:23:32 +01002211 // Turn the packed x & y into the final component of our int2.
2212 auto X = CallInst::Create(NewF, Lo, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002213
Kévin Petite8edce32019-04-10 14:23:32 +01002214 // Turn the packed z & w into the final component of our int2.
2215 auto Y = CallInst::Create(NewF, Hi, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002216
Kévin Petite8edce32019-04-10 14:23:32 +01002217 auto Combine = InsertElementInst::Create(
2218 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002219 Combine = InsertElementInst::Create(Combine, Y, ConstantInt::get(IntTy, 1),
2220 "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002221
Kévin Petite8edce32019-04-10 14:23:32 +01002222 // Cast the half* pointer to int2*.
2223 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002224
Kévin Petite8edce32019-04-10 14:23:32 +01002225 // Index into the correct address of the casted pointer.
2226 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002227
Kévin Petite8edce32019-04-10 14:23:32 +01002228 // Store to the int2* we casted to.
2229 return new StoreInst(Combine, Index, CI);
2230 });
David Neto22f144c2017-06-12 14:26:21 -04002231}
2232
SJW2c317da2020-03-23 07:39:13 -05002233bool ReplaceOpenCLBuiltinPass::replaceHalfReadImage(Function &F) {
2234 // convert half to float
2235 Module &M = *F.getParent();
2236 return replaceCallsWithValue(F, [&](CallInst *CI) {
2237 SmallVector<Type *, 3> types;
2238 SmallVector<Value *, 3> args;
2239 for (auto i = 0; i < CI->getNumArgOperands(); ++i) {
2240 types.push_back(CI->getArgOperand(i)->getType());
2241 args.push_back(CI->getArgOperand(i));
alan-bakerf7e17cb2020-01-02 07:29:59 -05002242 }
alan-bakerf7e17cb2020-01-02 07:29:59 -05002243
alan-baker5a8c3be2020-09-09 13:44:26 -04002244 auto NewFType =
2245 FunctionType::get(FixedVectorType::get(Type::getFloatTy(M.getContext()),
2246 cast<VectorType>(CI->getType())
2247 ->getElementCount()
2248 .getKnownMinValue()),
2249 types, false);
SJW2c317da2020-03-23 07:39:13 -05002250
SJW61531372020-06-09 07:31:08 -05002251 std::string NewFName =
2252 Builtins::GetMangledFunctionName("read_imagef", NewFType);
SJW2c317da2020-03-23 07:39:13 -05002253
2254 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
2255
2256 auto NewCI = CallInst::Create(NewF, args, "", CI);
2257
2258 // Convert to the half type.
2259 return CastInst::CreateFPCast(NewCI, CI->getType(), "", CI);
2260 });
alan-bakerf7e17cb2020-01-02 07:29:59 -05002261}
2262
SJW2c317da2020-03-23 07:39:13 -05002263bool ReplaceOpenCLBuiltinPass::replaceHalfWriteImage(Function &F) {
2264 // convert half to float
2265 Module &M = *F.getParent();
2266 return replaceCallsWithValue(F, [&](CallInst *CI) {
2267 SmallVector<Type *, 3> types(3);
2268 SmallVector<Value *, 3> args(3);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002269
SJW2c317da2020-03-23 07:39:13 -05002270 // Image
2271 types[0] = CI->getArgOperand(0)->getType();
2272 args[0] = CI->getArgOperand(0);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002273
SJW2c317da2020-03-23 07:39:13 -05002274 // Coord
2275 types[1] = CI->getArgOperand(1)->getType();
2276 args[1] = CI->getArgOperand(1);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002277
SJW2c317da2020-03-23 07:39:13 -05002278 // Data
alan-baker5a8c3be2020-09-09 13:44:26 -04002279 types[2] =
2280 FixedVectorType::get(Type::getFloatTy(M.getContext()),
2281 cast<VectorType>(CI->getArgOperand(2)->getType())
2282 ->getElementCount()
2283 .getKnownMinValue());
alan-bakerf7e17cb2020-01-02 07:29:59 -05002284
SJW2c317da2020-03-23 07:39:13 -05002285 auto NewFType =
2286 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002287
SJW61531372020-06-09 07:31:08 -05002288 std::string NewFName =
2289 Builtins::GetMangledFunctionName("write_imagef", NewFType);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002290
SJW2c317da2020-03-23 07:39:13 -05002291 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002292
SJW2c317da2020-03-23 07:39:13 -05002293 // Convert data to the float type.
2294 auto Cast = CastInst::CreateFPCast(CI->getArgOperand(2), types[2], "", CI);
2295 args[2] = Cast;
alan-bakerf7e17cb2020-01-02 07:29:59 -05002296
SJW2c317da2020-03-23 07:39:13 -05002297 return CallInst::Create(NewF, args, "", CI);
2298 });
alan-bakerf7e17cb2020-01-02 07:29:59 -05002299}
2300
SJW2c317da2020-03-23 07:39:13 -05002301bool ReplaceOpenCLBuiltinPass::replaceSampledReadImageWithIntCoords(
2302 Function &F) {
2303 // convert read_image with int coords to float coords
2304 Module &M = *F.getParent();
2305 return replaceCallsWithValue(F, [&](CallInst *CI) {
2306 // The image.
2307 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002308
SJW2c317da2020-03-23 07:39:13 -05002309 // The sampler.
2310 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002311
SJW2c317da2020-03-23 07:39:13 -05002312 // The coordinate (integer type that we can't handle).
2313 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002314
SJW2c317da2020-03-23 07:39:13 -05002315 uint32_t dim = clspv::ImageDimensionality(Arg0->getType());
2316 uint32_t components =
2317 dim + (clspv::IsArrayImageType(Arg0->getType()) ? 1 : 0);
2318 Type *float_ty = nullptr;
2319 if (components == 1) {
2320 float_ty = Type::getFloatTy(M.getContext());
2321 } else {
alan-baker5a8c3be2020-09-09 13:44:26 -04002322 float_ty = FixedVectorType::get(Type::getFloatTy(M.getContext()),
2323 cast<VectorType>(Arg2->getType())
2324 ->getElementCount()
2325 .getKnownMinValue());
David Neto22f144c2017-06-12 14:26:21 -04002326 }
David Neto22f144c2017-06-12 14:26:21 -04002327
SJW2c317da2020-03-23 07:39:13 -05002328 auto NewFType = FunctionType::get(
2329 CI->getType(), {Arg0->getType(), Arg1->getType(), float_ty}, false);
2330
2331 std::string NewFName = F.getName().str();
2332 NewFName[NewFName.length() - 1] = 'f';
2333
2334 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
2335
2336 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, float_ty, "", CI);
2337
2338 return CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2339 });
David Neto22f144c2017-06-12 14:26:21 -04002340}
2341
SJW2c317da2020-03-23 07:39:13 -05002342bool ReplaceOpenCLBuiltinPass::replaceAtomics(Function &F, spv::Op Op) {
2343 return replaceCallsWithValue(F, [&](CallInst *CI) {
2344 auto IntTy = Type::getInt32Ty(F.getContext());
David Neto22f144c2017-06-12 14:26:21 -04002345
SJW2c317da2020-03-23 07:39:13 -05002346 // We need to map the OpenCL constants to the SPIR-V equivalents.
2347 const auto ConstantScopeDevice = ConstantInt::get(IntTy, spv::ScopeDevice);
2348 const auto ConstantMemorySemantics = ConstantInt::get(
2349 IntTy, spv::MemorySemanticsUniformMemoryMask |
2350 spv::MemorySemanticsSequentiallyConsistentMask);
David Neto22f144c2017-06-12 14:26:21 -04002351
SJW2c317da2020-03-23 07:39:13 -05002352 SmallVector<Value *, 5> Params;
David Neto22f144c2017-06-12 14:26:21 -04002353
SJW2c317da2020-03-23 07:39:13 -05002354 // The pointer.
2355 Params.push_back(CI->getArgOperand(0));
David Neto22f144c2017-06-12 14:26:21 -04002356
SJW2c317da2020-03-23 07:39:13 -05002357 // The memory scope.
2358 Params.push_back(ConstantScopeDevice);
David Neto22f144c2017-06-12 14:26:21 -04002359
SJW2c317da2020-03-23 07:39:13 -05002360 // The memory semantics.
2361 Params.push_back(ConstantMemorySemantics);
David Neto22f144c2017-06-12 14:26:21 -04002362
SJW2c317da2020-03-23 07:39:13 -05002363 if (2 < CI->getNumArgOperands()) {
2364 // The unequal memory semantics.
2365 Params.push_back(ConstantMemorySemantics);
David Neto22f144c2017-06-12 14:26:21 -04002366
SJW2c317da2020-03-23 07:39:13 -05002367 // The value.
2368 Params.push_back(CI->getArgOperand(2));
David Neto22f144c2017-06-12 14:26:21 -04002369
SJW2c317da2020-03-23 07:39:13 -05002370 // The comparator.
2371 Params.push_back(CI->getArgOperand(1));
2372 } else if (1 < CI->getNumArgOperands()) {
2373 // The value.
2374 Params.push_back(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -04002375 }
David Neto22f144c2017-06-12 14:26:21 -04002376
SJW2c317da2020-03-23 07:39:13 -05002377 return clspv::InsertSPIRVOp(CI, Op, {}, CI->getType(), Params);
2378 });
David Neto22f144c2017-06-12 14:26:21 -04002379}
2380
SJW2c317da2020-03-23 07:39:13 -05002381bool ReplaceOpenCLBuiltinPass::replaceAtomics(Function &F,
2382 llvm::AtomicRMWInst::BinOp Op) {
2383 return replaceCallsWithValue(F, [&](CallInst *CI) {
alan-bakerd0eb9052020-07-07 13:12:01 -04002384 auto align = F.getParent()->getDataLayout().getABITypeAlign(
2385 CI->getArgOperand(1)->getType());
SJW2c317da2020-03-23 07:39:13 -05002386 return new AtomicRMWInst(Op, CI->getArgOperand(0), CI->getArgOperand(1),
alan-bakerd0eb9052020-07-07 13:12:01 -04002387 align, AtomicOrdering::SequentiallyConsistent,
SJW2c317da2020-03-23 07:39:13 -05002388 SyncScope::System, CI);
2389 });
2390}
David Neto22f144c2017-06-12 14:26:21 -04002391
SJW2c317da2020-03-23 07:39:13 -05002392bool ReplaceOpenCLBuiltinPass::replaceCross(Function &F) {
2393 Module &M = *F.getParent();
2394 return replaceCallsWithValue(F, [&](CallInst *CI) {
David Neto22f144c2017-06-12 14:26:21 -04002395 auto IntTy = Type::getInt32Ty(M.getContext());
2396 auto FloatTy = Type::getFloatTy(M.getContext());
2397
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002398 Constant *DownShuffleMask[3] = {ConstantInt::get(IntTy, 0),
2399 ConstantInt::get(IntTy, 1),
2400 ConstantInt::get(IntTy, 2)};
David Neto22f144c2017-06-12 14:26:21 -04002401
2402 Constant *UpShuffleMask[4] = {
2403 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2404 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2405
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002406 Constant *FloatVec[3] = {ConstantFP::get(FloatTy, 0.0f),
2407 UndefValue::get(FloatTy),
2408 UndefValue::get(FloatTy)};
David Neto22f144c2017-06-12 14:26:21 -04002409
Kévin Petite8edce32019-04-10 14:23:32 +01002410 auto Vec4Ty = CI->getArgOperand(0)->getType();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002411 auto Arg0 =
2412 new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty),
2413 ConstantVector::get(DownShuffleMask), "", CI);
2414 auto Arg1 =
2415 new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty),
2416 ConstantVector::get(DownShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002417 auto Vec3Ty = Arg0->getType();
David Neto22f144c2017-06-12 14:26:21 -04002418
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002419 auto NewFType = FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
SJW61531372020-06-09 07:31:08 -05002420 auto NewFName = Builtins::GetMangledFunctionName("cross", NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002421
SJW61531372020-06-09 07:31:08 -05002422 auto Cross3Func = M.getOrInsertFunction(NewFName, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002423
Kévin Petite8edce32019-04-10 14:23:32 +01002424 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002425
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002426 return new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec),
2427 ConstantVector::get(UpShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002428 });
David Neto22f144c2017-06-12 14:26:21 -04002429}
David Neto62653202017-10-16 19:05:18 -04002430
SJW2c317da2020-03-23 07:39:13 -05002431bool ReplaceOpenCLBuiltinPass::replaceFract(Function &F, int vec_size) {
David Neto62653202017-10-16 19:05:18 -04002432 // OpenCL's float result = fract(float x, float* ptr)
2433 //
2434 // In the LLVM domain:
2435 //
2436 // %floor_result = call spir_func float @floor(float %x)
2437 // store float %floor_result, float * %ptr
2438 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2439 // %result = call spir_func float
2440 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2441 //
2442 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2443 // and clspv.fract occur in the SPIR-V generator pass:
2444 //
2445 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2446 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2447 // ...
2448 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2449 // OpStore %ptr %floor_result
2450 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2451 // %fract_result = OpExtInst %float
Marco Antognini55d51862020-07-21 17:50:07 +01002452 // %glsl_ext Nmin %fract_intermediate %just_under_1
David Neto62653202017-10-16 19:05:18 -04002453
David Neto62653202017-10-16 19:05:18 -04002454 using std::string;
2455
2456 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2457 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
David Neto62653202017-10-16 19:05:18 -04002458
SJW2c317da2020-03-23 07:39:13 -05002459 Module &M = *F.getParent();
2460 return replaceCallsWithValue(F, [&](CallInst *CI) {
David Neto62653202017-10-16 19:05:18 -04002461
SJW2c317da2020-03-23 07:39:13 -05002462 // This is either float or a float vector. All the float-like
2463 // types are this type.
2464 auto result_ty = F.getReturnType();
2465
SJW61531372020-06-09 07:31:08 -05002466 std::string fmin_name = Builtins::GetMangledFunctionName("fmin", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002467 Function *fmin_fn = M.getFunction(fmin_name);
2468 if (!fmin_fn) {
2469 // Make the fmin function.
2470 FunctionType *fn_ty =
2471 FunctionType::get(result_ty, {result_ty, result_ty}, false);
2472 fmin_fn =
2473 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
2474 fmin_fn->addFnAttr(Attribute::ReadNone);
2475 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2476 }
2477
SJW61531372020-06-09 07:31:08 -05002478 std::string floor_name =
2479 Builtins::GetMangledFunctionName("floor", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002480 Function *floor_fn = M.getFunction(floor_name);
2481 if (!floor_fn) {
2482 // Make the floor function.
2483 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2484 floor_fn =
2485 cast<Function>(M.getOrInsertFunction(floor_name, fn_ty).getCallee());
2486 floor_fn->addFnAttr(Attribute::ReadNone);
2487 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2488 }
2489
SJW61531372020-06-09 07:31:08 -05002490 std::string clspv_fract_name =
2491 Builtins::GetMangledFunctionName("clspv.fract", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002492 Function *clspv_fract_fn = M.getFunction(clspv_fract_name);
2493 if (!clspv_fract_fn) {
2494 // Make the clspv_fract function.
2495 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2496 clspv_fract_fn = cast<Function>(
2497 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
2498 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2499 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2500 }
2501
2502 // Number of significant significand bits, whether represented or not.
2503 unsigned num_significand_bits;
2504 switch (result_ty->getScalarType()->getTypeID()) {
2505 case Type::HalfTyID:
2506 num_significand_bits = 11;
2507 break;
2508 case Type::FloatTyID:
2509 num_significand_bits = 24;
2510 break;
2511 case Type::DoubleTyID:
2512 num_significand_bits = 53;
2513 break;
2514 default:
2515 llvm_unreachable("Unhandled float type when processing fract builtin");
2516 break;
2517 }
2518 // Beware that the disassembler displays this value as
2519 // OpConstant %float 1
2520 // which is not quite right.
2521 const double kJustUnderOneScalar =
2522 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2523
2524 Constant *just_under_one =
2525 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2526 if (result_ty->isVectorTy()) {
2527 just_under_one = ConstantVector::getSplat(
alan-baker931253b2020-08-20 17:15:38 -04002528 cast<VectorType>(result_ty)->getElementCount(), just_under_one);
SJW2c317da2020-03-23 07:39:13 -05002529 }
2530
2531 IRBuilder<> Builder(CI);
2532
2533 auto arg = CI->getArgOperand(0);
2534 auto ptr = CI->getArgOperand(1);
2535
2536 // Compute floor result and store it.
2537 auto floor = Builder.CreateCall(floor_fn, {arg});
2538 Builder.CreateStore(floor, ptr);
2539
2540 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2541 auto fract_result =
2542 Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2543
2544 return fract_result;
2545 });
David Neto62653202017-10-16 19:05:18 -04002546}
alan-bakera52b7312020-10-26 08:58:51 -04002547
Kévin Petit8576f682020-11-02 14:51:32 +00002548bool ReplaceOpenCLBuiltinPass::replaceHadd(Function &F, bool is_signed,
alan-bakerb6da5132020-10-29 15:59:06 -04002549 Instruction::BinaryOps join_opcode) {
Kévin Petit8576f682020-11-02 14:51:32 +00002550 return replaceCallsWithValue(F, [is_signed, join_opcode](CallInst *Call) {
alan-bakerb6da5132020-10-29 15:59:06 -04002551 // a_shr = a >> 1
2552 // b_shr = b >> 1
2553 // add1 = a_shr + b_shr
2554 // join = a |join_opcode| b
2555 // and = join & 1
2556 // add = add1 + and
2557 const auto a = Call->getArgOperand(0);
2558 const auto b = Call->getArgOperand(1);
2559 IRBuilder<> builder(Call);
Kévin Petit8576f682020-11-02 14:51:32 +00002560 Value *a_shift, *b_shift;
2561 if (is_signed) {
2562 a_shift = builder.CreateAShr(a, 1);
2563 b_shift = builder.CreateAShr(b, 1);
2564 } else {
2565 a_shift = builder.CreateLShr(a, 1);
2566 b_shift = builder.CreateLShr(b, 1);
2567 }
alan-bakerb6da5132020-10-29 15:59:06 -04002568 auto add = builder.CreateAdd(a_shift, b_shift);
2569 auto join = BinaryOperator::Create(join_opcode, a, b, "", Call);
2570 auto constant_one = ConstantInt::get(a->getType(), 1);
2571 auto and_bit = builder.CreateAnd(join, constant_one);
2572 return builder.CreateAdd(add, and_bit);
2573 });
2574}
2575
alan-baker3f1bf492020-11-05 09:07:36 -05002576bool ReplaceOpenCLBuiltinPass::replaceAddSubSat(Function &F, bool is_signed,
2577 bool is_add) {
2578 return replaceCallsWithValue(F, [&F, this, is_signed,
2579 is_add](CallInst *Call) {
2580 auto ty = Call->getType();
2581 auto a = Call->getArgOperand(0);
2582 auto b = Call->getArgOperand(1);
2583 IRBuilder<> builder(Call);
alan-bakera52b7312020-10-26 08:58:51 -04002584 if (is_signed) {
2585 unsigned bitwidth = ty->getScalarSizeInBits();
2586 if (bitwidth < 32) {
alan-baker3f1bf492020-11-05 09:07:36 -05002587 unsigned extended_width = bitwidth << 1;
2588 Type *extended_ty =
2589 IntegerType::get(Call->getContext(), extended_width);
2590 Constant *min = ConstantInt::get(
alan-bakera52b7312020-10-26 08:58:51 -04002591 Call->getContext(),
alan-baker3f1bf492020-11-05 09:07:36 -05002592 APInt::getSignedMinValue(bitwidth).sext(extended_width));
2593 Constant *max = ConstantInt::get(
alan-bakera52b7312020-10-26 08:58:51 -04002594 Call->getContext(),
alan-baker3f1bf492020-11-05 09:07:36 -05002595 APInt::getSignedMaxValue(bitwidth).sext(extended_width));
alan-bakera52b7312020-10-26 08:58:51 -04002596 // Don't use the type in GetMangledFunctionName to ensure we get
2597 // signed parameters.
2598 std::string sclamp_name = Builtins::GetMangledFunctionName("clamp");
alan-bakera52b7312020-10-26 08:58:51 -04002599 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
alan-baker3f1bf492020-11-05 09:07:36 -05002600 extended_ty = VectorType::get(extended_ty, vec_ty->getElementCount());
2601 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2602 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2603 unsigned vec_width = vec_ty->getElementCount().getKnownMinValue();
2604 if (extended_width == 32) {
alan-bakera52b7312020-10-26 08:58:51 -04002605 sclamp_name += "Dv" + std::to_string(vec_width) + "_iS_S_";
alan-bakera52b7312020-10-26 08:58:51 -04002606 } else {
2607 sclamp_name += "Dv" + std::to_string(vec_width) + "_sS_S_";
2608 }
alan-baker3f1bf492020-11-05 09:07:36 -05002609 } else {
2610 if (extended_width == 32) {
2611 sclamp_name += "iii";
2612 } else {
2613 sclamp_name += "sss";
2614 }
alan-bakera52b7312020-10-26 08:58:51 -04002615 }
alan-baker3f1bf492020-11-05 09:07:36 -05002616
2617 auto sext_a = builder.CreateSExt(a, extended_ty);
2618 auto sext_b = builder.CreateSExt(b, extended_ty);
2619 Value *op = nullptr;
2620 // Extended operations won't wrap.
2621 if (is_add)
2622 op = builder.CreateAdd(sext_a, sext_b, "", true, true);
2623 else
2624 op = builder.CreateSub(sext_a, sext_b, "", true, true);
2625 auto clamp_ty = FunctionType::get(
2626 extended_ty, {extended_ty, extended_ty, extended_ty}, false);
2627 auto callee = F.getParent()->getOrInsertFunction(sclamp_name, clamp_ty);
2628 auto clamp = builder.CreateCall(callee, {op, min, max});
2629 return builder.CreateTrunc(clamp, ty);
alan-bakera52b7312020-10-26 08:58:51 -04002630 } else {
alan-baker3f1bf492020-11-05 09:07:36 -05002631 // Add:
2632 // c = a + b
alan-bakera52b7312020-10-26 08:58:51 -04002633 // if (b < 0)
2634 // c = c > a ? min : c;
2635 // else
alan-baker3f1bf492020-11-05 09:07:36 -05002636 // c = c < a ? max : c;
alan-bakera52b7312020-10-26 08:58:51 -04002637 //
alan-baker3f1bf492020-11-05 09:07:36 -05002638 // Sub:
2639 // c = a - b;
2640 // if (b < 0)
2641 // c = c < a ? max : c;
2642 // else
2643 // c = c > a ? min : c;
2644 Constant *min = ConstantInt::get(Call->getContext(),
2645 APInt::getSignedMinValue(bitwidth));
2646 Constant *max = ConstantInt::get(Call->getContext(),
2647 APInt::getSignedMaxValue(bitwidth));
alan-bakera52b7312020-10-26 08:58:51 -04002648 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2649 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2650 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2651 }
alan-baker3f1bf492020-11-05 09:07:36 -05002652 Value *op = nullptr;
2653 if (is_add) {
2654 op = builder.CreateAdd(a, b);
2655 } else {
2656 op = builder.CreateSub(a, b);
2657 }
2658 auto b_lt_0 = builder.CreateICmpSLT(b, Constant::getNullValue(ty));
2659 auto op_gt_a = builder.CreateICmpSGT(op, a);
2660 auto op_lt_a = builder.CreateICmpSLT(op, a);
2661 auto neg_cmp = is_add ? op_gt_a : op_lt_a;
2662 auto pos_cmp = is_add ? op_lt_a : op_gt_a;
2663 auto neg_value = is_add ? min : max;
2664 auto pos_value = is_add ? max : min;
2665 auto neg_clamp = builder.CreateSelect(neg_cmp, neg_value, op);
2666 auto pos_clamp = builder.CreateSelect(pos_cmp, pos_value, op);
2667 return builder.CreateSelect(b_lt_0, neg_clamp, pos_clamp);
alan-bakera52b7312020-10-26 08:58:51 -04002668 }
2669 } else {
alan-baker3f1bf492020-11-05 09:07:36 -05002670 // Replace with OpIAddCarry/OpISubBorrow and clamp to max/0 on a
2671 // carr/borrow.
2672 spv::Op op = is_add ? spv::OpIAddCarry : spv::OpISubBorrow;
2673 auto clamp_value =
2674 is_add ? Constant::getAllOnesValue(ty) : Constant::getNullValue(ty);
2675 auto struct_ty = GetPairStruct(ty);
2676 auto call =
2677 InsertSPIRVOp(Call, op, {Attribute::ReadNone}, struct_ty, {a, b});
2678 auto add_sub = builder.CreateExtractValue(call, {0});
2679 auto carry_borrow = builder.CreateExtractValue(call, {1});
2680 auto cmp = builder.CreateICmpEQ(carry_borrow, Constant::getNullValue(ty));
2681 return builder.CreateSelect(cmp, add_sub, clamp_value);
alan-bakera52b7312020-10-26 08:58:51 -04002682 }
alan-bakera52b7312020-10-26 08:58:51 -04002683 });
2684}
alan-baker4986eff2020-10-29 13:38:00 -04002685
2686bool ReplaceOpenCLBuiltinPass::replaceAtomicLoad(Function &F) {
2687 return replaceCallsWithValue(F, [](CallInst *Call) {
2688 auto pointer = Call->getArgOperand(0);
2689 // Clang emits an address space cast to the generic address space. Skip the
2690 // cast and use the input directly.
2691 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2692 pointer = cast->getPointerOperand();
2693 }
2694 Value *order_arg =
2695 Call->getNumArgOperands() > 1 ? Call->getArgOperand(1) : nullptr;
2696 Value *scope_arg =
2697 Call->getNumArgOperands() > 2 ? Call->getArgOperand(2) : nullptr;
2698 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2699 clspv::AddressSpace::Global;
2700 auto order = MemoryOrderSemantics(order_arg, is_global, Call,
2701 spv::MemorySemanticsAcquireMask);
2702 auto scope = MemoryScope(scope_arg, is_global, Call);
2703 return InsertSPIRVOp(Call, spv::OpAtomicLoad, {Attribute::Convergent},
2704 Call->getType(), {pointer, scope, order});
2705 });
2706}
2707
2708bool ReplaceOpenCLBuiltinPass::replaceExplicitAtomics(
2709 Function &F, spv::Op Op, spv::MemorySemanticsMask semantics) {
2710 return replaceCallsWithValue(F, [Op, semantics](CallInst *Call) {
2711 auto pointer = Call->getArgOperand(0);
2712 // Clang emits an address space cast to the generic address space. Skip the
2713 // cast and use the input directly.
2714 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2715 pointer = cast->getPointerOperand();
2716 }
2717 Value *value = Call->getArgOperand(1);
2718 Value *order_arg =
2719 Call->getNumArgOperands() > 2 ? Call->getArgOperand(2) : nullptr;
2720 Value *scope_arg =
2721 Call->getNumArgOperands() > 3 ? Call->getArgOperand(3) : nullptr;
2722 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2723 clspv::AddressSpace::Global;
2724 auto scope = MemoryScope(scope_arg, is_global, Call);
2725 auto order = MemoryOrderSemantics(order_arg, is_global, Call, semantics);
2726 return InsertSPIRVOp(Call, Op, {Attribute::Convergent}, Call->getType(),
2727 {pointer, scope, order, value});
2728 });
2729}
2730
2731bool ReplaceOpenCLBuiltinPass::replaceAtomicCompareExchange(Function &F) {
2732 return replaceCallsWithValue(F, [](CallInst *Call) {
2733 auto pointer = Call->getArgOperand(0);
2734 // Clang emits an address space cast to the generic address space. Skip the
2735 // cast and use the input directly.
2736 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2737 pointer = cast->getPointerOperand();
2738 }
2739 auto expected = Call->getArgOperand(1);
2740 if (auto cast = dyn_cast<AddrSpaceCastOperator>(expected)) {
2741 expected = cast->getPointerOperand();
2742 }
2743 auto value = Call->getArgOperand(2);
2744 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2745 clspv::AddressSpace::Global;
2746 Value *success_arg =
2747 Call->getNumArgOperands() > 3 ? Call->getArgOperand(3) : nullptr;
2748 Value *failure_arg =
2749 Call->getNumArgOperands() > 4 ? Call->getArgOperand(4) : nullptr;
2750 Value *scope_arg =
2751 Call->getNumArgOperands() > 5 ? Call->getArgOperand(5) : nullptr;
2752 auto scope = MemoryScope(scope_arg, is_global, Call);
2753 auto success = MemoryOrderSemantics(success_arg, is_global, Call,
2754 spv::MemorySemanticsAcquireReleaseMask);
2755 auto failure = MemoryOrderSemantics(failure_arg, is_global, Call,
2756 spv::MemorySemanticsAcquireMask);
2757
2758 // If the value pointed to by |expected| equals the value pointed to by
2759 // |pointer|, |value| is written into |pointer|, otherwise the value in
2760 // |pointer| is written into |expected|. In order to avoid extra stores,
2761 // the basic block with the original atomic is split and the store is
2762 // performed in the |then| block. The condition is the inversion of the
2763 // comparison result.
2764 IRBuilder<> builder(Call);
2765 auto load = builder.CreateLoad(expected);
2766 auto cmp_xchg = InsertSPIRVOp(
2767 Call, spv::OpAtomicCompareExchange, {Attribute::Convergent},
2768 value->getType(), {pointer, scope, success, failure, value, load});
2769 auto cmp = builder.CreateICmpEQ(cmp_xchg, load);
2770 auto not_cmp = builder.CreateNot(cmp);
2771 auto then_branch = SplitBlockAndInsertIfThen(not_cmp, Call, false);
2772 builder.SetInsertPoint(then_branch);
2773 builder.CreateStore(cmp_xchg, expected);
2774 return cmp;
2775 });
2776}
alan-bakercc2bafb2020-11-02 08:30:18 -05002777
alan-baker2cecaa72020-11-05 14:05:20 -05002778bool ReplaceOpenCLBuiltinPass::replaceCountZeroes(Function &F, bool leading) {
alan-bakercc2bafb2020-11-02 08:30:18 -05002779 if (!isa<IntegerType>(F.getReturnType()->getScalarType()))
2780 return false;
2781
2782 auto bitwidth = F.getReturnType()->getScalarSizeInBits();
2783 if (bitwidth == 32 || bitwidth > 64)
2784 return false;
2785
alan-baker2cecaa72020-11-05 14:05:20 -05002786 return replaceCallsWithValue(F, [&F, bitwidth, leading](CallInst *Call) {
alan-bakercc2bafb2020-11-02 08:30:18 -05002787 auto in = Call->getArgOperand(0);
2788 IRBuilder<> builder(Call);
2789 auto int32_ty = builder.getInt32Ty();
2790 Type *ty = int32_ty;
alan-baker2cecaa72020-11-05 14:05:20 -05002791 Constant *c32 = builder.getInt32(32);
alan-bakercc2bafb2020-11-02 08:30:18 -05002792 if (auto vec_ty = dyn_cast<VectorType>(Call->getType())) {
2793 ty = VectorType::get(ty, vec_ty->getElementCount());
alan-baker2cecaa72020-11-05 14:05:20 -05002794 c32 = ConstantVector::getSplat(vec_ty->getElementCount(), c32);
alan-bakercc2bafb2020-11-02 08:30:18 -05002795 }
alan-baker2cecaa72020-11-05 14:05:20 -05002796 auto func_32bit_ty = FunctionType::get(ty, {ty}, false);
2797 std::string func_32bit_name =
2798 Builtins::GetMangledFunctionName((leading ? "clz" : "ctz"), ty);
2799 auto func_32bit =
2800 F.getParent()->getOrInsertFunction(func_32bit_name, func_32bit_ty);
alan-bakercc2bafb2020-11-02 08:30:18 -05002801 if (bitwidth < 32) {
alan-baker2cecaa72020-11-05 14:05:20 -05002802 // Extend the input to 32-bits and perform a clz/ctz.
alan-bakercc2bafb2020-11-02 08:30:18 -05002803 auto zext = builder.CreateZExt(in, ty);
alan-baker2cecaa72020-11-05 14:05:20 -05002804 Value *call_input = zext;
2805 if (!leading) {
2806 // Or the extended input value with a constant that caps the max to the
2807 // right bitwidth (e.g. 256 for i8 and 65536 for i16).
2808 Constant *mask = builder.getInt32(1 << bitwidth);
2809 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2810 mask = ConstantVector::getSplat(vec_ty->getElementCount(), mask);
2811 }
2812 call_input = builder.CreateOr(zext, mask);
alan-bakercc2bafb2020-11-02 08:30:18 -05002813 }
alan-baker2cecaa72020-11-05 14:05:20 -05002814 auto call = builder.CreateCall(func_32bit, {call_input});
2815 Value *tmp = call;
2816 if (leading) {
2817 // Clz is implemented as 31 - FindUMsb(|zext|), so adjust the result
2818 // the right bitwidth.
2819 Constant *sub_const = builder.getInt32(32 - bitwidth);
2820 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2821 sub_const =
2822 ConstantVector::getSplat(vec_ty->getElementCount(), sub_const);
2823 }
2824 tmp = builder.CreateSub(call, sub_const);
2825 }
2826 // Truncate the intermediate result to the right size.
2827 return builder.CreateTrunc(tmp, Call->getType());
alan-bakercc2bafb2020-11-02 08:30:18 -05002828 } else {
alan-baker2cecaa72020-11-05 14:05:20 -05002829 // Perform a 32-bit version of clz/ctz on each half of the 64-bit input.
alan-bakercc2bafb2020-11-02 08:30:18 -05002830 auto lshr = builder.CreateLShr(in, 32);
2831 auto top_bits = builder.CreateTrunc(lshr, ty);
2832 auto bot_bits = builder.CreateTrunc(in, ty);
alan-baker2cecaa72020-11-05 14:05:20 -05002833 auto top_func = builder.CreateCall(func_32bit, {top_bits});
2834 auto bot_func = builder.CreateCall(func_32bit, {bot_bits});
2835 Value *tmp = nullptr;
2836 if (leading) {
2837 // For clz, if clz(top) is 32, return 32 + clz(bot).
2838 auto cmp = builder.CreateICmpEQ(top_func, c32);
2839 auto adjust = builder.CreateAdd(bot_func, c32);
2840 tmp = builder.CreateSelect(cmp, adjust, top_func);
2841 } else {
2842 // For ctz, if clz(bot) is 32, return 32 + ctz(top)
2843 auto bot_cmp = builder.CreateICmpEQ(bot_func, c32);
2844 auto adjust = builder.CreateAdd(top_func, c32);
2845 tmp = builder.CreateSelect(bot_cmp, adjust, bot_func);
alan-bakercc2bafb2020-11-02 08:30:18 -05002846 }
alan-baker2cecaa72020-11-05 14:05:20 -05002847 // Extend the intermediate result to the correct size.
2848 return builder.CreateZExt(tmp, Call->getType());
alan-bakercc2bafb2020-11-02 08:30:18 -05002849 }
2850 });
2851}
alan-baker6b9d1ee2020-11-03 23:11:32 -05002852
2853bool ReplaceOpenCLBuiltinPass::replaceMadSat(Function &F, bool is_signed) {
2854 return replaceCallsWithValue(F, [&F, is_signed, this](CallInst *Call) {
2855 const auto ty = Call->getType();
2856 const auto a = Call->getArgOperand(0);
2857 const auto b = Call->getArgOperand(1);
2858 const auto c = Call->getArgOperand(2);
2859 IRBuilder<> builder(Call);
2860 if (is_signed) {
2861 unsigned bitwidth = Call->getType()->getScalarSizeInBits();
2862 if (bitwidth < 32) {
2863 // mul = sext(a) * sext(b)
2864 // add = mul + sext(c)
2865 // res = clamp(add, MIN, MAX)
2866 unsigned extended_width = bitwidth << 1;
2867 Type *extended_ty = IntegerType::get(F.getContext(), extended_width);
2868 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2869 extended_ty = VectorType::get(extended_ty, vec_ty->getElementCount());
2870 }
2871 auto a_sext = builder.CreateSExt(a, extended_ty);
2872 auto b_sext = builder.CreateSExt(b, extended_ty);
2873 auto c_sext = builder.CreateSExt(c, extended_ty);
2874 // Extended the size so no overflows occur.
2875 auto mul = builder.CreateMul(a_sext, b_sext, "", true, true);
2876 auto add = builder.CreateAdd(mul, c_sext, "", true, true);
2877 auto func_ty = FunctionType::get(
2878 extended_ty, {extended_ty, extended_ty, extended_ty}, false);
2879 // Don't use function type because we need signed parameters.
2880 std::string clamp_name = Builtins::GetMangledFunctionName("clamp");
2881 // The clamp values are the signed min and max of the original bitwidth
2882 // sign extended to the extended bitwidth.
2883 Constant *min = ConstantInt::get(
2884 Call->getContext(),
2885 APInt::getSignedMinValue(bitwidth).sext(extended_width));
2886 Constant *max = ConstantInt::get(
2887 Call->getContext(),
2888 APInt::getSignedMaxValue(bitwidth).sext(extended_width));
2889 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2890 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2891 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2892 unsigned vec_width = vec_ty->getElementCount().getKnownMinValue();
2893 if (extended_width == 32)
2894 clamp_name += "Dv" + std::to_string(vec_width) + "_iS_S_";
2895 else
2896 clamp_name += "Dv" + std::to_string(vec_width) + "_sS_S_";
2897 } else {
2898 if (extended_width == 32)
2899 clamp_name += "iii";
2900 else
2901 clamp_name += "sss";
2902 }
2903 auto callee = F.getParent()->getOrInsertFunction(clamp_name, func_ty);
2904 auto clamp = builder.CreateCall(callee, {add, min, max});
2905 return builder.CreateTrunc(clamp, ty);
2906 } else {
2907 auto struct_ty = GetPairStruct(ty);
2908 // Compute
2909 // {hi, lo} = smul_extended(a, b)
2910 // add = lo + c
2911 auto mul_ext = InsertSPIRVOp(Call, spv::OpSMulExtended,
2912 {Attribute::ReadNone}, struct_ty, {a, b});
2913 auto mul_lo = builder.CreateExtractValue(mul_ext, {0});
2914 auto mul_hi = builder.CreateExtractValue(mul_ext, {1});
2915 auto add = builder.CreateAdd(mul_lo, c);
2916
2917 // Constants for use in the calculation.
2918 Constant *min = ConstantInt::get(Call->getContext(),
2919 APInt::getSignedMinValue(bitwidth));
2920 Constant *max = ConstantInt::get(Call->getContext(),
2921 APInt::getSignedMaxValue(bitwidth));
2922 Constant *max_plus_1 = ConstantInt::get(
2923 Call->getContext(),
2924 APInt::getSignedMaxValue(bitwidth) + APInt(bitwidth, 1));
2925 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2926 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2927 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2928 max_plus_1 =
2929 ConstantVector::getSplat(vec_ty->getElementCount(), max_plus_1);
2930 }
2931
2932 auto a_xor_b = builder.CreateXor(a, b);
2933 auto same_sign =
2934 builder.CreateICmpSGT(a_xor_b, Constant::getAllOnesValue(ty));
2935 auto different_sign = builder.CreateNot(same_sign);
2936 auto hi_eq_0 = builder.CreateICmpEQ(mul_hi, Constant::getNullValue(ty));
2937 auto hi_ne_0 = builder.CreateNot(hi_eq_0);
2938 auto lo_ge_max = builder.CreateICmpUGE(mul_lo, max);
2939 auto c_gt_0 = builder.CreateICmpSGT(c, Constant::getNullValue(ty));
2940 auto c_lt_0 = builder.CreateICmpSLT(c, Constant::getNullValue(ty));
2941 auto add_gt_max = builder.CreateICmpUGT(add, max);
2942 auto hi_eq_m1 =
2943 builder.CreateICmpEQ(mul_hi, Constant::getAllOnesValue(ty));
2944 auto hi_ne_m1 = builder.CreateNot(hi_eq_m1);
2945 auto lo_le_max_plus_1 = builder.CreateICmpULE(mul_lo, max_plus_1);
2946 auto max_sub_lo = builder.CreateSub(max, mul_lo);
2947 auto c_lt_max_sub_lo = builder.CreateICmpULT(c, max_sub_lo);
2948
2949 // Equivalent to:
2950 // if (((x < 0) == (y < 0)) && mul_hi != 0)
2951 // return MAX
2952 // if (mul_hi == 0 && mul_lo >= MAX && (z > 0 || add > MAX))
2953 // return MAX
2954 // if (((x < 0) != (y < 0)) && mul_hi != -1)
2955 // return MIN
2956 // if (hi == -1 && mul_lo <= (MAX + 1) && (z < 0 || z < (MAX - mul_lo))
2957 // return MIN
2958 // return add
2959 auto max_clamp_1 = builder.CreateAnd(same_sign, hi_ne_0);
2960 auto max_clamp_2 = builder.CreateOr(c_gt_0, add_gt_max);
2961 auto tmp = builder.CreateAnd(hi_eq_0, lo_ge_max);
2962 max_clamp_2 = builder.CreateAnd(tmp, max_clamp_2);
2963 auto max_clamp = builder.CreateOr(max_clamp_1, max_clamp_2);
2964 auto min_clamp_1 = builder.CreateAnd(different_sign, hi_ne_m1);
2965 auto min_clamp_2 = builder.CreateOr(c_lt_0, c_lt_max_sub_lo);
2966 tmp = builder.CreateAnd(hi_eq_m1, lo_le_max_plus_1);
2967 min_clamp_2 = builder.CreateAnd(tmp, min_clamp_2);
2968 auto min_clamp = builder.CreateOr(min_clamp_1, min_clamp_2);
2969 auto sel = builder.CreateSelect(min_clamp, min, add);
2970 return builder.CreateSelect(max_clamp, max, sel);
2971 }
2972 } else {
2973 // {lo, hi} = mul_extended(a, b)
2974 // {add, carry} = add_carry(lo, c)
2975 // cmp = (mul_hi | carry) == 0
2976 // mad_sat = cmp ? add : MAX
2977 auto struct_ty = GetPairStruct(ty);
2978 auto mul_ext = InsertSPIRVOp(Call, spv::OpUMulExtended,
2979 {Attribute::ReadNone}, struct_ty, {a, b});
2980 auto mul_lo = builder.CreateExtractValue(mul_ext, {0});
2981 auto mul_hi = builder.CreateExtractValue(mul_ext, {1});
2982 auto add_carry =
2983 InsertSPIRVOp(Call, spv::OpIAddCarry, {Attribute::ReadNone},
2984 struct_ty, {mul_lo, c});
2985 auto add = builder.CreateExtractValue(add_carry, {0});
2986 auto carry = builder.CreateExtractValue(add_carry, {1});
2987 auto or_value = builder.CreateOr(mul_hi, carry);
2988 auto cmp = builder.CreateICmpEQ(or_value, Constant::getNullValue(ty));
2989 return builder.CreateSelect(cmp, add, Constant::getAllOnesValue(ty));
2990 }
2991 });
2992}
alan-baker15106572020-11-06 15:08:10 -05002993
2994bool ReplaceOpenCLBuiltinPass::replaceOrdered(Function &F, bool is_ordered) {
2995 if (!isa<IntegerType>(F.getReturnType()->getScalarType()))
2996 return false;
2997
2998 if (F.getFunctionType()->getNumParams() != 2)
2999 return false;
3000
3001 if (F.getFunctionType()->getParamType(0) !=
3002 F.getFunctionType()->getParamType(1)) {
3003 return false;
3004 }
3005
3006 switch (F.getFunctionType()->getParamType(0)->getScalarType()->getTypeID()) {
3007 case Type::FloatTyID:
3008 case Type::HalfTyID:
3009 case Type::DoubleTyID:
3010 break;
3011 default:
3012 return false;
3013 }
3014
3015 // Scalar versions all return an int, while vector versions return a vector
3016 // of an equally sized integer types (e.g. short, int or long).
3017 if (isa<VectorType>(F.getReturnType())) {
3018 if (F.getReturnType()->getScalarSizeInBits() !=
3019 F.getFunctionType()->getParamType(0)->getScalarSizeInBits()) {
3020 return false;
3021 }
3022 } else {
3023 if (F.getReturnType()->getScalarSizeInBits() != 32)
3024 return false;
3025 }
3026
3027 return replaceCallsWithValue(F, [is_ordered](CallInst *Call) {
3028 // Replace with a floating point [un]ordered comparison followed by an
3029 // extension.
3030 auto x = Call->getArgOperand(0);
3031 auto y = Call->getArgOperand(1);
3032 IRBuilder<> builder(Call);
3033 Value *tmp = nullptr;
3034 if (is_ordered) {
3035 // This leads to a slight inefficiency in the SPIR-V that is easy for
3036 // drivers to optimize where the SPIR-V for the comparison and the
3037 // extension could be fused to drop the inversion of the OpIsNan.
3038 tmp = builder.CreateFCmpORD(x, y);
3039 } else {
3040 tmp = builder.CreateFCmpUNO(x, y);
3041 }
3042 // OpenCL CTS requires that vector versions use sign extension, but scalar
3043 // versions use zero extension.
3044 if (isa<VectorType>(Call->getType()))
3045 return builder.CreateSExt(tmp, Call->getType());
3046 return builder.CreateZExt(tmp, Call->getType());
3047 });
3048}
alan-baker497920b2020-11-09 16:41:36 -05003049
3050bool ReplaceOpenCLBuiltinPass::replaceIsNormal(Function &F) {
3051 return replaceCallsWithValue(F, [this](CallInst *Call) {
3052 auto ty = Call->getType();
3053 auto x = Call->getArgOperand(0);
3054 unsigned width = x->getType()->getScalarSizeInBits();
3055 Type *int_ty = IntegerType::get(Call->getContext(), width);
3056 uint64_t abs_mask = 0x7fffffff;
3057 uint64_t exp_mask = 0x7f800000;
3058 uint64_t min_mask = 0x00800000;
3059 if (width == 16) {
3060 abs_mask = 0x7fff;
3061 exp_mask = 0x7c00;
3062 min_mask = 0x0400;
3063 } else if (width == 64) {
3064 abs_mask = 0x7fffffffffffffff;
3065 exp_mask = 0x7ff0000000000000;
3066 min_mask = 0x0010000000000000;
3067 }
3068 Constant *abs_const = ConstantInt::get(int_ty, APInt(width, abs_mask));
3069 Constant *exp_const = ConstantInt::get(int_ty, APInt(width, exp_mask));
3070 Constant *min_const = ConstantInt::get(int_ty, APInt(width, min_mask));
3071 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
3072 int_ty = VectorType::get(int_ty, vec_ty->getElementCount());
3073 abs_const =
3074 ConstantVector::getSplat(vec_ty->getElementCount(), abs_const);
3075 exp_const =
3076 ConstantVector::getSplat(vec_ty->getElementCount(), exp_const);
3077 min_const =
3078 ConstantVector::getSplat(vec_ty->getElementCount(), min_const);
3079 }
3080 // Drop the sign bit and then check that the number is between
3081 // (exclusive) the min and max exponent values for the bit width.
3082 IRBuilder<> builder(Call);
3083 auto bitcast = builder.CreateBitCast(x, int_ty);
3084 auto abs = builder.CreateAnd(bitcast, abs_const);
3085 auto lt = builder.CreateICmpULT(abs, exp_const);
3086 auto ge = builder.CreateICmpUGE(abs, min_const);
3087 auto tmp = builder.CreateAnd(lt, ge);
3088 // OpenCL CTS requires that vector versions use sign extension, but scalar
3089 // versions use zero extension.
3090 if (isa<VectorType>(ty))
3091 return builder.CreateSExt(tmp, ty);
3092 return builder.CreateZExt(tmp, ty);
3093 });
3094}
alan-bakere0406e72020-11-10 12:32:04 -05003095
3096bool ReplaceOpenCLBuiltinPass::replaceFDim(Function &F) {
3097 return replaceCallsWithValue(F, [](CallInst *Call) {
3098 const auto x = Call->getArgOperand(0);
3099 const auto y = Call->getArgOperand(1);
3100 IRBuilder<> builder(Call);
3101 auto sub = builder.CreateFSub(x, y);
3102 auto cmp = builder.CreateFCmpUGT(x, y);
3103 return builder.CreateSelect(cmp, sub,
3104 Constant::getNullValue(Call->getType()));
3105 });
3106}