blob: bf0fc36055789daacd1bb4d2a1c421a49085adff [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto62653202017-10-16 19:05:18 -040015#include <math.h>
16#include <string>
17#include <tuple>
18
Kévin Petit9d1a9d12019-03-25 15:23:46 +000019#include "llvm/ADT/StringSwitch.h"
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/IR/Constants.h"
David Neto118188e2018-08-24 11:27:54 -040021#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040022#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040023#include "llvm/IR/Module.h"
alan-baker4986eff2020-10-29 13:38:00 -040024#include "llvm/IR/Operator.h"
Kévin Petitf5b78a22018-10-25 14:32:17 +000025#include "llvm/IR/ValueSymbolTable.h"
David Neto118188e2018-08-24 11:27:54 -040026#include "llvm/Pass.h"
27#include "llvm/Support/CommandLine.h"
28#include "llvm/Support/raw_ostream.h"
alan-baker4986eff2020-10-29 13:38:00 -040029#include "llvm/Transforms/Utils/BasicBlockUtils.h"
David Neto118188e2018-08-24 11:27:54 -040030#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040031
alan-bakere0902602020-03-23 08:43:40 -040032#include "spirv/unified1/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040033
alan-baker931d18a2019-12-12 08:21:32 -050034#include "clspv/AddressSpace.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035#include "clspv/Option.h"
David Neto482550a2018-03-24 05:21:07 -070036
SJW2c317da2020-03-23 07:39:13 -050037#include "Builtins.h"
alan-baker931d18a2019-12-12 08:21:32 -050038#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040039#include "Passes.h"
40#include "SPIRVOp.h"
alan-bakerf906d2b2019-12-10 11:26:23 -050041#include "Types.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040042
SJW2c317da2020-03-23 07:39:13 -050043using namespace clspv;
David Neto22f144c2017-06-12 14:26:21 -040044using namespace llvm;
45
46#define DEBUG_TYPE "ReplaceOpenCLBuiltin"
47
48namespace {
Kévin Petit8a560882019-03-21 15:24:34 +000049
David Neto22f144c2017-06-12 14:26:21 -040050uint32_t clz(uint32_t v) {
51 uint32_t r;
52 uint32_t shift;
53
54 r = (v > 0xFFFF) << 4;
55 v >>= r;
56 shift = (v > 0xFF) << 3;
57 v >>= shift;
58 r |= shift;
59 shift = (v > 0xF) << 2;
60 v >>= shift;
61 r |= shift;
62 shift = (v > 0x3) << 1;
63 v >>= shift;
64 r |= shift;
65 r |= (v >> 1);
66
67 return r;
68}
69
Kévin Petitfdfa92e2019-09-25 14:20:58 +010070Type *getIntOrIntVectorTyForCast(LLVMContext &C, Type *Ty) {
71 Type *IntTy = Type::getIntNTy(C, Ty->getScalarSizeInBits());
James Pricecf53df42020-04-20 14:41:24 -040072 if (auto vec_ty = dyn_cast<VectorType>(Ty)) {
alan-baker5a8c3be2020-09-09 13:44:26 -040073 IntTy = FixedVectorType::get(IntTy,
74 vec_ty->getElementCount().getKnownMinValue());
Kévin Petitfdfa92e2019-09-25 14:20:58 +010075 }
76 return IntTy;
77}
78
alan-baker4986eff2020-10-29 13:38:00 -040079Value *MemoryOrderSemantics(Value *order, bool is_global,
80 Instruction *InsertBefore,
81 spv::MemorySemanticsMask base_semantics) {
82 enum AtomicMemoryOrder : uint32_t {
83 kMemoryOrderRelaxed = 0,
84 kMemoryOrderAcquire = 2,
85 kMemoryOrderRelease = 3,
86 kMemoryOrderAcqRel = 4,
87 kMemoryOrderSeqCst = 5
88 };
89
90 IRBuilder<> builder(InsertBefore);
91
92 // Constants for OpenCL C 2.0 memory_order.
93 const auto relaxed = builder.getInt32(AtomicMemoryOrder::kMemoryOrderRelaxed);
94 const auto acquire = builder.getInt32(AtomicMemoryOrder::kMemoryOrderAcquire);
95 const auto release = builder.getInt32(AtomicMemoryOrder::kMemoryOrderRelease);
96 const auto acq_rel = builder.getInt32(AtomicMemoryOrder::kMemoryOrderAcqRel);
97
98 // Constants for SPIR-V ordering memory semantics.
99 const auto RelaxedSemantics = builder.getInt32(spv::MemorySemanticsMaskNone);
100 const auto AcquireSemantics =
101 builder.getInt32(spv::MemorySemanticsAcquireMask);
102 const auto ReleaseSemantics =
103 builder.getInt32(spv::MemorySemanticsReleaseMask);
104 const auto AcqRelSemantics =
105 builder.getInt32(spv::MemorySemanticsAcquireReleaseMask);
106
107 // Constants for SPIR-V storage class semantics.
108 const auto UniformSemantics =
109 builder.getInt32(spv::MemorySemanticsUniformMemoryMask);
110 const auto WorkgroupSemantics =
111 builder.getInt32(spv::MemorySemanticsWorkgroupMemoryMask);
112
113 // Instead of sequentially consistent, use acquire, release or acquire
114 // release semantics.
115 Value *base_order = nullptr;
116 switch (base_semantics) {
117 case spv::MemorySemanticsAcquireMask:
118 base_order = AcquireSemantics;
119 break;
120 case spv::MemorySemanticsReleaseMask:
121 base_order = ReleaseSemantics;
122 break;
123 default:
124 base_order = AcqRelSemantics;
125 break;
126 }
127
128 Value *storage = is_global ? UniformSemantics : WorkgroupSemantics;
129 if (order == nullptr)
130 return builder.CreateOr({storage, base_order});
131
132 auto is_relaxed = builder.CreateICmpEQ(order, relaxed);
133 auto is_acquire = builder.CreateICmpEQ(order, acquire);
134 auto is_release = builder.CreateICmpEQ(order, release);
135 auto is_acq_rel = builder.CreateICmpEQ(order, acq_rel);
136 auto semantics =
137 builder.CreateSelect(is_relaxed, RelaxedSemantics, base_order);
138 semantics = builder.CreateSelect(is_acquire, AcquireSemantics, semantics);
139 semantics = builder.CreateSelect(is_release, ReleaseSemantics, semantics);
140 semantics = builder.CreateSelect(is_acq_rel, AcqRelSemantics, semantics);
141 return builder.CreateOr({storage, semantics});
142}
143
144Value *MemoryScope(Value *scope, bool is_global, Instruction *InsertBefore) {
145 enum AtomicMemoryScope : uint32_t {
146 kMemoryScopeWorkItem = 0,
147 kMemoryScopeWorkGroup = 1,
148 kMemoryScopeDevice = 2,
149 kMemoryScopeAllSVMDevices = 3, // not supported
150 kMemoryScopeSubGroup = 4
151 };
152
153 IRBuilder<> builder(InsertBefore);
154
155 // Constants for OpenCL C 2.0 memory_scope.
156 const auto work_item =
157 builder.getInt32(AtomicMemoryScope::kMemoryScopeWorkItem);
158 const auto work_group =
159 builder.getInt32(AtomicMemoryScope::kMemoryScopeWorkGroup);
160 const auto sub_group =
161 builder.getInt32(AtomicMemoryScope::kMemoryScopeSubGroup);
162 const auto device = builder.getInt32(AtomicMemoryScope::kMemoryScopeDevice);
163
164 // Constants for SPIR-V memory scopes.
165 const auto InvocationScope = builder.getInt32(spv::ScopeInvocation);
166 const auto WorkgroupScope = builder.getInt32(spv::ScopeWorkgroup);
167 const auto DeviceScope = builder.getInt32(spv::ScopeDevice);
168 const auto SubgroupScope = builder.getInt32(spv::ScopeSubgroup);
169
170 auto base_scope = is_global ? DeviceScope : WorkgroupScope;
171 if (scope == nullptr)
172 return base_scope;
173
174 auto is_work_item = builder.CreateICmpEQ(scope, work_item);
175 auto is_work_group = builder.CreateICmpEQ(scope, work_group);
176 auto is_sub_group = builder.CreateICmpEQ(scope, sub_group);
177 auto is_device = builder.CreateICmpEQ(scope, device);
178
179 scope = builder.CreateSelect(is_work_item, InvocationScope, base_scope);
180 scope = builder.CreateSelect(is_work_group, WorkgroupScope, scope);
181 scope = builder.CreateSelect(is_sub_group, SubgroupScope, scope);
182 scope = builder.CreateSelect(is_device, DeviceScope, scope);
183
184 return scope;
185}
186
SJW2c317da2020-03-23 07:39:13 -0500187bool replaceCallsWithValue(Function &F,
188 std::function<Value *(CallInst *)> Replacer) {
189
190 bool Changed = false;
191
192 SmallVector<Instruction *, 4> ToRemoves;
193
194 // Walk the users of the function.
195 for (auto &U : F.uses()) {
196 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
197
198 auto NewValue = Replacer(CI);
199
200 if (NewValue != nullptr) {
201 CI->replaceAllUsesWith(NewValue);
202
203 // Lastly, remember to remove the user.
204 ToRemoves.push_back(CI);
205 }
206 }
207 }
208
209 Changed = !ToRemoves.empty();
210
211 // And cleanup the calls we don't use anymore.
212 for (auto V : ToRemoves) {
213 V->eraseFromParent();
214 }
215
216 return Changed;
217}
218
David Neto22f144c2017-06-12 14:26:21 -0400219struct ReplaceOpenCLBuiltinPass final : public ModulePass {
220 static char ID;
221 ReplaceOpenCLBuiltinPass() : ModulePass(ID) {}
222
223 bool runOnModule(Module &M) override;
alan-baker6b9d1ee2020-11-03 23:11:32 -0500224
225private:
SJW2c317da2020-03-23 07:39:13 -0500226 bool runOnFunction(Function &F);
227 bool replaceAbs(Function &F);
228 bool replaceAbsDiff(Function &F, bool is_signed);
229 bool replaceCopysign(Function &F);
230 bool replaceRecip(Function &F);
231 bool replaceDivide(Function &F);
232 bool replaceDot(Function &F);
233 bool replaceFmod(Function &F);
SJW61531372020-06-09 07:31:08 -0500234 bool replaceExp10(Function &F, const std::string &basename);
235 bool replaceLog10(Function &F, const std::string &basename);
gnl21636e7992020-09-09 16:08:16 +0100236 bool replaceLog1p(Function &F);
alan-baker12d2c182020-07-20 08:22:42 -0400237 bool replaceBarrier(Function &F, bool subgroup = false);
SJW2c317da2020-03-23 07:39:13 -0500238 bool replaceMemFence(Function &F, uint32_t semantics);
Kévin Petit1cb45112020-04-27 18:55:48 +0100239 bool replacePrefetch(Function &F);
SJW2c317da2020-03-23 07:39:13 -0500240 bool replaceRelational(Function &F, CmpInst::Predicate P, int32_t C);
241 bool replaceIsInfAndIsNan(Function &F, spv::Op SPIRVOp, int32_t isvec);
242 bool replaceIsFinite(Function &F);
243 bool replaceAllAndAny(Function &F, spv::Op SPIRVOp);
244 bool replaceUpsample(Function &F);
245 bool replaceRotate(Function &F);
246 bool replaceConvert(Function &F, bool SrcIsSigned, bool DstIsSigned);
247 bool replaceMulHi(Function &F, bool is_signed, bool is_mad = false);
248 bool replaceSelect(Function &F);
249 bool replaceBitSelect(Function &F);
SJW61531372020-06-09 07:31:08 -0500250 bool replaceStep(Function &F, bool is_smooth);
SJW2c317da2020-03-23 07:39:13 -0500251 bool replaceSignbit(Function &F, bool is_vec);
252 bool replaceMul(Function &F, bool is_float, bool is_mad);
253 bool replaceVloadHalf(Function &F, const std::string &name, int vec_size);
254 bool replaceVloadHalf(Function &F);
255 bool replaceVloadHalf2(Function &F);
256 bool replaceVloadHalf4(Function &F);
257 bool replaceClspvVloadaHalf2(Function &F);
258 bool replaceClspvVloadaHalf4(Function &F);
259 bool replaceVstoreHalf(Function &F, int vec_size);
260 bool replaceVstoreHalf(Function &F);
261 bool replaceVstoreHalf2(Function &F);
262 bool replaceVstoreHalf4(Function &F);
263 bool replaceHalfReadImage(Function &F);
264 bool replaceHalfWriteImage(Function &F);
265 bool replaceSampledReadImageWithIntCoords(Function &F);
266 bool replaceAtomics(Function &F, spv::Op Op);
267 bool replaceAtomics(Function &F, llvm::AtomicRMWInst::BinOp Op);
alan-baker4986eff2020-10-29 13:38:00 -0400268 bool replaceAtomicLoad(Function &F);
269 bool replaceExplicitAtomics(Function &F, spv::Op Op,
270 spv::MemorySemanticsMask semantics =
271 spv::MemorySemanticsAcquireReleaseMask);
272 bool replaceAtomicCompareExchange(Function &);
SJW2c317da2020-03-23 07:39:13 -0500273 bool replaceCross(Function &F);
274 bool replaceFract(Function &F, int vec_size);
275 bool replaceVload(Function &F);
276 bool replaceVstore(Function &F);
alan-bakera52b7312020-10-26 08:58:51 -0400277 bool replaceAddSat(Function &F, bool is_signed);
Kévin Petit8576f682020-11-02 14:51:32 +0000278 bool replaceHadd(Function &F, bool is_signed,
279 Instruction::BinaryOps join_opcode);
alan-bakercc2bafb2020-11-02 08:30:18 -0500280 bool replaceClz(Function &F);
alan-baker6b9d1ee2020-11-03 23:11:32 -0500281 bool replaceMadSat(Function &F, bool is_signed);
282
283 // Caches struct types for { |type|, |type| }. This prevents
284 // getOrInsertFunction from introducing a bitcasts between structs with
285 // identical contents.
286 Type *GetPairStruct(Type *type);
287
288 DenseMap<Type *, Type *> PairStructMap;
David Neto22f144c2017-06-12 14:26:21 -0400289};
SJW2c317da2020-03-23 07:39:13 -0500290
Kévin Petit91bc72e2019-04-08 15:17:46 +0100291} // namespace
David Neto22f144c2017-06-12 14:26:21 -0400292
293char ReplaceOpenCLBuiltinPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400294INITIALIZE_PASS(ReplaceOpenCLBuiltinPass, "ReplaceOpenCLBuiltin",
295 "Replace OpenCL Builtins Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -0400296
297namespace clspv {
298ModulePass *createReplaceOpenCLBuiltinPass() {
299 return new ReplaceOpenCLBuiltinPass();
300}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400301} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -0400302
303bool ReplaceOpenCLBuiltinPass::runOnModule(Module &M) {
SJW2c317da2020-03-23 07:39:13 -0500304 std::list<Function *> func_list;
305 for (auto &F : M.getFunctionList()) {
306 // process only function declarations
307 if (F.isDeclaration() && runOnFunction(F)) {
308 func_list.push_front(&F);
Kévin Petit2444e9b2018-11-09 14:14:37 +0000309 }
310 }
SJW2c317da2020-03-23 07:39:13 -0500311 if (func_list.size() != 0) {
312 // recursively convert functions, but first remove dead
313 for (auto *F : func_list) {
314 if (F->use_empty()) {
315 F->eraseFromParent();
316 }
317 }
318 runOnModule(M);
319 return true;
320 }
321 return false;
Kévin Petit2444e9b2018-11-09 14:14:37 +0000322}
323
SJW2c317da2020-03-23 07:39:13 -0500324bool ReplaceOpenCLBuiltinPass::runOnFunction(Function &F) {
325 auto &FI = Builtins::Lookup(&F);
326 switch (FI.getType()) {
327 case Builtins::kAbs:
328 if (!FI.getParameter(0).is_signed) {
329 return replaceAbs(F);
330 }
331 break;
332 case Builtins::kAbsDiff:
333 return replaceAbsDiff(F, FI.getParameter(0).is_signed);
alan-bakera52b7312020-10-26 08:58:51 -0400334
335 case Builtins::kAddSat:
336 return replaceAddSat(F, FI.getParameter(0).is_signed);
337
alan-bakercc2bafb2020-11-02 08:30:18 -0500338 case Builtins::kClz:
339 return replaceClz(F);
340
alan-bakerb6da5132020-10-29 15:59:06 -0400341 case Builtins::kHadd:
Kévin Petit8576f682020-11-02 14:51:32 +0000342 return replaceHadd(F, FI.getParameter(0).is_signed, Instruction::And);
alan-bakerb6da5132020-10-29 15:59:06 -0400343 case Builtins::kRhadd:
Kévin Petit8576f682020-11-02 14:51:32 +0000344 return replaceHadd(F, FI.getParameter(0).is_signed, Instruction::Or);
alan-bakerb6da5132020-10-29 15:59:06 -0400345
SJW2c317da2020-03-23 07:39:13 -0500346 case Builtins::kCopysign:
347 return replaceCopysign(F);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100348
SJW2c317da2020-03-23 07:39:13 -0500349 case Builtins::kHalfRecip:
350 case Builtins::kNativeRecip:
351 return replaceRecip(F);
Kévin Petite8edce32019-04-10 14:23:32 +0100352
SJW2c317da2020-03-23 07:39:13 -0500353 case Builtins::kHalfDivide:
354 case Builtins::kNativeDivide:
355 return replaceDivide(F);
356
357 case Builtins::kDot:
358 return replaceDot(F);
359
360 case Builtins::kExp10:
361 case Builtins::kHalfExp10:
SJW61531372020-06-09 07:31:08 -0500362 case Builtins::kNativeExp10:
363 return replaceExp10(F, FI.getName());
SJW2c317da2020-03-23 07:39:13 -0500364
365 case Builtins::kLog10:
366 case Builtins::kHalfLog10:
SJW61531372020-06-09 07:31:08 -0500367 case Builtins::kNativeLog10:
368 return replaceLog10(F, FI.getName());
SJW2c317da2020-03-23 07:39:13 -0500369
gnl21636e7992020-09-09 16:08:16 +0100370 case Builtins::kLog1p:
371 return replaceLog1p(F);
372
SJW2c317da2020-03-23 07:39:13 -0500373 case Builtins::kFmod:
374 return replaceFmod(F);
375
376 case Builtins::kBarrier:
377 case Builtins::kWorkGroupBarrier:
378 return replaceBarrier(F);
379
alan-baker12d2c182020-07-20 08:22:42 -0400380 case Builtins::kSubGroupBarrier:
381 return replaceBarrier(F, true);
382
SJW2c317da2020-03-23 07:39:13 -0500383 case Builtins::kMemFence:
alan-baker12d2c182020-07-20 08:22:42 -0400384 return replaceMemFence(F, spv::MemorySemanticsAcquireReleaseMask);
SJW2c317da2020-03-23 07:39:13 -0500385 case Builtins::kReadMemFence:
386 return replaceMemFence(F, spv::MemorySemanticsAcquireMask);
387 case Builtins::kWriteMemFence:
388 return replaceMemFence(F, spv::MemorySemanticsReleaseMask);
389
390 // Relational
391 case Builtins::kIsequal:
392 return replaceRelational(F, CmpInst::FCMP_OEQ,
393 FI.getParameter(0).vector_size ? -1 : 1);
394 case Builtins::kIsgreater:
395 return replaceRelational(F, CmpInst::FCMP_OGT,
396 FI.getParameter(0).vector_size ? -1 : 1);
397 case Builtins::kIsgreaterequal:
398 return replaceRelational(F, CmpInst::FCMP_OGE,
399 FI.getParameter(0).vector_size ? -1 : 1);
400 case Builtins::kIsless:
401 return replaceRelational(F, CmpInst::FCMP_OLT,
402 FI.getParameter(0).vector_size ? -1 : 1);
403 case Builtins::kIslessequal:
404 return replaceRelational(F, CmpInst::FCMP_OLE,
405 FI.getParameter(0).vector_size ? -1 : 1);
406 case Builtins::kIsnotequal:
407 return replaceRelational(F, CmpInst::FCMP_ONE,
408 FI.getParameter(0).vector_size ? -1 : 1);
409
410 case Builtins::kIsinf: {
411 bool is_vec = FI.getParameter(0).vector_size != 0;
412 return replaceIsInfAndIsNan(F, spv::OpIsInf, is_vec ? -1 : 1);
413 }
414 case Builtins::kIsnan: {
415 bool is_vec = FI.getParameter(0).vector_size != 0;
416 return replaceIsInfAndIsNan(F, spv::OpIsNan, is_vec ? -1 : 1);
417 }
418
419 case Builtins::kIsfinite:
420 return replaceIsFinite(F);
421
422 case Builtins::kAll: {
423 bool is_vec = FI.getParameter(0).vector_size != 0;
424 return replaceAllAndAny(F, !is_vec ? spv::OpNop : spv::OpAll);
425 }
426 case Builtins::kAny: {
427 bool is_vec = FI.getParameter(0).vector_size != 0;
428 return replaceAllAndAny(F, !is_vec ? spv::OpNop : spv::OpAny);
429 }
430
431 case Builtins::kUpsample:
432 return replaceUpsample(F);
433
434 case Builtins::kRotate:
435 return replaceRotate(F);
436
437 case Builtins::kConvert:
438 return replaceConvert(F, FI.getParameter(0).is_signed,
439 FI.getReturnType().is_signed);
440
alan-baker4986eff2020-10-29 13:38:00 -0400441 // OpenCL 2.0 explicit atomics have different default scopes and semantics
442 // than legacy atomic functions.
443 case Builtins::kAtomicLoad:
444 case Builtins::kAtomicLoadExplicit:
445 return replaceAtomicLoad(F);
446 case Builtins::kAtomicStore:
447 case Builtins::kAtomicStoreExplicit:
448 return replaceExplicitAtomics(F, spv::OpAtomicStore,
449 spv::MemorySemanticsReleaseMask);
450 case Builtins::kAtomicExchange:
451 case Builtins::kAtomicExchangeExplicit:
452 return replaceExplicitAtomics(F, spv::OpAtomicExchange);
453 case Builtins::kAtomicFetchAdd:
454 case Builtins::kAtomicFetchAddExplicit:
455 return replaceExplicitAtomics(F, spv::OpAtomicIAdd);
456 case Builtins::kAtomicFetchSub:
457 case Builtins::kAtomicFetchSubExplicit:
458 return replaceExplicitAtomics(F, spv::OpAtomicISub);
459 case Builtins::kAtomicFetchOr:
460 case Builtins::kAtomicFetchOrExplicit:
461 return replaceExplicitAtomics(F, spv::OpAtomicOr);
462 case Builtins::kAtomicFetchXor:
463 case Builtins::kAtomicFetchXorExplicit:
464 return replaceExplicitAtomics(F, spv::OpAtomicXor);
465 case Builtins::kAtomicFetchAnd:
466 case Builtins::kAtomicFetchAndExplicit:
467 return replaceExplicitAtomics(F, spv::OpAtomicAnd);
468 case Builtins::kAtomicFetchMin:
469 case Builtins::kAtomicFetchMinExplicit:
470 return replaceExplicitAtomics(F, FI.getParameter(1).is_signed
471 ? spv::OpAtomicSMin
472 : spv::OpAtomicUMin);
473 case Builtins::kAtomicFetchMax:
474 case Builtins::kAtomicFetchMaxExplicit:
475 return replaceExplicitAtomics(F, FI.getParameter(1).is_signed
476 ? spv::OpAtomicSMax
477 : spv::OpAtomicUMax);
478 // Weak compare exchange is generated as strong compare exchange.
479 case Builtins::kAtomicCompareExchangeWeak:
480 case Builtins::kAtomicCompareExchangeWeakExplicit:
481 case Builtins::kAtomicCompareExchangeStrong:
482 case Builtins::kAtomicCompareExchangeStrongExplicit:
483 return replaceAtomicCompareExchange(F);
484
485 // Legacy atomic functions.
SJW2c317da2020-03-23 07:39:13 -0500486 case Builtins::kAtomicInc:
487 return replaceAtomics(F, spv::OpAtomicIIncrement);
488 case Builtins::kAtomicDec:
489 return replaceAtomics(F, spv::OpAtomicIDecrement);
490 case Builtins::kAtomicCmpxchg:
491 return replaceAtomics(F, spv::OpAtomicCompareExchange);
492 case Builtins::kAtomicAdd:
493 return replaceAtomics(F, llvm::AtomicRMWInst::Add);
494 case Builtins::kAtomicSub:
495 return replaceAtomics(F, llvm::AtomicRMWInst::Sub);
496 case Builtins::kAtomicXchg:
497 return replaceAtomics(F, llvm::AtomicRMWInst::Xchg);
498 case Builtins::kAtomicMin:
499 return replaceAtomics(F, FI.getParameter(0).is_signed
500 ? llvm::AtomicRMWInst::Min
501 : llvm::AtomicRMWInst::UMin);
502 case Builtins::kAtomicMax:
503 return replaceAtomics(F, FI.getParameter(0).is_signed
504 ? llvm::AtomicRMWInst::Max
505 : llvm::AtomicRMWInst::UMax);
506 case Builtins::kAtomicAnd:
507 return replaceAtomics(F, llvm::AtomicRMWInst::And);
508 case Builtins::kAtomicOr:
509 return replaceAtomics(F, llvm::AtomicRMWInst::Or);
510 case Builtins::kAtomicXor:
511 return replaceAtomics(F, llvm::AtomicRMWInst::Xor);
512
513 case Builtins::kCross:
514 if (FI.getParameter(0).vector_size == 4) {
515 return replaceCross(F);
516 }
517 break;
518
519 case Builtins::kFract:
520 if (FI.getParameterCount()) {
521 return replaceFract(F, FI.getParameter(0).vector_size);
522 }
523 break;
524
525 case Builtins::kMadHi:
526 return replaceMulHi(F, FI.getParameter(0).is_signed, true);
527 case Builtins::kMulHi:
528 return replaceMulHi(F, FI.getParameter(0).is_signed, false);
529
alan-baker6b9d1ee2020-11-03 23:11:32 -0500530 case Builtins::kMadSat:
531 return replaceMadSat(F, FI.getParameter(0).is_signed);
532
SJW2c317da2020-03-23 07:39:13 -0500533 case Builtins::kMad:
534 case Builtins::kMad24:
535 return replaceMul(F, FI.getParameter(0).type_id == llvm::Type::FloatTyID,
536 true);
537 case Builtins::kMul24:
538 return replaceMul(F, FI.getParameter(0).type_id == llvm::Type::FloatTyID,
539 false);
540
541 case Builtins::kSelect:
542 return replaceSelect(F);
543
544 case Builtins::kBitselect:
545 return replaceBitSelect(F);
546
547 case Builtins::kVload:
548 return replaceVload(F);
549
550 case Builtins::kVloadaHalf:
551 case Builtins::kVloadHalf:
552 return replaceVloadHalf(F, FI.getName(), FI.getParameter(0).vector_size);
553
554 case Builtins::kVstore:
555 return replaceVstore(F);
556
557 case Builtins::kVstoreHalf:
558 case Builtins::kVstoreaHalf:
559 return replaceVstoreHalf(F, FI.getParameter(0).vector_size);
560
561 case Builtins::kSmoothstep: {
562 int vec_size = FI.getLastParameter().vector_size;
563 if (FI.getParameter(0).vector_size == 0 && vec_size != 0) {
SJW61531372020-06-09 07:31:08 -0500564 return replaceStep(F, true);
SJW2c317da2020-03-23 07:39:13 -0500565 }
566 break;
567 }
568 case Builtins::kStep: {
569 int vec_size = FI.getLastParameter().vector_size;
570 if (FI.getParameter(0).vector_size == 0 && vec_size != 0) {
SJW61531372020-06-09 07:31:08 -0500571 return replaceStep(F, false);
SJW2c317da2020-03-23 07:39:13 -0500572 }
573 break;
574 }
575
576 case Builtins::kSignbit:
577 return replaceSignbit(F, FI.getParameter(0).vector_size != 0);
578
579 case Builtins::kReadImageh:
580 return replaceHalfReadImage(F);
581 case Builtins::kReadImagef:
582 case Builtins::kReadImagei:
583 case Builtins::kReadImageui: {
584 if (FI.getParameter(1).isSampler() &&
585 FI.getParameter(2).type_id == llvm::Type::IntegerTyID) {
586 return replaceSampledReadImageWithIntCoords(F);
587 }
588 break;
589 }
590
591 case Builtins::kWriteImageh:
592 return replaceHalfWriteImage(F);
593
Kévin Petit1cb45112020-04-27 18:55:48 +0100594 case Builtins::kPrefetch:
595 return replacePrefetch(F);
596
SJW2c317da2020-03-23 07:39:13 -0500597 default:
598 break;
599 }
600
601 return false;
602}
603
alan-baker6b9d1ee2020-11-03 23:11:32 -0500604Type *ReplaceOpenCLBuiltinPass::GetPairStruct(Type *type) {
605 auto iter = PairStructMap.find(type);
606 if (iter != PairStructMap.end())
607 return iter->second;
608
609 auto new_struct = StructType::get(type->getContext(), {type, type});
610 PairStructMap[type] = new_struct;
611 return new_struct;
612}
613
SJW2c317da2020-03-23 07:39:13 -0500614bool ReplaceOpenCLBuiltinPass::replaceAbs(Function &F) {
615 return replaceCallsWithValue(F,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400616 [](CallInst *CI) { return CI->getOperand(0); });
Kévin Petite8edce32019-04-10 14:23:32 +0100617}
618
SJW2c317da2020-03-23 07:39:13 -0500619bool ReplaceOpenCLBuiltinPass::replaceAbsDiff(Function &F, bool is_signed) {
620 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100621 auto XValue = CI->getOperand(0);
622 auto YValue = CI->getOperand(1);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100623
Kévin Petite8edce32019-04-10 14:23:32 +0100624 IRBuilder<> Builder(CI);
625 auto XmY = Builder.CreateSub(XValue, YValue);
626 auto YmX = Builder.CreateSub(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100627
SJW2c317da2020-03-23 07:39:13 -0500628 Value *Cmp = nullptr;
629 if (is_signed) {
Kévin Petite8edce32019-04-10 14:23:32 +0100630 Cmp = Builder.CreateICmpSGT(YValue, XValue);
631 } else {
632 Cmp = Builder.CreateICmpUGT(YValue, XValue);
Kévin Petit91bc72e2019-04-08 15:17:46 +0100633 }
Kévin Petit91bc72e2019-04-08 15:17:46 +0100634
Kévin Petite8edce32019-04-10 14:23:32 +0100635 return Builder.CreateSelect(Cmp, YmX, XmY);
636 });
Kévin Petit91bc72e2019-04-08 15:17:46 +0100637}
638
SJW2c317da2020-03-23 07:39:13 -0500639bool ReplaceOpenCLBuiltinPass::replaceCopysign(Function &F) {
640 return replaceCallsWithValue(F, [&F](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100641 auto XValue = CI->getOperand(0);
642 auto YValue = CI->getOperand(1);
Kévin Petit8c1be282019-04-02 19:34:25 +0100643
Kévin Petite8edce32019-04-10 14:23:32 +0100644 auto Ty = XValue->getType();
Kévin Petit8c1be282019-04-02 19:34:25 +0100645
SJW2c317da2020-03-23 07:39:13 -0500646 Type *IntTy = Type::getIntNTy(F.getContext(), Ty->getScalarSizeInBits());
James Pricecf53df42020-04-20 14:41:24 -0400647 if (auto vec_ty = dyn_cast<VectorType>(Ty)) {
alan-baker5a8c3be2020-09-09 13:44:26 -0400648 IntTy = FixedVectorType::get(
649 IntTy, vec_ty->getElementCount().getKnownMinValue());
Kévin Petit8c1be282019-04-02 19:34:25 +0100650 }
Kévin Petit8c1be282019-04-02 19:34:25 +0100651
Kévin Petite8edce32019-04-10 14:23:32 +0100652 // Return X with the sign of Y
653
654 // Sign bit masks
655 auto SignBit = IntTy->getScalarSizeInBits() - 1;
656 auto SignBitMask = 1 << SignBit;
657 auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
658 auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);
659
660 IRBuilder<> Builder(CI);
661
662 // Extract sign of Y
663 auto YInt = Builder.CreateBitCast(YValue, IntTy);
664 auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);
665
666 // Clear sign bit in X
667 auto XInt = Builder.CreateBitCast(XValue, IntTy);
668 XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);
669
670 // Insert sign bit of Y into X
671 auto NewXInt = Builder.CreateOr(XInt, YSign);
672
673 // And cast back to floating-point
674 return Builder.CreateBitCast(NewXInt, Ty);
675 });
Kévin Petit8c1be282019-04-02 19:34:25 +0100676}
677
SJW2c317da2020-03-23 07:39:13 -0500678bool ReplaceOpenCLBuiltinPass::replaceRecip(Function &F) {
679 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100680 // Recip has one arg.
681 auto Arg = CI->getOperand(0);
682 auto Cst1 = ConstantFP::get(Arg->getType(), 1.0);
683 return BinaryOperator::Create(Instruction::FDiv, Cst1, Arg, "", CI);
684 });
David Neto22f144c2017-06-12 14:26:21 -0400685}
686
SJW2c317da2020-03-23 07:39:13 -0500687bool ReplaceOpenCLBuiltinPass::replaceDivide(Function &F) {
688 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +0100689 auto Op0 = CI->getOperand(0);
690 auto Op1 = CI->getOperand(1);
691 return BinaryOperator::Create(Instruction::FDiv, Op0, Op1, "", CI);
692 });
David Neto22f144c2017-06-12 14:26:21 -0400693}
694
SJW2c317da2020-03-23 07:39:13 -0500695bool ReplaceOpenCLBuiltinPass::replaceDot(Function &F) {
696 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petit1329a002019-06-15 05:54:05 +0100697 auto Op0 = CI->getOperand(0);
698 auto Op1 = CI->getOperand(1);
699
SJW2c317da2020-03-23 07:39:13 -0500700 Value *V = nullptr;
Kévin Petit1329a002019-06-15 05:54:05 +0100701 if (Op0->getType()->isVectorTy()) {
702 V = clspv::InsertSPIRVOp(CI, spv::OpDot, {Attribute::ReadNone},
703 CI->getType(), {Op0, Op1});
704 } else {
705 V = BinaryOperator::Create(Instruction::FMul, Op0, Op1, "", CI);
706 }
707
708 return V;
709 });
710}
711
SJW2c317da2020-03-23 07:39:13 -0500712bool ReplaceOpenCLBuiltinPass::replaceExp10(Function &F,
SJW61531372020-06-09 07:31:08 -0500713 const std::string &basename) {
SJW2c317da2020-03-23 07:39:13 -0500714 // convert to natural
715 auto slen = basename.length() - 2;
SJW61531372020-06-09 07:31:08 -0500716 std::string NewFName = basename.substr(0, slen);
717 NewFName =
718 Builtins::GetMangledFunctionName(NewFName.c_str(), F.getFunctionType());
David Neto22f144c2017-06-12 14:26:21 -0400719
SJW2c317da2020-03-23 07:39:13 -0500720 Module &M = *F.getParent();
721 return replaceCallsWithValue(F, [&](CallInst *CI) {
722 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
723
724 auto Arg = CI->getOperand(0);
725
726 // Constant of the natural log of 10 (ln(10)).
727 const double Ln10 =
728 2.302585092994045684017991454684364207601101488628772976033;
729
730 auto Mul = BinaryOperator::Create(
731 Instruction::FMul, ConstantFP::get(Arg->getType(), Ln10), Arg, "", CI);
732
733 return CallInst::Create(NewF, Mul, "", CI);
734 });
David Neto22f144c2017-06-12 14:26:21 -0400735}
736
SJW2c317da2020-03-23 07:39:13 -0500737bool ReplaceOpenCLBuiltinPass::replaceFmod(Function &F) {
Kévin Petit0644a9c2019-06-20 21:08:46 +0100738 // OpenCL fmod(x,y) is x - y * trunc(x/y)
739 // The sign for a non-zero result is taken from x.
740 // (Try an example.)
741 // So translate to FRem
SJW2c317da2020-03-23 07:39:13 -0500742 return replaceCallsWithValue(F, [](CallInst *CI) {
Kévin Petit0644a9c2019-06-20 21:08:46 +0100743 auto Op0 = CI->getOperand(0);
744 auto Op1 = CI->getOperand(1);
745 return BinaryOperator::Create(Instruction::FRem, Op0, Op1, "", CI);
746 });
747}
748
SJW2c317da2020-03-23 07:39:13 -0500749bool ReplaceOpenCLBuiltinPass::replaceLog10(Function &F,
SJW61531372020-06-09 07:31:08 -0500750 const std::string &basename) {
SJW2c317da2020-03-23 07:39:13 -0500751 // convert to natural
752 auto slen = basename.length() - 2;
SJW61531372020-06-09 07:31:08 -0500753 std::string NewFName = basename.substr(0, slen);
754 NewFName =
755 Builtins::GetMangledFunctionName(NewFName.c_str(), F.getFunctionType());
David Neto22f144c2017-06-12 14:26:21 -0400756
SJW2c317da2020-03-23 07:39:13 -0500757 Module &M = *F.getParent();
758 return replaceCallsWithValue(F, [&](CallInst *CI) {
759 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
760
761 auto Arg = CI->getOperand(0);
762
763 // Constant of the reciprocal of the natural log of 10 (ln(10)).
764 const double Ln10 =
765 0.434294481903251827651128918916605082294397005803666566114;
766
767 auto NewCI = CallInst::Create(NewF, Arg, "", CI);
768
769 return BinaryOperator::Create(Instruction::FMul,
770 ConstantFP::get(Arg->getType(), Ln10), NewCI,
771 "", CI);
772 });
David Neto22f144c2017-06-12 14:26:21 -0400773}
774
gnl21636e7992020-09-09 16:08:16 +0100775bool ReplaceOpenCLBuiltinPass::replaceLog1p(Function &F) {
776 // convert to natural
777 std::string NewFName =
778 Builtins::GetMangledFunctionName("log", F.getFunctionType());
779
780 Module &M = *F.getParent();
781 return replaceCallsWithValue(F, [&](CallInst *CI) {
782 auto NewF = M.getOrInsertFunction(NewFName, F.getFunctionType());
783
784 auto Arg = CI->getOperand(0);
785
786 auto ArgP1 = BinaryOperator::Create(
787 Instruction::FAdd, ConstantFP::get(Arg->getType(), 1.0), Arg, "", CI);
788
789 return CallInst::Create(NewF, ArgP1, "", CI);
790 });
791}
792
alan-baker12d2c182020-07-20 08:22:42 -0400793bool ReplaceOpenCLBuiltinPass::replaceBarrier(Function &F, bool subgroup) {
David Neto22f144c2017-06-12 14:26:21 -0400794
alan-bakerf6bc8252020-09-23 14:58:55 -0400795 enum {
796 CLK_LOCAL_MEM_FENCE = 0x01,
797 CLK_GLOBAL_MEM_FENCE = 0x02,
798 CLK_IMAGE_MEM_FENCE = 0x04
799 };
David Neto22f144c2017-06-12 14:26:21 -0400800
alan-baker12d2c182020-07-20 08:22:42 -0400801 return replaceCallsWithValue(F, [subgroup](CallInst *CI) {
Kévin Petitc4643922019-06-17 19:32:05 +0100802 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400803
Kévin Petitc4643922019-06-17 19:32:05 +0100804 // We need to map the OpenCL constants to the SPIR-V equivalents.
805 const auto LocalMemFence =
806 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
807 const auto GlobalMemFence =
808 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
alan-bakerf6bc8252020-09-23 14:58:55 -0400809 const auto ImageMemFence =
810 ConstantInt::get(Arg->getType(), CLK_IMAGE_MEM_FENCE);
alan-baker12d2c182020-07-20 08:22:42 -0400811 const auto ConstantAcquireRelease = ConstantInt::get(
812 Arg->getType(), spv::MemorySemanticsAcquireReleaseMask);
Kévin Petitc4643922019-06-17 19:32:05 +0100813 const auto ConstantScopeDevice =
814 ConstantInt::get(Arg->getType(), spv::ScopeDevice);
815 const auto ConstantScopeWorkgroup =
816 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
alan-baker12d2c182020-07-20 08:22:42 -0400817 const auto ConstantScopeSubgroup =
818 ConstantInt::get(Arg->getType(), spv::ScopeSubgroup);
David Neto22f144c2017-06-12 14:26:21 -0400819
Kévin Petitc4643922019-06-17 19:32:05 +0100820 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
821 const auto LocalMemFenceMask =
822 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
823 const auto WorkgroupShiftAmount =
824 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
825 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
826 Instruction::Shl, LocalMemFenceMask,
827 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400828
Kévin Petitc4643922019-06-17 19:32:05 +0100829 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
830 const auto GlobalMemFenceMask =
831 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
832 const auto UniformShiftAmount =
833 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
834 const auto MemorySemanticsUniform = BinaryOperator::Create(
835 Instruction::Shl, GlobalMemFenceMask,
836 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400837
alan-bakerf6bc8252020-09-23 14:58:55 -0400838 // OpenCL 2.0
839 // Map CLK_IMAGE_MEM_FENCE to MemorySemanticsImageMemoryMask.
840 const auto ImageMemFenceMask =
841 BinaryOperator::Create(Instruction::And, ImageMemFence, Arg, "", CI);
842 const auto ImageShiftAmount =
843 clz(spv::MemorySemanticsImageMemoryMask) - clz(CLK_IMAGE_MEM_FENCE);
844 const auto MemorySemanticsImage = BinaryOperator::Create(
845 Instruction::Shl, ImageMemFenceMask,
846 ConstantInt::get(Arg->getType(), ImageShiftAmount), "", CI);
847
Kévin Petitc4643922019-06-17 19:32:05 +0100848 // And combine the above together, also adding in
alan-bakerf6bc8252020-09-23 14:58:55 -0400849 // MemorySemanticsSequentiallyConsistentMask.
850 auto MemorySemantics1 =
Kévin Petitc4643922019-06-17 19:32:05 +0100851 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
alan-baker12d2c182020-07-20 08:22:42 -0400852 ConstantAcquireRelease, "", CI);
alan-bakerf6bc8252020-09-23 14:58:55 -0400853 auto MemorySemantics2 = BinaryOperator::Create(
854 Instruction::Or, MemorySemanticsUniform, MemorySemanticsImage, "", CI);
855 auto MemorySemantics = BinaryOperator::Create(
856 Instruction::Or, MemorySemantics1, MemorySemantics2, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400857
alan-baker12d2c182020-07-20 08:22:42 -0400858 // If the memory scope is not specified explicitly, it is either Subgroup
859 // or Workgroup depending on the type of barrier.
860 Value *MemoryScope =
861 subgroup ? ConstantScopeSubgroup : ConstantScopeWorkgroup;
862 if (CI->data_operands_size() > 1) {
863 enum {
864 CL_MEMORY_SCOPE_WORKGROUP = 0x1,
865 CL_MEMORY_SCOPE_DEVICE = 0x2,
866 CL_MEMORY_SCOPE_SUBGROUP = 0x4
867 };
868 // The call was given an explicit memory scope.
869 const auto MemoryScopeSubgroup =
870 ConstantInt::get(Arg->getType(), CL_MEMORY_SCOPE_SUBGROUP);
871 const auto MemoryScopeDevice =
872 ConstantInt::get(Arg->getType(), CL_MEMORY_SCOPE_DEVICE);
David Neto22f144c2017-06-12 14:26:21 -0400873
alan-baker12d2c182020-07-20 08:22:42 -0400874 auto Cmp =
875 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
876 MemoryScopeSubgroup, CI->getOperand(1), "", CI);
877 MemoryScope = SelectInst::Create(Cmp, ConstantScopeSubgroup,
878 ConstantScopeWorkgroup, "", CI);
879 Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
880 MemoryScopeDevice, CI->getOperand(1), "", CI);
881 MemoryScope =
882 SelectInst::Create(Cmp, ConstantScopeDevice, MemoryScope, "", CI);
883 }
884
885 // Lastly, the Execution Scope is either Workgroup or Subgroup depending on
886 // the type of barrier;
887 const auto ExecutionScope =
888 subgroup ? ConstantScopeSubgroup : ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400889
Kévin Petitc4643922019-06-17 19:32:05 +0100890 return clspv::InsertSPIRVOp(CI, spv::OpControlBarrier,
alan-baker3d905692020-10-28 14:02:37 -0400891 {Attribute::NoDuplicate, Attribute::Convergent},
892 CI->getType(),
Kévin Petitc4643922019-06-17 19:32:05 +0100893 {ExecutionScope, MemoryScope, MemorySemantics});
894 });
David Neto22f144c2017-06-12 14:26:21 -0400895}
896
SJW2c317da2020-03-23 07:39:13 -0500897bool ReplaceOpenCLBuiltinPass::replaceMemFence(Function &F,
898 uint32_t semantics) {
David Neto22f144c2017-06-12 14:26:21 -0400899
SJW2c317da2020-03-23 07:39:13 -0500900 return replaceCallsWithValue(F, [&](CallInst *CI) {
alan-bakerf6bc8252020-09-23 14:58:55 -0400901 enum {
902 CLK_LOCAL_MEM_FENCE = 0x01,
903 CLK_GLOBAL_MEM_FENCE = 0x02,
904 CLK_IMAGE_MEM_FENCE = 0x04,
905 };
David Neto22f144c2017-06-12 14:26:21 -0400906
SJW2c317da2020-03-23 07:39:13 -0500907 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400908
SJW2c317da2020-03-23 07:39:13 -0500909 // We need to map the OpenCL constants to the SPIR-V equivalents.
910 const auto LocalMemFence =
911 ConstantInt::get(Arg->getType(), CLK_LOCAL_MEM_FENCE);
912 const auto GlobalMemFence =
913 ConstantInt::get(Arg->getType(), CLK_GLOBAL_MEM_FENCE);
alan-bakerf6bc8252020-09-23 14:58:55 -0400914 const auto ImageMemFence =
915 ConstantInt::get(Arg->getType(), CLK_IMAGE_MEM_FENCE);
SJW2c317da2020-03-23 07:39:13 -0500916 const auto ConstantMemorySemantics =
917 ConstantInt::get(Arg->getType(), semantics);
alan-baker12d2c182020-07-20 08:22:42 -0400918 const auto ConstantScopeWorkgroup =
919 ConstantInt::get(Arg->getType(), spv::ScopeWorkgroup);
David Neto22f144c2017-06-12 14:26:21 -0400920
SJW2c317da2020-03-23 07:39:13 -0500921 // Map CLK_LOCAL_MEM_FENCE to MemorySemanticsWorkgroupMemoryMask.
922 const auto LocalMemFenceMask =
923 BinaryOperator::Create(Instruction::And, LocalMemFence, Arg, "", CI);
924 const auto WorkgroupShiftAmount =
925 clz(spv::MemorySemanticsWorkgroupMemoryMask) - clz(CLK_LOCAL_MEM_FENCE);
926 const auto MemorySemanticsWorkgroup = BinaryOperator::Create(
927 Instruction::Shl, LocalMemFenceMask,
928 ConstantInt::get(Arg->getType(), WorkgroupShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400929
SJW2c317da2020-03-23 07:39:13 -0500930 // Map CLK_GLOBAL_MEM_FENCE to MemorySemanticsUniformMemoryMask.
931 const auto GlobalMemFenceMask =
932 BinaryOperator::Create(Instruction::And, GlobalMemFence, Arg, "", CI);
933 const auto UniformShiftAmount =
934 clz(spv::MemorySemanticsUniformMemoryMask) - clz(CLK_GLOBAL_MEM_FENCE);
935 const auto MemorySemanticsUniform = BinaryOperator::Create(
936 Instruction::Shl, GlobalMemFenceMask,
937 ConstantInt::get(Arg->getType(), UniformShiftAmount), "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400938
alan-bakerf6bc8252020-09-23 14:58:55 -0400939 // OpenCL 2.0
940 // Map CLK_IMAGE_MEM_FENCE to MemorySemanticsImageMemoryMask.
941 const auto ImageMemFenceMask =
942 BinaryOperator::Create(Instruction::And, ImageMemFence, Arg, "", CI);
943 const auto ImageShiftAmount =
944 clz(spv::MemorySemanticsImageMemoryMask) - clz(CLK_IMAGE_MEM_FENCE);
945 const auto MemorySemanticsImage = BinaryOperator::Create(
946 Instruction::Shl, ImageMemFenceMask,
947 ConstantInt::get(Arg->getType(), ImageShiftAmount), "", CI);
948
SJW2c317da2020-03-23 07:39:13 -0500949 // And combine the above together, also adding in
alan-bakerf6bc8252020-09-23 14:58:55 -0400950 // |semantics|.
951 auto MemorySemantics1 =
SJW2c317da2020-03-23 07:39:13 -0500952 BinaryOperator::Create(Instruction::Or, MemorySemanticsWorkgroup,
953 ConstantMemorySemantics, "", CI);
alan-bakerf6bc8252020-09-23 14:58:55 -0400954 auto MemorySemantics2 = BinaryOperator::Create(
955 Instruction::Or, MemorySemanticsUniform, MemorySemanticsImage, "", CI);
956 auto MemorySemantics = BinaryOperator::Create(
957 Instruction::Or, MemorySemantics1, MemorySemantics2, "", CI);
David Neto22f144c2017-06-12 14:26:21 -0400958
alan-baker12d2c182020-07-20 08:22:42 -0400959 // Memory Scope is always workgroup.
960 const auto MemoryScope = ConstantScopeWorkgroup;
David Neto22f144c2017-06-12 14:26:21 -0400961
alan-baker3d905692020-10-28 14:02:37 -0400962 return clspv::InsertSPIRVOp(CI, spv::OpMemoryBarrier,
963 {Attribute::Convergent}, CI->getType(),
SJW2c317da2020-03-23 07:39:13 -0500964 {MemoryScope, MemorySemantics});
965 });
David Neto22f144c2017-06-12 14:26:21 -0400966}
967
Kévin Petit1cb45112020-04-27 18:55:48 +0100968bool ReplaceOpenCLBuiltinPass::replacePrefetch(Function &F) {
969 bool Changed = false;
970
971 SmallVector<Instruction *, 4> ToRemoves;
972
973 // Find all calls to the function
974 for (auto &U : F.uses()) {
975 if (auto CI = dyn_cast<CallInst>(U.getUser())) {
976 ToRemoves.push_back(CI);
977 }
978 }
979
980 Changed = !ToRemoves.empty();
981
982 // Delete them
983 for (auto V : ToRemoves) {
984 V->eraseFromParent();
985 }
986
987 return Changed;
988}
989
SJW2c317da2020-03-23 07:39:13 -0500990bool ReplaceOpenCLBuiltinPass::replaceRelational(Function &F,
991 CmpInst::Predicate P,
992 int32_t C) {
993 return replaceCallsWithValue(F, [&](CallInst *CI) {
994 // The predicate to use in the CmpInst.
995 auto Predicate = P;
David Neto22f144c2017-06-12 14:26:21 -0400996
SJW2c317da2020-03-23 07:39:13 -0500997 // The value to return for true.
998 auto TrueValue = ConstantInt::getSigned(CI->getType(), C);
David Neto22f144c2017-06-12 14:26:21 -0400999
SJW2c317da2020-03-23 07:39:13 -05001000 // The value to return for false.
1001 auto FalseValue = Constant::getNullValue(CI->getType());
David Neto22f144c2017-06-12 14:26:21 -04001002
SJW2c317da2020-03-23 07:39:13 -05001003 auto Arg1 = CI->getOperand(0);
1004 auto Arg2 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001005
SJW2c317da2020-03-23 07:39:13 -05001006 const auto Cmp =
1007 CmpInst::Create(Instruction::FCmp, Predicate, Arg1, Arg2, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001008
SJW2c317da2020-03-23 07:39:13 -05001009 return SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1010 });
David Neto22f144c2017-06-12 14:26:21 -04001011}
1012
SJW2c317da2020-03-23 07:39:13 -05001013bool ReplaceOpenCLBuiltinPass::replaceIsInfAndIsNan(Function &F,
1014 spv::Op SPIRVOp,
1015 int32_t C) {
1016 Module &M = *F.getParent();
1017 return replaceCallsWithValue(F, [&](CallInst *CI) {
1018 const auto CITy = CI->getType();
David Neto22f144c2017-06-12 14:26:21 -04001019
SJW2c317da2020-03-23 07:39:13 -05001020 // The value to return for true.
1021 auto TrueValue = ConstantInt::getSigned(CITy, C);
David Neto22f144c2017-06-12 14:26:21 -04001022
SJW2c317da2020-03-23 07:39:13 -05001023 // The value to return for false.
1024 auto FalseValue = Constant::getNullValue(CITy);
David Neto22f144c2017-06-12 14:26:21 -04001025
SJW2c317da2020-03-23 07:39:13 -05001026 Type *CorrespondingBoolTy = Type::getInt1Ty(M.getContext());
James Pricecf53df42020-04-20 14:41:24 -04001027 if (auto CIVecTy = dyn_cast<VectorType>(CITy)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001028 CorrespondingBoolTy =
1029 FixedVectorType::get(Type::getInt1Ty(M.getContext()),
1030 CIVecTy->getElementCount().getKnownMinValue());
David Neto22f144c2017-06-12 14:26:21 -04001031 }
David Neto22f144c2017-06-12 14:26:21 -04001032
SJW2c317da2020-03-23 07:39:13 -05001033 auto NewCI = clspv::InsertSPIRVOp(CI, SPIRVOp, {Attribute::ReadNone},
1034 CorrespondingBoolTy, {CI->getOperand(0)});
1035
1036 return SelectInst::Create(NewCI, TrueValue, FalseValue, "", CI);
1037 });
David Neto22f144c2017-06-12 14:26:21 -04001038}
1039
SJW2c317da2020-03-23 07:39:13 -05001040bool ReplaceOpenCLBuiltinPass::replaceIsFinite(Function &F) {
1041 Module &M = *F.getParent();
1042 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001043 auto &C = M.getContext();
1044 auto Val = CI->getOperand(0);
1045 auto ValTy = Val->getType();
1046 auto RetTy = CI->getType();
1047
1048 // Get a suitable integer type to represent the number
1049 auto IntTy = getIntOrIntVectorTyForCast(C, ValTy);
1050
1051 // Create Mask
1052 auto ScalarSize = ValTy->getScalarSizeInBits();
SJW2c317da2020-03-23 07:39:13 -05001053 Value *InfMask = nullptr;
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001054 switch (ScalarSize) {
1055 case 16:
1056 InfMask = ConstantInt::get(IntTy, 0x7C00U);
1057 break;
1058 case 32:
1059 InfMask = ConstantInt::get(IntTy, 0x7F800000U);
1060 break;
1061 case 64:
1062 InfMask = ConstantInt::get(IntTy, 0x7FF0000000000000ULL);
1063 break;
1064 default:
1065 llvm_unreachable("Unsupported floating-point type");
1066 }
1067
1068 IRBuilder<> Builder(CI);
1069
1070 // Bitcast to int
1071 auto ValInt = Builder.CreateBitCast(Val, IntTy);
1072
1073 // Mask and compare
1074 auto InfBits = Builder.CreateAnd(InfMask, ValInt);
1075 auto Cmp = Builder.CreateICmp(CmpInst::ICMP_EQ, InfBits, InfMask);
1076
1077 auto RetFalse = ConstantInt::get(RetTy, 0);
SJW2c317da2020-03-23 07:39:13 -05001078 Value *RetTrue = nullptr;
Kévin Petitfdfa92e2019-09-25 14:20:58 +01001079 if (ValTy->isVectorTy()) {
1080 RetTrue = ConstantInt::getSigned(RetTy, -1);
1081 } else {
1082 RetTrue = ConstantInt::get(RetTy, 1);
1083 }
1084 return Builder.CreateSelect(Cmp, RetFalse, RetTrue);
1085 });
1086}
1087
SJW2c317da2020-03-23 07:39:13 -05001088bool ReplaceOpenCLBuiltinPass::replaceAllAndAny(Function &F, spv::Op SPIRVOp) {
1089 Module &M = *F.getParent();
1090 return replaceCallsWithValue(F, [&](CallInst *CI) {
1091 auto Arg = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001092
SJW2c317da2020-03-23 07:39:13 -05001093 Value *V = nullptr;
Kévin Petitfd27cca2018-10-31 13:00:17 +00001094
SJW2c317da2020-03-23 07:39:13 -05001095 // If the argument is a 32-bit int, just use a shift
1096 if (Arg->getType() == Type::getInt32Ty(M.getContext())) {
1097 V = BinaryOperator::Create(Instruction::LShr, Arg,
1098 ConstantInt::get(Arg->getType(), 31), "", CI);
1099 } else {
1100 // The value for zero to compare against.
1101 const auto ZeroValue = Constant::getNullValue(Arg->getType());
David Neto22f144c2017-06-12 14:26:21 -04001102
SJW2c317da2020-03-23 07:39:13 -05001103 // The value to return for true.
1104 const auto TrueValue = ConstantInt::get(CI->getType(), 1);
David Neto22f144c2017-06-12 14:26:21 -04001105
SJW2c317da2020-03-23 07:39:13 -05001106 // The value to return for false.
1107 const auto FalseValue = Constant::getNullValue(CI->getType());
David Neto22f144c2017-06-12 14:26:21 -04001108
SJW2c317da2020-03-23 07:39:13 -05001109 const auto Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT,
1110 Arg, ZeroValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001111
SJW2c317da2020-03-23 07:39:13 -05001112 Value *SelectSource = nullptr;
David Neto22f144c2017-06-12 14:26:21 -04001113
SJW2c317da2020-03-23 07:39:13 -05001114 // If we have a function to call, call it!
1115 if (SPIRVOp != spv::OpNop) {
David Neto22f144c2017-06-12 14:26:21 -04001116
SJW2c317da2020-03-23 07:39:13 -05001117 const auto BoolTy = Type::getInt1Ty(M.getContext());
David Neto22f144c2017-06-12 14:26:21 -04001118
SJW2c317da2020-03-23 07:39:13 -05001119 const auto NewCI = clspv::InsertSPIRVOp(
1120 CI, SPIRVOp, {Attribute::ReadNone}, BoolTy, {Cmp});
1121 SelectSource = NewCI;
David Neto22f144c2017-06-12 14:26:21 -04001122
SJW2c317da2020-03-23 07:39:13 -05001123 } else {
1124 SelectSource = Cmp;
David Neto22f144c2017-06-12 14:26:21 -04001125 }
1126
SJW2c317da2020-03-23 07:39:13 -05001127 V = SelectInst::Create(SelectSource, TrueValue, FalseValue, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001128 }
SJW2c317da2020-03-23 07:39:13 -05001129 return V;
1130 });
David Neto22f144c2017-06-12 14:26:21 -04001131}
1132
SJW2c317da2020-03-23 07:39:13 -05001133bool ReplaceOpenCLBuiltinPass::replaceUpsample(Function &F) {
1134 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1135 // Get arguments
1136 auto HiValue = CI->getOperand(0);
1137 auto LoValue = CI->getOperand(1);
Kévin Petitbf0036c2019-03-06 13:57:10 +00001138
SJW2c317da2020-03-23 07:39:13 -05001139 // Don't touch overloads that aren't in OpenCL C
1140 auto HiType = HiValue->getType();
1141 auto LoType = LoValue->getType();
1142
1143 if (HiType != LoType) {
1144 return nullptr;
Kévin Petitbf0036c2019-03-06 13:57:10 +00001145 }
Kévin Petitbf0036c2019-03-06 13:57:10 +00001146
SJW2c317da2020-03-23 07:39:13 -05001147 if (!HiType->isIntOrIntVectorTy()) {
1148 return nullptr;
Kévin Petitbf0036c2019-03-06 13:57:10 +00001149 }
Kévin Petitbf0036c2019-03-06 13:57:10 +00001150
SJW2c317da2020-03-23 07:39:13 -05001151 if (HiType->getScalarSizeInBits() * 2 !=
1152 CI->getType()->getScalarSizeInBits()) {
1153 return nullptr;
1154 }
1155
1156 if ((HiType->getScalarSizeInBits() != 8) &&
1157 (HiType->getScalarSizeInBits() != 16) &&
1158 (HiType->getScalarSizeInBits() != 32)) {
1159 return nullptr;
1160 }
1161
James Pricecf53df42020-04-20 14:41:24 -04001162 if (auto HiVecType = dyn_cast<VectorType>(HiType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001163 unsigned NumElements = HiVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001164 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1165 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001166 return nullptr;
1167 }
1168 }
1169
1170 // Convert both operands to the result type
1171 auto HiCast = CastInst::CreateZExtOrBitCast(HiValue, CI->getType(), "", CI);
1172 auto LoCast = CastInst::CreateZExtOrBitCast(LoValue, CI->getType(), "", CI);
1173
1174 // Shift high operand
1175 auto ShiftAmount =
1176 ConstantInt::get(CI->getType(), HiType->getScalarSizeInBits());
1177 auto HiShifted =
1178 BinaryOperator::Create(Instruction::Shl, HiCast, ShiftAmount, "", CI);
1179
1180 // OR both results
1181 return BinaryOperator::Create(Instruction::Or, HiShifted, LoCast, "", CI);
1182 });
Kévin Petitbf0036c2019-03-06 13:57:10 +00001183}
1184
SJW2c317da2020-03-23 07:39:13 -05001185bool ReplaceOpenCLBuiltinPass::replaceRotate(Function &F) {
1186 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1187 // Get arguments
1188 auto SrcValue = CI->getOperand(0);
1189 auto RotAmount = CI->getOperand(1);
Kévin Petitd44eef52019-03-08 13:22:14 +00001190
SJW2c317da2020-03-23 07:39:13 -05001191 // Don't touch overloads that aren't in OpenCL C
1192 auto SrcType = SrcValue->getType();
1193 auto RotType = RotAmount->getType();
1194
1195 if ((SrcType != RotType) || (CI->getType() != SrcType)) {
1196 return nullptr;
Kévin Petitd44eef52019-03-08 13:22:14 +00001197 }
Kévin Petitd44eef52019-03-08 13:22:14 +00001198
SJW2c317da2020-03-23 07:39:13 -05001199 if (!SrcType->isIntOrIntVectorTy()) {
1200 return nullptr;
Kévin Petitd44eef52019-03-08 13:22:14 +00001201 }
Kévin Petitd44eef52019-03-08 13:22:14 +00001202
SJW2c317da2020-03-23 07:39:13 -05001203 if ((SrcType->getScalarSizeInBits() != 8) &&
1204 (SrcType->getScalarSizeInBits() != 16) &&
1205 (SrcType->getScalarSizeInBits() != 32) &&
1206 (SrcType->getScalarSizeInBits() != 64)) {
1207 return nullptr;
1208 }
1209
James Pricecf53df42020-04-20 14:41:24 -04001210 if (auto SrcVecType = dyn_cast<VectorType>(SrcType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001211 unsigned NumElements = SrcVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001212 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1213 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001214 return nullptr;
1215 }
1216 }
1217
alan-bakerfd22ae12020-10-29 15:59:22 -04001218 // Replace with LLVM's funnel shift left intrinsic because it is more
1219 // generic than rotate.
1220 Function *intrinsic =
1221 Intrinsic::getDeclaration(F.getParent(), Intrinsic::fshl, SrcType);
1222 return CallInst::Create(intrinsic->getFunctionType(), intrinsic,
1223 {SrcValue, SrcValue, RotAmount}, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001224 });
Kévin Petitd44eef52019-03-08 13:22:14 +00001225}
1226
SJW2c317da2020-03-23 07:39:13 -05001227bool ReplaceOpenCLBuiltinPass::replaceConvert(Function &F, bool SrcIsSigned,
1228 bool DstIsSigned) {
1229 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1230 Value *V = nullptr;
1231 // Get arguments
1232 auto SrcValue = CI->getOperand(0);
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001233
SJW2c317da2020-03-23 07:39:13 -05001234 // Don't touch overloads that aren't in OpenCL C
1235 auto SrcType = SrcValue->getType();
1236 auto DstType = CI->getType();
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001237
SJW2c317da2020-03-23 07:39:13 -05001238 if ((SrcType->isVectorTy() && !DstType->isVectorTy()) ||
1239 (!SrcType->isVectorTy() && DstType->isVectorTy())) {
1240 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001241 }
1242
James Pricecf53df42020-04-20 14:41:24 -04001243 if (auto SrcVecType = dyn_cast<VectorType>(SrcType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001244 unsigned SrcNumElements =
1245 SrcVecType->getElementCount().getKnownMinValue();
1246 unsigned DstNumElements =
1247 cast<VectorType>(DstType)->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001248 if (SrcNumElements != DstNumElements) {
SJW2c317da2020-03-23 07:39:13 -05001249 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001250 }
1251
James Pricecf53df42020-04-20 14:41:24 -04001252 if ((SrcNumElements != 2) && (SrcNumElements != 3) &&
1253 (SrcNumElements != 4) && (SrcNumElements != 8) &&
1254 (SrcNumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001255 return V;
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001256 }
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001257 }
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001258
SJW2c317da2020-03-23 07:39:13 -05001259 bool SrcIsFloat = SrcType->getScalarType()->isFloatingPointTy();
1260 bool DstIsFloat = DstType->getScalarType()->isFloatingPointTy();
1261
1262 bool SrcIsInt = SrcType->isIntOrIntVectorTy();
1263 bool DstIsInt = DstType->isIntOrIntVectorTy();
1264
1265 if (SrcType == DstType && DstIsSigned == SrcIsSigned) {
1266 // Unnecessary cast operation.
1267 V = SrcValue;
1268 } else if (SrcIsFloat && DstIsFloat) {
1269 V = CastInst::CreateFPCast(SrcValue, DstType, "", CI);
1270 } else if (SrcIsFloat && DstIsInt) {
1271 if (DstIsSigned) {
1272 V = CastInst::Create(Instruction::FPToSI, SrcValue, DstType, "", CI);
1273 } else {
1274 V = CastInst::Create(Instruction::FPToUI, SrcValue, DstType, "", CI);
1275 }
1276 } else if (SrcIsInt && DstIsFloat) {
1277 if (SrcIsSigned) {
1278 V = CastInst::Create(Instruction::SIToFP, SrcValue, DstType, "", CI);
1279 } else {
1280 V = CastInst::Create(Instruction::UIToFP, SrcValue, DstType, "", CI);
1281 }
1282 } else if (SrcIsInt && DstIsInt) {
1283 V = CastInst::CreateIntegerCast(SrcValue, DstType, SrcIsSigned, "", CI);
1284 } else {
1285 // Not something we're supposed to handle, just move on
1286 }
1287
1288 return V;
1289 });
Kévin Petit9d1a9d12019-03-25 15:23:46 +00001290}
1291
SJW2c317da2020-03-23 07:39:13 -05001292bool ReplaceOpenCLBuiltinPass::replaceMulHi(Function &F, bool is_signed,
1293 bool is_mad) {
1294 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1295 Value *V = nullptr;
1296 // Get arguments
1297 auto AValue = CI->getOperand(0);
1298 auto BValue = CI->getOperand(1);
1299 auto CValue = CI->getOperand(2);
Kévin Petit8a560882019-03-21 15:24:34 +00001300
SJW2c317da2020-03-23 07:39:13 -05001301 // Don't touch overloads that aren't in OpenCL C
1302 auto AType = AValue->getType();
1303 auto BType = BValue->getType();
1304 auto CType = CValue->getType();
Kévin Petit8a560882019-03-21 15:24:34 +00001305
SJW2c317da2020-03-23 07:39:13 -05001306 if ((AType != BType) || (CI->getType() != AType) ||
1307 (is_mad && (AType != CType))) {
1308 return V;
Kévin Petit8a560882019-03-21 15:24:34 +00001309 }
1310
SJW2c317da2020-03-23 07:39:13 -05001311 if (!AType->isIntOrIntVectorTy()) {
1312 return V;
Kévin Petit8a560882019-03-21 15:24:34 +00001313 }
Kévin Petit8a560882019-03-21 15:24:34 +00001314
SJW2c317da2020-03-23 07:39:13 -05001315 if ((AType->getScalarSizeInBits() != 8) &&
1316 (AType->getScalarSizeInBits() != 16) &&
1317 (AType->getScalarSizeInBits() != 32) &&
1318 (AType->getScalarSizeInBits() != 64)) {
1319 return V;
1320 }
Kévin Petit617a76d2019-04-04 13:54:16 +01001321
James Pricecf53df42020-04-20 14:41:24 -04001322 if (auto AVecType = dyn_cast<VectorType>(AType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001323 unsigned NumElements = AVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001324 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1325 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001326 return V;
Kévin Petit617a76d2019-04-04 13:54:16 +01001327 }
1328 }
1329
SJW2c317da2020-03-23 07:39:13 -05001330 // Our SPIR-V op returns a struct, create a type for it
alan-baker6b9d1ee2020-11-03 23:11:32 -05001331 auto ExMulRetType = GetPairStruct(AType);
Kévin Petit617a76d2019-04-04 13:54:16 +01001332
SJW2c317da2020-03-23 07:39:13 -05001333 // Select the appropriate signed/unsigned SPIR-V op
1334 spv::Op opcode = is_signed ? spv::OpSMulExtended : spv::OpUMulExtended;
1335
1336 // Call the SPIR-V op
1337 auto Call = clspv::InsertSPIRVOp(CI, opcode, {Attribute::ReadNone},
1338 ExMulRetType, {AValue, BValue});
1339
1340 // Get the high part of the result
1341 unsigned Idxs[] = {1};
1342 V = ExtractValueInst::Create(Call, Idxs, "", CI);
1343
1344 // If we're handling a mad_hi, add the third argument to the result
1345 if (is_mad) {
1346 V = BinaryOperator::Create(Instruction::Add, V, CValue, "", CI);
Kévin Petit617a76d2019-04-04 13:54:16 +01001347 }
1348
SJW2c317da2020-03-23 07:39:13 -05001349 return V;
1350 });
Kévin Petit8a560882019-03-21 15:24:34 +00001351}
1352
SJW2c317da2020-03-23 07:39:13 -05001353bool ReplaceOpenCLBuiltinPass::replaceSelect(Function &F) {
1354 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1355 // Get arguments
1356 auto FalseValue = CI->getOperand(0);
1357 auto TrueValue = CI->getOperand(1);
1358 auto PredicateValue = CI->getOperand(2);
Kévin Petitf5b78a22018-10-25 14:32:17 +00001359
SJW2c317da2020-03-23 07:39:13 -05001360 // Don't touch overloads that aren't in OpenCL C
1361 auto FalseType = FalseValue->getType();
1362 auto TrueType = TrueValue->getType();
1363 auto PredicateType = PredicateValue->getType();
1364
1365 if (FalseType != TrueType) {
1366 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001367 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001368
SJW2c317da2020-03-23 07:39:13 -05001369 if (!PredicateType->isIntOrIntVectorTy()) {
1370 return nullptr;
1371 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001372
SJW2c317da2020-03-23 07:39:13 -05001373 if (!FalseType->isIntOrIntVectorTy() &&
1374 !FalseType->getScalarType()->isFloatingPointTy()) {
1375 return nullptr;
1376 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001377
SJW2c317da2020-03-23 07:39:13 -05001378 if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
1379 return nullptr;
1380 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001381
SJW2c317da2020-03-23 07:39:13 -05001382 if (FalseType->getScalarSizeInBits() !=
1383 PredicateType->getScalarSizeInBits()) {
1384 return nullptr;
1385 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001386
James Pricecf53df42020-04-20 14:41:24 -04001387 if (auto FalseVecType = dyn_cast<VectorType>(FalseType)) {
alan-baker5a8c3be2020-09-09 13:44:26 -04001388 unsigned NumElements = FalseVecType->getElementCount().getKnownMinValue();
1389 if (NumElements != cast<VectorType>(PredicateType)
1390 ->getElementCount()
1391 .getKnownMinValue()) {
SJW2c317da2020-03-23 07:39:13 -05001392 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001393 }
1394
James Pricecf53df42020-04-20 14:41:24 -04001395 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1396 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001397 return nullptr;
Kévin Petitf5b78a22018-10-25 14:32:17 +00001398 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001399 }
Kévin Petitf5b78a22018-10-25 14:32:17 +00001400
SJW2c317da2020-03-23 07:39:13 -05001401 // Create constant
1402 const auto ZeroValue = Constant::getNullValue(PredicateType);
1403
1404 // Scalar and vector are to be treated differently
1405 CmpInst::Predicate Pred;
1406 if (PredicateType->isVectorTy()) {
1407 Pred = CmpInst::ICMP_SLT;
1408 } else {
1409 Pred = CmpInst::ICMP_NE;
1410 }
1411
1412 // Create comparison instruction
1413 auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
1414 ZeroValue, "", CI);
1415
1416 // Create select
1417 return SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
1418 });
Kévin Petitf5b78a22018-10-25 14:32:17 +00001419}
1420
SJW2c317da2020-03-23 07:39:13 -05001421bool ReplaceOpenCLBuiltinPass::replaceBitSelect(Function &F) {
1422 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1423 Value *V = nullptr;
1424 if (CI->getNumOperands() != 4) {
1425 return V;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001426 }
Kévin Petite7d0cce2018-10-31 12:38:56 +00001427
SJW2c317da2020-03-23 07:39:13 -05001428 // Get arguments
1429 auto FalseValue = CI->getOperand(0);
1430 auto TrueValue = CI->getOperand(1);
1431 auto PredicateValue = CI->getOperand(2);
Kévin Petite7d0cce2018-10-31 12:38:56 +00001432
SJW2c317da2020-03-23 07:39:13 -05001433 // Don't touch overloads that aren't in OpenCL C
1434 auto FalseType = FalseValue->getType();
1435 auto TrueType = TrueValue->getType();
1436 auto PredicateType = PredicateValue->getType();
Kévin Petite7d0cce2018-10-31 12:38:56 +00001437
SJW2c317da2020-03-23 07:39:13 -05001438 if ((FalseType != TrueType) || (PredicateType != TrueType)) {
1439 return V;
Kévin Petite7d0cce2018-10-31 12:38:56 +00001440 }
Kévin Petite7d0cce2018-10-31 12:38:56 +00001441
James Pricecf53df42020-04-20 14:41:24 -04001442 if (auto TrueVecType = dyn_cast<VectorType>(TrueType)) {
SJW2c317da2020-03-23 07:39:13 -05001443 if (!TrueType->getScalarType()->isFloatingPointTy() &&
1444 !TrueType->getScalarType()->isIntegerTy()) {
1445 return V;
1446 }
alan-baker5a8c3be2020-09-09 13:44:26 -04001447 unsigned NumElements = TrueVecType->getElementCount().getKnownMinValue();
James Pricecf53df42020-04-20 14:41:24 -04001448 if ((NumElements != 2) && (NumElements != 3) && (NumElements != 4) &&
1449 (NumElements != 8) && (NumElements != 16)) {
SJW2c317da2020-03-23 07:39:13 -05001450 return V;
1451 }
1452 }
1453
1454 // Remember the type of the operands
1455 auto OpType = TrueType;
1456
1457 // The actual bit selection will always be done on an integer type,
1458 // declare it here
1459 Type *BitType;
1460
1461 // If the operands are float, then bitcast them to int
1462 if (OpType->getScalarType()->isFloatingPointTy()) {
1463
1464 // First create the new type
1465 BitType = getIntOrIntVectorTyForCast(F.getContext(), OpType);
1466
1467 // Then bitcast all operands
1468 PredicateValue =
1469 CastInst::CreateZExtOrBitCast(PredicateValue, BitType, "", CI);
1470 FalseValue = CastInst::CreateZExtOrBitCast(FalseValue, BitType, "", CI);
1471 TrueValue = CastInst::CreateZExtOrBitCast(TrueValue, BitType, "", CI);
1472
1473 } else {
1474 // The operands have an integer type, use it directly
1475 BitType = OpType;
1476 }
1477
1478 // All the operands are now always integers
1479 // implement as (c & b) | (~c & a)
1480
1481 // Create our negated predicate value
1482 auto AllOnes = Constant::getAllOnesValue(BitType);
1483 auto NotPredicateValue = BinaryOperator::Create(
1484 Instruction::Xor, PredicateValue, AllOnes, "", CI);
1485
1486 // Then put everything together
1487 auto BitsFalse = BinaryOperator::Create(Instruction::And, NotPredicateValue,
1488 FalseValue, "", CI);
1489 auto BitsTrue = BinaryOperator::Create(Instruction::And, PredicateValue,
1490 TrueValue, "", CI);
1491
1492 V = BinaryOperator::Create(Instruction::Or, BitsFalse, BitsTrue, "", CI);
1493
1494 // If we were dealing with a floating point type, we must bitcast
1495 // the result back to that
1496 if (OpType->getScalarType()->isFloatingPointTy()) {
1497 V = CastInst::CreateZExtOrBitCast(V, OpType, "", CI);
1498 }
1499
1500 return V;
1501 });
Kévin Petite7d0cce2018-10-31 12:38:56 +00001502}
1503
SJW61531372020-06-09 07:31:08 -05001504bool ReplaceOpenCLBuiltinPass::replaceStep(Function &F, bool is_smooth) {
SJW2c317da2020-03-23 07:39:13 -05001505 // convert to vector versions
1506 Module &M = *F.getParent();
1507 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1508 SmallVector<Value *, 2> ArgsToSplat = {CI->getOperand(0)};
1509 Value *VectorArg = nullptr;
Kévin Petit6b0a9532018-10-30 20:00:39 +00001510
SJW2c317da2020-03-23 07:39:13 -05001511 // First figure out which function we're dealing with
1512 if (is_smooth) {
1513 ArgsToSplat.push_back(CI->getOperand(1));
1514 VectorArg = CI->getOperand(2);
1515 } else {
1516 VectorArg = CI->getOperand(1);
1517 }
1518
1519 // Splat arguments that need to be
1520 SmallVector<Value *, 2> SplatArgs;
James Pricecf53df42020-04-20 14:41:24 -04001521 auto VecType = cast<VectorType>(VectorArg->getType());
SJW2c317da2020-03-23 07:39:13 -05001522
1523 for (auto arg : ArgsToSplat) {
1524 Value *NewVectorArg = UndefValue::get(VecType);
alan-baker5a8c3be2020-09-09 13:44:26 -04001525 for (auto i = 0; i < VecType->getElementCount().getKnownMinValue(); i++) {
SJW2c317da2020-03-23 07:39:13 -05001526 auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
1527 NewVectorArg =
1528 InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
1529 }
1530 SplatArgs.push_back(NewVectorArg);
1531 }
1532
1533 // Replace the call with the vector/vector flavour
1534 SmallVector<Type *, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
1535 const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
1536
SJW61531372020-06-09 07:31:08 -05001537 std::string NewFName = Builtins::GetMangledFunctionName(
1538 is_smooth ? "smoothstep" : "step", NewFType);
1539
SJW2c317da2020-03-23 07:39:13 -05001540 const auto NewF = M.getOrInsertFunction(NewFName, NewFType);
1541
1542 SmallVector<Value *, 3> NewArgs;
1543 for (auto arg : SplatArgs) {
1544 NewArgs.push_back(arg);
1545 }
1546 NewArgs.push_back(VectorArg);
1547
1548 return CallInst::Create(NewF, NewArgs, "", CI);
1549 });
Kévin Petit6b0a9532018-10-30 20:00:39 +00001550}
1551
SJW2c317da2020-03-23 07:39:13 -05001552bool ReplaceOpenCLBuiltinPass::replaceSignbit(Function &F, bool is_vec) {
SJW2c317da2020-03-23 07:39:13 -05001553 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1554 auto Arg = CI->getOperand(0);
1555 auto Op = is_vec ? Instruction::AShr : Instruction::LShr;
David Neto22f144c2017-06-12 14:26:21 -04001556
SJW2c317da2020-03-23 07:39:13 -05001557 auto Bitcast = CastInst::CreateZExtOrBitCast(Arg, CI->getType(), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001558
SJW2c317da2020-03-23 07:39:13 -05001559 return BinaryOperator::Create(Op, Bitcast,
1560 ConstantInt::get(CI->getType(), 31), "", CI);
1561 });
David Neto22f144c2017-06-12 14:26:21 -04001562}
1563
SJW2c317da2020-03-23 07:39:13 -05001564bool ReplaceOpenCLBuiltinPass::replaceMul(Function &F, bool is_float,
1565 bool is_mad) {
SJW2c317da2020-03-23 07:39:13 -05001566 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1567 // The multiply instruction to use.
1568 auto MulInst = is_float ? Instruction::FMul : Instruction::Mul;
David Neto22f144c2017-06-12 14:26:21 -04001569
SJW2c317da2020-03-23 07:39:13 -05001570 SmallVector<Value *, 8> Args(CI->arg_begin(), CI->arg_end());
David Neto22f144c2017-06-12 14:26:21 -04001571
SJW2c317da2020-03-23 07:39:13 -05001572 Value *V = BinaryOperator::Create(MulInst, CI->getArgOperand(0),
1573 CI->getArgOperand(1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001574
SJW2c317da2020-03-23 07:39:13 -05001575 if (is_mad) {
1576 // The add instruction to use.
1577 auto AddInst = is_float ? Instruction::FAdd : Instruction::Add;
David Neto22f144c2017-06-12 14:26:21 -04001578
SJW2c317da2020-03-23 07:39:13 -05001579 V = BinaryOperator::Create(AddInst, V, CI->getArgOperand(2), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001580 }
David Neto22f144c2017-06-12 14:26:21 -04001581
SJW2c317da2020-03-23 07:39:13 -05001582 return V;
1583 });
David Neto22f144c2017-06-12 14:26:21 -04001584}
1585
SJW2c317da2020-03-23 07:39:13 -05001586bool ReplaceOpenCLBuiltinPass::replaceVstore(Function &F) {
SJW2c317da2020-03-23 07:39:13 -05001587 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1588 Value *V = nullptr;
1589 auto data = CI->getOperand(0);
Derek Chowcfd368b2017-10-19 20:58:45 -07001590
SJW2c317da2020-03-23 07:39:13 -05001591 auto data_type = data->getType();
1592 if (!data_type->isVectorTy())
1593 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001594
James Pricecf53df42020-04-20 14:41:24 -04001595 auto vec_data_type = cast<VectorType>(data_type);
1596
alan-baker5a8c3be2020-09-09 13:44:26 -04001597 auto elems = vec_data_type->getElementCount().getKnownMinValue();
SJW2c317da2020-03-23 07:39:13 -05001598 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 && elems != 16)
1599 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001600
SJW2c317da2020-03-23 07:39:13 -05001601 auto offset = CI->getOperand(1);
1602 auto ptr = CI->getOperand(2);
1603 auto ptr_type = ptr->getType();
1604 auto pointee_type = ptr_type->getPointerElementType();
James Pricecf53df42020-04-20 14:41:24 -04001605 if (pointee_type != vec_data_type->getElementType())
SJW2c317da2020-03-23 07:39:13 -05001606 return V;
alan-bakerf795f392019-06-11 18:24:34 -04001607
SJW2c317da2020-03-23 07:39:13 -05001608 // Avoid pointer casts. Instead generate the correct number of stores
1609 // and rely on drivers to coalesce appropriately.
1610 IRBuilder<> builder(CI);
1611 auto elems_const = builder.getInt32(elems);
1612 auto adjust = builder.CreateMul(offset, elems_const);
1613 for (auto i = 0; i < elems; ++i) {
1614 auto idx = builder.getInt32(i);
1615 auto add = builder.CreateAdd(adjust, idx);
1616 auto gep = builder.CreateGEP(ptr, add);
1617 auto extract = builder.CreateExtractElement(data, i);
1618 V = builder.CreateStore(extract, gep);
Derek Chowcfd368b2017-10-19 20:58:45 -07001619 }
SJW2c317da2020-03-23 07:39:13 -05001620 return V;
1621 });
Derek Chowcfd368b2017-10-19 20:58:45 -07001622}
1623
SJW2c317da2020-03-23 07:39:13 -05001624bool ReplaceOpenCLBuiltinPass::replaceVload(Function &F) {
SJW2c317da2020-03-23 07:39:13 -05001625 return replaceCallsWithValue(F, [&](CallInst *CI) -> llvm::Value * {
1626 Value *V = nullptr;
1627 auto ret_type = F.getReturnType();
1628 if (!ret_type->isVectorTy())
1629 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001630
James Pricecf53df42020-04-20 14:41:24 -04001631 auto vec_ret_type = cast<VectorType>(ret_type);
1632
alan-baker5a8c3be2020-09-09 13:44:26 -04001633 auto elems = vec_ret_type->getElementCount().getKnownMinValue();
SJW2c317da2020-03-23 07:39:13 -05001634 if (elems != 2 && elems != 3 && elems != 4 && elems != 8 && elems != 16)
1635 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001636
SJW2c317da2020-03-23 07:39:13 -05001637 auto offset = CI->getOperand(0);
1638 auto ptr = CI->getOperand(1);
1639 auto ptr_type = ptr->getType();
1640 auto pointee_type = ptr_type->getPointerElementType();
James Pricecf53df42020-04-20 14:41:24 -04001641 if (pointee_type != vec_ret_type->getElementType())
SJW2c317da2020-03-23 07:39:13 -05001642 return V;
Derek Chowcfd368b2017-10-19 20:58:45 -07001643
SJW2c317da2020-03-23 07:39:13 -05001644 // Avoid pointer casts. Instead generate the correct number of loads
1645 // and rely on drivers to coalesce appropriately.
1646 IRBuilder<> builder(CI);
1647 auto elems_const = builder.getInt32(elems);
1648 V = UndefValue::get(ret_type);
1649 auto adjust = builder.CreateMul(offset, elems_const);
1650 for (auto i = 0; i < elems; ++i) {
1651 auto idx = builder.getInt32(i);
1652 auto add = builder.CreateAdd(adjust, idx);
1653 auto gep = builder.CreateGEP(ptr, add);
1654 auto load = builder.CreateLoad(gep);
1655 V = builder.CreateInsertElement(V, load, i);
Derek Chowcfd368b2017-10-19 20:58:45 -07001656 }
SJW2c317da2020-03-23 07:39:13 -05001657 return V;
1658 });
Derek Chowcfd368b2017-10-19 20:58:45 -07001659}
1660
SJW2c317da2020-03-23 07:39:13 -05001661bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Function &F,
1662 const std::string &name,
1663 int vec_size) {
1664 bool is_clspv_version = !name.compare(0, 8, "__clspv_");
1665 if (!vec_size) {
1666 // deduce vec_size from last character of name (e.g. vload_half4)
1667 vec_size = std::atoi(&name.back());
David Neto22f144c2017-06-12 14:26:21 -04001668 }
SJW2c317da2020-03-23 07:39:13 -05001669 switch (vec_size) {
1670 case 2:
1671 return is_clspv_version ? replaceClspvVloadaHalf2(F) : replaceVloadHalf2(F);
1672 case 4:
1673 return is_clspv_version ? replaceClspvVloadaHalf4(F) : replaceVloadHalf4(F);
1674 case 0:
1675 if (!is_clspv_version) {
1676 return replaceVloadHalf(F);
1677 }
1678 default:
1679 llvm_unreachable("Unsupported vload_half vector size");
1680 break;
1681 }
1682 return false;
David Neto22f144c2017-06-12 14:26:21 -04001683}
1684
SJW2c317da2020-03-23 07:39:13 -05001685bool ReplaceOpenCLBuiltinPass::replaceVloadHalf(Function &F) {
1686 Module &M = *F.getParent();
1687 return replaceCallsWithValue(F, [&](CallInst *CI) {
1688 // The index argument from vload_half.
1689 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001690
SJW2c317da2020-03-23 07:39:13 -05001691 // The pointer argument from vload_half.
1692 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001693
SJW2c317da2020-03-23 07:39:13 -05001694 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001695 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
SJW2c317da2020-03-23 07:39:13 -05001696 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
1697
1698 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001699 auto SPIRVIntrinsic = clspv::UnpackFunction();
SJW2c317da2020-03-23 07:39:13 -05001700
1701 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
1702
1703 Value *V = nullptr;
1704
alan-baker7efcaaa2020-05-06 19:33:27 -04001705 bool supports_16bit_storage = true;
1706 switch (Arg1->getType()->getPointerAddressSpace()) {
1707 case clspv::AddressSpace::Global:
1708 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1709 clspv::Option::StorageClass::kSSBO);
1710 break;
1711 case clspv::AddressSpace::Constant:
1712 if (clspv::Option::ConstantArgsInUniformBuffer())
1713 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1714 clspv::Option::StorageClass::kUBO);
1715 else
1716 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
1717 clspv::Option::StorageClass::kSSBO);
1718 break;
1719 default:
1720 // Clspv will emit the Float16 capability if the half type is
1721 // encountered. That capability covers private and local addressspaces.
1722 break;
1723 }
1724
1725 if (supports_16bit_storage) {
SJW2c317da2020-03-23 07:39:13 -05001726 auto ShortTy = Type::getInt16Ty(M.getContext());
1727 auto ShortPointerTy =
1728 PointerType::get(ShortTy, Arg1->getType()->getPointerAddressSpace());
1729
1730 // Cast the half* pointer to short*.
1731 auto Cast = CastInst::CreatePointerCast(Arg1, ShortPointerTy, "", CI);
1732
1733 // Index into the correct address of the casted pointer.
1734 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg0, "", CI);
1735
1736 // Load from the short* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001737 auto Load = new LoadInst(ShortTy, Index, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001738
1739 // ZExt the short -> int.
1740 auto ZExt = CastInst::CreateZExtOrBitCast(Load, IntTy, "", CI);
1741
1742 // Get our float2.
1743 auto Call = CallInst::Create(NewF, ZExt, "", CI);
1744
1745 // Extract out the bottom element which is our float result.
1746 V = ExtractElementInst::Create(Call, ConstantInt::get(IntTy, 0), "", CI);
1747 } else {
1748 // Assume the pointer argument points to storage aligned to 32bits
1749 // or more.
1750 // TODO(dneto): Do more analysis to make sure this is true?
1751 //
1752 // Replace call vstore_half(i32 %index, half addrspace(1) %base)
1753 // with:
1754 //
1755 // %base_i32_ptr = bitcast half addrspace(1)* %base to i32
1756 // addrspace(1)* %index_is_odd32 = and i32 %index, 1 %index_i32 =
1757 // lshr i32 %index, 1 %in_ptr = getlementptr i32, i32
1758 // addrspace(1)* %base_i32_ptr, %index_i32 %value_i32 = load i32,
1759 // i32 addrspace(1)* %in_ptr %converted = call <2 x float>
1760 // @spirv.unpack.v2f16(i32 %value_i32) %value = extractelement <2
1761 // x float> %converted, %index_is_odd32
1762
1763 auto IntPointerTy =
1764 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
1765
1766 // Cast the base pointer to int*.
1767 // In a valid call (according to assumptions), this should get
1768 // optimized away in the simplify GEP pass.
1769 auto Cast = CastInst::CreatePointerCast(Arg1, IntPointerTy, "", CI);
1770
1771 auto One = ConstantInt::get(IntTy, 1);
1772 auto IndexIsOdd = BinaryOperator::CreateAnd(Arg0, One, "", CI);
1773 auto IndexIntoI32 = BinaryOperator::CreateLShr(Arg0, One, "", CI);
1774
1775 // Index into the correct address of the casted pointer.
1776 auto Ptr = GetElementPtrInst::Create(IntTy, Cast, IndexIntoI32, "", CI);
1777
1778 // Load from the int* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001779 auto Load = new LoadInst(IntTy, Ptr, "", CI);
SJW2c317da2020-03-23 07:39:13 -05001780
1781 // Get our float2.
1782 auto Call = CallInst::Create(NewF, Load, "", CI);
1783
1784 // Extract out the float result, where the element number is
1785 // determined by whether the original index was even or odd.
1786 V = ExtractElementInst::Create(Call, IndexIsOdd, "", CI);
1787 }
1788 return V;
1789 });
1790}
1791
1792bool ReplaceOpenCLBuiltinPass::replaceVloadHalf2(Function &F) {
1793 Module &M = *F.getParent();
1794 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001795 // The index argument from vload_half.
1796 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001797
Kévin Petite8edce32019-04-10 14:23:32 +01001798 // The pointer argument from vload_half.
1799 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001800
Kévin Petite8edce32019-04-10 14:23:32 +01001801 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001802 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001803 auto NewPointerTy =
1804 PointerType::get(IntTy, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01001805 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04001806
Kévin Petite8edce32019-04-10 14:23:32 +01001807 // Cast the half* pointer to int*.
1808 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001809
Kévin Petite8edce32019-04-10 14:23:32 +01001810 // Index into the correct address of the casted pointer.
1811 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001812
Kévin Petite8edce32019-04-10 14:23:32 +01001813 // Load from the int* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001814 auto Load = new LoadInst(IntTy, Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001815
Kévin Petite8edce32019-04-10 14:23:32 +01001816 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001817 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto22f144c2017-06-12 14:26:21 -04001818
Kévin Petite8edce32019-04-10 14:23:32 +01001819 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04001820
Kévin Petite8edce32019-04-10 14:23:32 +01001821 // Get our float2.
1822 return CallInst::Create(NewF, Load, "", CI);
1823 });
David Neto22f144c2017-06-12 14:26:21 -04001824}
1825
SJW2c317da2020-03-23 07:39:13 -05001826bool ReplaceOpenCLBuiltinPass::replaceVloadHalf4(Function &F) {
1827 Module &M = *F.getParent();
1828 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001829 // The index argument from vload_half.
1830 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001831
Kévin Petite8edce32019-04-10 14:23:32 +01001832 // The pointer argument from vload_half.
1833 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001834
Kévin Petite8edce32019-04-10 14:23:32 +01001835 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001836 auto Int2Ty = FixedVectorType::get(IntTy, 2);
1837 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001838 auto NewPointerTy =
1839 PointerType::get(Int2Ty, Arg1->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01001840 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto22f144c2017-06-12 14:26:21 -04001841
Kévin Petite8edce32019-04-10 14:23:32 +01001842 // Cast the half* pointer to int2*.
1843 auto Cast = CastInst::CreatePointerCast(Arg1, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001844
Kévin Petite8edce32019-04-10 14:23:32 +01001845 // Index into the correct address of the casted pointer.
1846 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001847
Kévin Petite8edce32019-04-10 14:23:32 +01001848 // Load from the int2* we casted to.
alan-baker741fd1f2020-04-14 17:38:15 -04001849 auto Load = new LoadInst(Int2Ty, Index, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001850
Kévin Petite8edce32019-04-10 14:23:32 +01001851 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001852 auto X =
1853 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
1854 auto Y =
1855 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001856
Kévin Petite8edce32019-04-10 14:23:32 +01001857 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001858 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto22f144c2017-06-12 14:26:21 -04001859
Kévin Petite8edce32019-04-10 14:23:32 +01001860 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04001861
Kévin Petite8edce32019-04-10 14:23:32 +01001862 // Get the lower (x & y) components of our final float4.
1863 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001864
Kévin Petite8edce32019-04-10 14:23:32 +01001865 // Get the higher (z & w) components of our final float4.
1866 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001867
Kévin Petite8edce32019-04-10 14:23:32 +01001868 Constant *ShuffleMask[4] = {
1869 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1870 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04001871
Kévin Petite8edce32019-04-10 14:23:32 +01001872 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001873 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
1874 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01001875 });
David Neto22f144c2017-06-12 14:26:21 -04001876}
1877
SJW2c317da2020-03-23 07:39:13 -05001878bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf2(Function &F) {
David Neto6ad93232018-06-07 15:42:58 -07001879
1880 // Replace __clspv_vloada_half2(uint Index, global uint* Ptr) with:
1881 //
1882 // %u = load i32 %ptr
1883 // %fxy = call <2 x float> Unpack2xHalf(u)
1884 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
SJW2c317da2020-03-23 07:39:13 -05001885 Module &M = *F.getParent();
1886 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001887 auto Index = CI->getOperand(0);
1888 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07001889
Kévin Petite8edce32019-04-10 14:23:32 +01001890 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001891 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001892 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07001893
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001894 auto IndexedPtr = GetElementPtrInst::Create(IntTy, Ptr, Index, "", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04001895 auto Load = new LoadInst(IntTy, IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001896
Kévin Petite8edce32019-04-10 14:23:32 +01001897 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001898 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto6ad93232018-06-07 15:42:58 -07001899
Kévin Petite8edce32019-04-10 14:23:32 +01001900 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07001901
Kévin Petite8edce32019-04-10 14:23:32 +01001902 // Get our final float2.
1903 return CallInst::Create(NewF, Load, "", CI);
1904 });
David Neto6ad93232018-06-07 15:42:58 -07001905}
1906
SJW2c317da2020-03-23 07:39:13 -05001907bool ReplaceOpenCLBuiltinPass::replaceClspvVloadaHalf4(Function &F) {
David Neto6ad93232018-06-07 15:42:58 -07001908
1909 // Replace __clspv_vloada_half4(uint Index, global uint2* Ptr) with:
1910 //
1911 // %u2 = load <2 x i32> %ptr
1912 // %u2xy = extractelement %u2, 0
1913 // %u2zw = extractelement %u2, 1
1914 // %fxy = call <2 x float> Unpack2xHalf(uint)
1915 // %fzw = call <2 x float> Unpack2xHalf(uint)
1916 // %result = shufflevector %fxy %fzw <4 x i32> <0, 1, 2, 3>
SJW2c317da2020-03-23 07:39:13 -05001917 Module &M = *F.getParent();
1918 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001919 auto Index = CI->getOperand(0);
1920 auto Ptr = CI->getOperand(1);
David Neto6ad93232018-06-07 15:42:58 -07001921
Kévin Petite8edce32019-04-10 14:23:32 +01001922 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001923 auto Int2Ty = FixedVectorType::get(IntTy, 2);
1924 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001925 auto NewFType = FunctionType::get(Float2Ty, IntTy, false);
David Neto6ad93232018-06-07 15:42:58 -07001926
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001927 auto IndexedPtr = GetElementPtrInst::Create(Int2Ty, Ptr, Index, "", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04001928 auto Load = new LoadInst(Int2Ty, IndexedPtr, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001929
Kévin Petite8edce32019-04-10 14:23:32 +01001930 // Extract each element from the loaded int2.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001931 auto X =
1932 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 0), "", CI);
1933 auto Y =
1934 ExtractElementInst::Create(Load, ConstantInt::get(IntTy, 1), "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001935
Kévin Petite8edce32019-04-10 14:23:32 +01001936 // Our intrinsic to unpack a float2 from an int.
SJW61531372020-06-09 07:31:08 -05001937 auto SPIRVIntrinsic = clspv::UnpackFunction();
David Neto6ad93232018-06-07 15:42:58 -07001938
Kévin Petite8edce32019-04-10 14:23:32 +01001939 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto6ad93232018-06-07 15:42:58 -07001940
Kévin Petite8edce32019-04-10 14:23:32 +01001941 // Get the lower (x & y) components of our final float4.
1942 auto Lo = CallInst::Create(NewF, X, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001943
Kévin Petite8edce32019-04-10 14:23:32 +01001944 // Get the higher (z & w) components of our final float4.
1945 auto Hi = CallInst::Create(NewF, Y, "", CI);
David Neto6ad93232018-06-07 15:42:58 -07001946
Kévin Petite8edce32019-04-10 14:23:32 +01001947 Constant *ShuffleMask[4] = {
1948 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
1949 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
David Neto6ad93232018-06-07 15:42:58 -07001950
Kévin Petite8edce32019-04-10 14:23:32 +01001951 // Combine our two float2's into one float4.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001952 return new ShuffleVectorInst(Lo, Hi, ConstantVector::get(ShuffleMask), "",
1953 CI);
Kévin Petite8edce32019-04-10 14:23:32 +01001954 });
David Neto6ad93232018-06-07 15:42:58 -07001955}
1956
SJW2c317da2020-03-23 07:39:13 -05001957bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Function &F, int vec_size) {
1958 switch (vec_size) {
1959 case 0:
1960 return replaceVstoreHalf(F);
1961 case 2:
1962 return replaceVstoreHalf2(F);
1963 case 4:
1964 return replaceVstoreHalf4(F);
1965 default:
1966 llvm_unreachable("Unsupported vstore_half vector size");
1967 break;
1968 }
1969 return false;
1970}
David Neto22f144c2017-06-12 14:26:21 -04001971
SJW2c317da2020-03-23 07:39:13 -05001972bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf(Function &F) {
1973 Module &M = *F.getParent();
1974 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01001975 // The value to store.
1976 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04001977
Kévin Petite8edce32019-04-10 14:23:32 +01001978 // The index argument from vstore_half.
1979 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04001980
Kévin Petite8edce32019-04-10 14:23:32 +01001981 // The pointer argument from vstore_half.
1982 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04001983
Kévin Petite8edce32019-04-10 14:23:32 +01001984 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04001985 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Kévin Petite8edce32019-04-10 14:23:32 +01001986 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
1987 auto One = ConstantInt::get(IntTy, 1);
David Neto22f144c2017-06-12 14:26:21 -04001988
Kévin Petite8edce32019-04-10 14:23:32 +01001989 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05001990 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04001991
Kévin Petite8edce32019-04-10 14:23:32 +01001992 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04001993
Kévin Petite8edce32019-04-10 14:23:32 +01001994 // Insert our value into a float2 so that we can pack it.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04001995 auto TempVec = InsertElementInst::Create(
1996 UndefValue::get(Float2Ty), Arg0, ConstantInt::get(IntTy, 0), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04001997
Kévin Petite8edce32019-04-10 14:23:32 +01001998 // Pack the float2 -> half2 (in an int).
1999 auto X = CallInst::Create(NewF, TempVec, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002000
alan-baker7efcaaa2020-05-06 19:33:27 -04002001 bool supports_16bit_storage = true;
2002 switch (Arg2->getType()->getPointerAddressSpace()) {
2003 case clspv::AddressSpace::Global:
2004 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2005 clspv::Option::StorageClass::kSSBO);
2006 break;
2007 case clspv::AddressSpace::Constant:
2008 if (clspv::Option::ConstantArgsInUniformBuffer())
2009 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2010 clspv::Option::StorageClass::kUBO);
2011 else
2012 supports_16bit_storage = clspv::Option::Supports16BitStorageClass(
2013 clspv::Option::StorageClass::kSSBO);
2014 break;
2015 default:
2016 // Clspv will emit the Float16 capability if the half type is
2017 // encountered. That capability covers private and local addressspaces.
2018 break;
2019 }
2020
SJW2c317da2020-03-23 07:39:13 -05002021 Value *V = nullptr;
alan-baker7efcaaa2020-05-06 19:33:27 -04002022 if (supports_16bit_storage) {
Kévin Petite8edce32019-04-10 14:23:32 +01002023 auto ShortTy = Type::getInt16Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002024 auto ShortPointerTy =
2025 PointerType::get(ShortTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002026
Kévin Petite8edce32019-04-10 14:23:32 +01002027 // Truncate our i32 to an i16.
2028 auto Trunc = CastInst::CreateTruncOrBitCast(X, ShortTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002029
Kévin Petite8edce32019-04-10 14:23:32 +01002030 // Cast the half* pointer to short*.
2031 auto Cast = CastInst::CreatePointerCast(Arg2, ShortPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002032
Kévin Petite8edce32019-04-10 14:23:32 +01002033 // Index into the correct address of the casted pointer.
2034 auto Index = GetElementPtrInst::Create(ShortTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002035
Kévin Petite8edce32019-04-10 14:23:32 +01002036 // Store to the int* we casted to.
SJW2c317da2020-03-23 07:39:13 -05002037 V = new StoreInst(Trunc, Index, CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002038 } else {
2039 // We can only write to 32-bit aligned words.
2040 //
2041 // Assuming base is aligned to 32-bits, replace the equivalent of
2042 // vstore_half(value, index, base)
2043 // with:
2044 // uint32_t* target_ptr = (uint32_t*)(base) + index / 2;
2045 // uint32_t write_to_upper_half = index & 1u;
2046 // uint32_t shift = write_to_upper_half << 4;
2047 //
2048 // // Pack the float value as a half number in bottom 16 bits
2049 // // of an i32.
2050 // uint32_t packed = spirv.pack.v2f16((float2)(value, undef));
2051 //
2052 // uint32_t xor_value = (*target_ptr & (0xffff << shift))
2053 // ^ ((packed & 0xffff) << shift)
2054 // // We only need relaxed consistency, but OpenCL 1.2 only has
2055 // // sequentially consistent atomics.
2056 // // TODO(dneto): Use relaxed consistency.
2057 // atomic_xor(target_ptr, xor_value)
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002058 auto IntPointerTy =
2059 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
David Neto22f144c2017-06-12 14:26:21 -04002060
Kévin Petite8edce32019-04-10 14:23:32 +01002061 auto Four = ConstantInt::get(IntTy, 4);
2062 auto FFFF = ConstantInt::get(IntTy, 0xffff);
David Neto17852de2017-05-29 17:29:31 -04002063
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002064 auto IndexIsOdd =
2065 BinaryOperator::CreateAnd(Arg1, One, "index_is_odd_i32", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002066 // Compute index / 2
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002067 auto IndexIntoI32 =
2068 BinaryOperator::CreateLShr(Arg1, One, "index_into_i32", CI);
2069 auto BaseI32Ptr =
2070 CastInst::CreatePointerCast(Arg2, IntPointerTy, "base_i32_ptr", CI);
2071 auto OutPtr = GetElementPtrInst::Create(IntTy, BaseI32Ptr, IndexIntoI32,
2072 "base_i32_ptr", CI);
alan-baker741fd1f2020-04-14 17:38:15 -04002073 auto CurrentValue = new LoadInst(IntTy, OutPtr, "current_value", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002074 auto Shift = BinaryOperator::CreateShl(IndexIsOdd, Four, "shift", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002075 auto MaskBitsToWrite =
2076 BinaryOperator::CreateShl(FFFF, Shift, "mask_bits_to_write", CI);
2077 auto MaskedCurrent = BinaryOperator::CreateAnd(
2078 MaskBitsToWrite, CurrentValue, "masked_current", CI);
David Neto17852de2017-05-29 17:29:31 -04002079
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002080 auto XLowerBits =
2081 BinaryOperator::CreateAnd(X, FFFF, "lower_bits_of_packed", CI);
2082 auto NewBitsToWrite =
2083 BinaryOperator::CreateShl(XLowerBits, Shift, "new_bits_to_write", CI);
2084 auto ValueToXor = BinaryOperator::CreateXor(MaskedCurrent, NewBitsToWrite,
2085 "value_to_xor", CI);
David Neto17852de2017-05-29 17:29:31 -04002086
Kévin Petite8edce32019-04-10 14:23:32 +01002087 // Generate the call to atomi_xor.
2088 SmallVector<Type *, 5> ParamTypes;
2089 // The pointer type.
2090 ParamTypes.push_back(IntPointerTy);
2091 // The Types for memory scope, semantics, and value.
2092 ParamTypes.push_back(IntTy);
2093 ParamTypes.push_back(IntTy);
2094 ParamTypes.push_back(IntTy);
2095 auto NewFType = FunctionType::get(IntTy, ParamTypes, false);
2096 auto NewF = M.getOrInsertFunction("spirv.atomic_xor", NewFType);
David Neto17852de2017-05-29 17:29:31 -04002097
Kévin Petite8edce32019-04-10 14:23:32 +01002098 const auto ConstantScopeDevice =
2099 ConstantInt::get(IntTy, spv::ScopeDevice);
2100 // Assume the pointee is in OpenCL global (SPIR-V Uniform) or local
2101 // (SPIR-V Workgroup).
2102 const auto AddrSpaceSemanticsBits =
2103 IntPointerTy->getPointerAddressSpace() == 1
2104 ? spv::MemorySemanticsUniformMemoryMask
2105 : spv::MemorySemanticsWorkgroupMemoryMask;
David Neto17852de2017-05-29 17:29:31 -04002106
Kévin Petite8edce32019-04-10 14:23:32 +01002107 // We're using relaxed consistency here.
2108 const auto ConstantMemorySemantics =
2109 ConstantInt::get(IntTy, spv::MemorySemanticsUniformMemoryMask |
2110 AddrSpaceSemanticsBits);
David Neto17852de2017-05-29 17:29:31 -04002111
Kévin Petite8edce32019-04-10 14:23:32 +01002112 SmallVector<Value *, 5> Params{OutPtr, ConstantScopeDevice,
2113 ConstantMemorySemantics, ValueToXor};
2114 CallInst::Create(NewF, Params, "store_halfword_xor_trick", CI);
SJW2c317da2020-03-23 07:39:13 -05002115
2116 // Return a Nop so the old Call is removed
2117 Function *donothing = Intrinsic::getDeclaration(&M, Intrinsic::donothing);
2118 V = CallInst::Create(donothing, {}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002119 }
David Neto22f144c2017-06-12 14:26:21 -04002120
SJW2c317da2020-03-23 07:39:13 -05002121 return V;
Kévin Petite8edce32019-04-10 14:23:32 +01002122 });
David Neto22f144c2017-06-12 14:26:21 -04002123}
2124
SJW2c317da2020-03-23 07:39:13 -05002125bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf2(Function &F) {
2126 Module &M = *F.getParent();
2127 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002128 // The value to store.
2129 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002130
Kévin Petite8edce32019-04-10 14:23:32 +01002131 // The index argument from vstore_half.
2132 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002133
Kévin Petite8edce32019-04-10 14:23:32 +01002134 // The pointer argument from vstore_half.
2135 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002136
Kévin Petite8edce32019-04-10 14:23:32 +01002137 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04002138 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002139 auto NewPointerTy =
2140 PointerType::get(IntTy, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002141 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002142
Kévin Petite8edce32019-04-10 14:23:32 +01002143 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05002144 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04002145
Kévin Petite8edce32019-04-10 14:23:32 +01002146 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002147
Kévin Petite8edce32019-04-10 14:23:32 +01002148 // Turn the packed x & y into the final packing.
2149 auto X = CallInst::Create(NewF, Arg0, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002150
Kévin Petite8edce32019-04-10 14:23:32 +01002151 // Cast the half* pointer to int*.
2152 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002153
Kévin Petite8edce32019-04-10 14:23:32 +01002154 // Index into the correct address of the casted pointer.
2155 auto Index = GetElementPtrInst::Create(IntTy, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002156
Kévin Petite8edce32019-04-10 14:23:32 +01002157 // Store to the int* we casted to.
2158 return new StoreInst(X, Index, CI);
2159 });
David Neto22f144c2017-06-12 14:26:21 -04002160}
2161
SJW2c317da2020-03-23 07:39:13 -05002162bool ReplaceOpenCLBuiltinPass::replaceVstoreHalf4(Function &F) {
2163 Module &M = *F.getParent();
2164 return replaceCallsWithValue(F, [&](CallInst *CI) {
Kévin Petite8edce32019-04-10 14:23:32 +01002165 // The value to store.
2166 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002167
Kévin Petite8edce32019-04-10 14:23:32 +01002168 // The index argument from vstore_half.
2169 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002170
Kévin Petite8edce32019-04-10 14:23:32 +01002171 // The pointer argument from vstore_half.
2172 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002173
Kévin Petite8edce32019-04-10 14:23:32 +01002174 auto IntTy = Type::getInt32Ty(M.getContext());
alan-bakerb3e2b6d2020-06-24 23:59:57 -04002175 auto Int2Ty = FixedVectorType::get(IntTy, 2);
2176 auto Float2Ty = FixedVectorType::get(Type::getFloatTy(M.getContext()), 2);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002177 auto NewPointerTy =
2178 PointerType::get(Int2Ty, Arg2->getType()->getPointerAddressSpace());
Kévin Petite8edce32019-04-10 14:23:32 +01002179 auto NewFType = FunctionType::get(IntTy, Float2Ty, false);
David Neto22f144c2017-06-12 14:26:21 -04002180
Kévin Petite8edce32019-04-10 14:23:32 +01002181 Constant *LoShuffleMask[2] = {ConstantInt::get(IntTy, 0),
2182 ConstantInt::get(IntTy, 1)};
David Neto22f144c2017-06-12 14:26:21 -04002183
Kévin Petite8edce32019-04-10 14:23:32 +01002184 // Extract out the x & y components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002185 auto Lo = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2186 ConstantVector::get(LoShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002187
Kévin Petite8edce32019-04-10 14:23:32 +01002188 Constant *HiShuffleMask[2] = {ConstantInt::get(IntTy, 2),
2189 ConstantInt::get(IntTy, 3)};
David Neto22f144c2017-06-12 14:26:21 -04002190
Kévin Petite8edce32019-04-10 14:23:32 +01002191 // Extract out the z & w components of our to store value.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002192 auto Hi = new ShuffleVectorInst(Arg0, UndefValue::get(Arg0->getType()),
2193 ConstantVector::get(HiShuffleMask), "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002194
Kévin Petite8edce32019-04-10 14:23:32 +01002195 // Our intrinsic to pack a float2 to an int.
SJW61531372020-06-09 07:31:08 -05002196 auto SPIRVIntrinsic = clspv::PackFunction();
David Neto22f144c2017-06-12 14:26:21 -04002197
Kévin Petite8edce32019-04-10 14:23:32 +01002198 auto NewF = M.getOrInsertFunction(SPIRVIntrinsic, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002199
Kévin Petite8edce32019-04-10 14:23:32 +01002200 // Turn the packed x & y into the final component of our int2.
2201 auto X = CallInst::Create(NewF, Lo, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002202
Kévin Petite8edce32019-04-10 14:23:32 +01002203 // Turn the packed z & w into the final component of our int2.
2204 auto Y = CallInst::Create(NewF, Hi, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002205
Kévin Petite8edce32019-04-10 14:23:32 +01002206 auto Combine = InsertElementInst::Create(
2207 UndefValue::get(Int2Ty), X, ConstantInt::get(IntTy, 0), "", CI);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002208 Combine = InsertElementInst::Create(Combine, Y, ConstantInt::get(IntTy, 1),
2209 "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002210
Kévin Petite8edce32019-04-10 14:23:32 +01002211 // Cast the half* pointer to int2*.
2212 auto Cast = CastInst::CreatePointerCast(Arg2, NewPointerTy, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002213
Kévin Petite8edce32019-04-10 14:23:32 +01002214 // Index into the correct address of the casted pointer.
2215 auto Index = GetElementPtrInst::Create(Int2Ty, Cast, Arg1, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002216
Kévin Petite8edce32019-04-10 14:23:32 +01002217 // Store to the int2* we casted to.
2218 return new StoreInst(Combine, Index, CI);
2219 });
David Neto22f144c2017-06-12 14:26:21 -04002220}
2221
SJW2c317da2020-03-23 07:39:13 -05002222bool ReplaceOpenCLBuiltinPass::replaceHalfReadImage(Function &F) {
2223 // convert half to float
2224 Module &M = *F.getParent();
2225 return replaceCallsWithValue(F, [&](CallInst *CI) {
2226 SmallVector<Type *, 3> types;
2227 SmallVector<Value *, 3> args;
2228 for (auto i = 0; i < CI->getNumArgOperands(); ++i) {
2229 types.push_back(CI->getArgOperand(i)->getType());
2230 args.push_back(CI->getArgOperand(i));
alan-bakerf7e17cb2020-01-02 07:29:59 -05002231 }
alan-bakerf7e17cb2020-01-02 07:29:59 -05002232
alan-baker5a8c3be2020-09-09 13:44:26 -04002233 auto NewFType =
2234 FunctionType::get(FixedVectorType::get(Type::getFloatTy(M.getContext()),
2235 cast<VectorType>(CI->getType())
2236 ->getElementCount()
2237 .getKnownMinValue()),
2238 types, false);
SJW2c317da2020-03-23 07:39:13 -05002239
SJW61531372020-06-09 07:31:08 -05002240 std::string NewFName =
2241 Builtins::GetMangledFunctionName("read_imagef", NewFType);
SJW2c317da2020-03-23 07:39:13 -05002242
2243 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
2244
2245 auto NewCI = CallInst::Create(NewF, args, "", CI);
2246
2247 // Convert to the half type.
2248 return CastInst::CreateFPCast(NewCI, CI->getType(), "", CI);
2249 });
alan-bakerf7e17cb2020-01-02 07:29:59 -05002250}
2251
SJW2c317da2020-03-23 07:39:13 -05002252bool ReplaceOpenCLBuiltinPass::replaceHalfWriteImage(Function &F) {
2253 // convert half to float
2254 Module &M = *F.getParent();
2255 return replaceCallsWithValue(F, [&](CallInst *CI) {
2256 SmallVector<Type *, 3> types(3);
2257 SmallVector<Value *, 3> args(3);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002258
SJW2c317da2020-03-23 07:39:13 -05002259 // Image
2260 types[0] = CI->getArgOperand(0)->getType();
2261 args[0] = CI->getArgOperand(0);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002262
SJW2c317da2020-03-23 07:39:13 -05002263 // Coord
2264 types[1] = CI->getArgOperand(1)->getType();
2265 args[1] = CI->getArgOperand(1);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002266
SJW2c317da2020-03-23 07:39:13 -05002267 // Data
alan-baker5a8c3be2020-09-09 13:44:26 -04002268 types[2] =
2269 FixedVectorType::get(Type::getFloatTy(M.getContext()),
2270 cast<VectorType>(CI->getArgOperand(2)->getType())
2271 ->getElementCount()
2272 .getKnownMinValue());
alan-bakerf7e17cb2020-01-02 07:29:59 -05002273
SJW2c317da2020-03-23 07:39:13 -05002274 auto NewFType =
2275 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002276
SJW61531372020-06-09 07:31:08 -05002277 std::string NewFName =
2278 Builtins::GetMangledFunctionName("write_imagef", NewFType);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002279
SJW2c317da2020-03-23 07:39:13 -05002280 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
alan-bakerf7e17cb2020-01-02 07:29:59 -05002281
SJW2c317da2020-03-23 07:39:13 -05002282 // Convert data to the float type.
2283 auto Cast = CastInst::CreateFPCast(CI->getArgOperand(2), types[2], "", CI);
2284 args[2] = Cast;
alan-bakerf7e17cb2020-01-02 07:29:59 -05002285
SJW2c317da2020-03-23 07:39:13 -05002286 return CallInst::Create(NewF, args, "", CI);
2287 });
alan-bakerf7e17cb2020-01-02 07:29:59 -05002288}
2289
SJW2c317da2020-03-23 07:39:13 -05002290bool ReplaceOpenCLBuiltinPass::replaceSampledReadImageWithIntCoords(
2291 Function &F) {
2292 // convert read_image with int coords to float coords
2293 Module &M = *F.getParent();
2294 return replaceCallsWithValue(F, [&](CallInst *CI) {
2295 // The image.
2296 auto Arg0 = CI->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -04002297
SJW2c317da2020-03-23 07:39:13 -05002298 // The sampler.
2299 auto Arg1 = CI->getOperand(1);
David Neto22f144c2017-06-12 14:26:21 -04002300
SJW2c317da2020-03-23 07:39:13 -05002301 // The coordinate (integer type that we can't handle).
2302 auto Arg2 = CI->getOperand(2);
David Neto22f144c2017-06-12 14:26:21 -04002303
SJW2c317da2020-03-23 07:39:13 -05002304 uint32_t dim = clspv::ImageDimensionality(Arg0->getType());
2305 uint32_t components =
2306 dim + (clspv::IsArrayImageType(Arg0->getType()) ? 1 : 0);
2307 Type *float_ty = nullptr;
2308 if (components == 1) {
2309 float_ty = Type::getFloatTy(M.getContext());
2310 } else {
alan-baker5a8c3be2020-09-09 13:44:26 -04002311 float_ty = FixedVectorType::get(Type::getFloatTy(M.getContext()),
2312 cast<VectorType>(Arg2->getType())
2313 ->getElementCount()
2314 .getKnownMinValue());
David Neto22f144c2017-06-12 14:26:21 -04002315 }
David Neto22f144c2017-06-12 14:26:21 -04002316
SJW2c317da2020-03-23 07:39:13 -05002317 auto NewFType = FunctionType::get(
2318 CI->getType(), {Arg0->getType(), Arg1->getType(), float_ty}, false);
2319
2320 std::string NewFName = F.getName().str();
2321 NewFName[NewFName.length() - 1] = 'f';
2322
2323 auto NewF = M.getOrInsertFunction(NewFName, NewFType);
2324
2325 auto Cast = CastInst::Create(Instruction::SIToFP, Arg2, float_ty, "", CI);
2326
2327 return CallInst::Create(NewF, {Arg0, Arg1, Cast}, "", CI);
2328 });
David Neto22f144c2017-06-12 14:26:21 -04002329}
2330
SJW2c317da2020-03-23 07:39:13 -05002331bool ReplaceOpenCLBuiltinPass::replaceAtomics(Function &F, spv::Op Op) {
2332 return replaceCallsWithValue(F, [&](CallInst *CI) {
2333 auto IntTy = Type::getInt32Ty(F.getContext());
David Neto22f144c2017-06-12 14:26:21 -04002334
SJW2c317da2020-03-23 07:39:13 -05002335 // We need to map the OpenCL constants to the SPIR-V equivalents.
2336 const auto ConstantScopeDevice = ConstantInt::get(IntTy, spv::ScopeDevice);
2337 const auto ConstantMemorySemantics = ConstantInt::get(
2338 IntTy, spv::MemorySemanticsUniformMemoryMask |
2339 spv::MemorySemanticsSequentiallyConsistentMask);
David Neto22f144c2017-06-12 14:26:21 -04002340
SJW2c317da2020-03-23 07:39:13 -05002341 SmallVector<Value *, 5> Params;
David Neto22f144c2017-06-12 14:26:21 -04002342
SJW2c317da2020-03-23 07:39:13 -05002343 // The pointer.
2344 Params.push_back(CI->getArgOperand(0));
David Neto22f144c2017-06-12 14:26:21 -04002345
SJW2c317da2020-03-23 07:39:13 -05002346 // The memory scope.
2347 Params.push_back(ConstantScopeDevice);
David Neto22f144c2017-06-12 14:26:21 -04002348
SJW2c317da2020-03-23 07:39:13 -05002349 // The memory semantics.
2350 Params.push_back(ConstantMemorySemantics);
David Neto22f144c2017-06-12 14:26:21 -04002351
SJW2c317da2020-03-23 07:39:13 -05002352 if (2 < CI->getNumArgOperands()) {
2353 // The unequal memory semantics.
2354 Params.push_back(ConstantMemorySemantics);
David Neto22f144c2017-06-12 14:26:21 -04002355
SJW2c317da2020-03-23 07:39:13 -05002356 // The value.
2357 Params.push_back(CI->getArgOperand(2));
David Neto22f144c2017-06-12 14:26:21 -04002358
SJW2c317da2020-03-23 07:39:13 -05002359 // The comparator.
2360 Params.push_back(CI->getArgOperand(1));
2361 } else if (1 < CI->getNumArgOperands()) {
2362 // The value.
2363 Params.push_back(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -04002364 }
David Neto22f144c2017-06-12 14:26:21 -04002365
SJW2c317da2020-03-23 07:39:13 -05002366 return clspv::InsertSPIRVOp(CI, Op, {}, CI->getType(), Params);
2367 });
David Neto22f144c2017-06-12 14:26:21 -04002368}
2369
SJW2c317da2020-03-23 07:39:13 -05002370bool ReplaceOpenCLBuiltinPass::replaceAtomics(Function &F,
2371 llvm::AtomicRMWInst::BinOp Op) {
2372 return replaceCallsWithValue(F, [&](CallInst *CI) {
alan-bakerd0eb9052020-07-07 13:12:01 -04002373 auto align = F.getParent()->getDataLayout().getABITypeAlign(
2374 CI->getArgOperand(1)->getType());
SJW2c317da2020-03-23 07:39:13 -05002375 return new AtomicRMWInst(Op, CI->getArgOperand(0), CI->getArgOperand(1),
alan-bakerd0eb9052020-07-07 13:12:01 -04002376 align, AtomicOrdering::SequentiallyConsistent,
SJW2c317da2020-03-23 07:39:13 -05002377 SyncScope::System, CI);
2378 });
2379}
David Neto22f144c2017-06-12 14:26:21 -04002380
SJW2c317da2020-03-23 07:39:13 -05002381bool ReplaceOpenCLBuiltinPass::replaceCross(Function &F) {
2382 Module &M = *F.getParent();
2383 return replaceCallsWithValue(F, [&](CallInst *CI) {
David Neto22f144c2017-06-12 14:26:21 -04002384 auto IntTy = Type::getInt32Ty(M.getContext());
2385 auto FloatTy = Type::getFloatTy(M.getContext());
2386
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002387 Constant *DownShuffleMask[3] = {ConstantInt::get(IntTy, 0),
2388 ConstantInt::get(IntTy, 1),
2389 ConstantInt::get(IntTy, 2)};
David Neto22f144c2017-06-12 14:26:21 -04002390
2391 Constant *UpShuffleMask[4] = {
2392 ConstantInt::get(IntTy, 0), ConstantInt::get(IntTy, 1),
2393 ConstantInt::get(IntTy, 2), ConstantInt::get(IntTy, 3)};
2394
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002395 Constant *FloatVec[3] = {ConstantFP::get(FloatTy, 0.0f),
2396 UndefValue::get(FloatTy),
2397 UndefValue::get(FloatTy)};
David Neto22f144c2017-06-12 14:26:21 -04002398
Kévin Petite8edce32019-04-10 14:23:32 +01002399 auto Vec4Ty = CI->getArgOperand(0)->getType();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002400 auto Arg0 =
2401 new ShuffleVectorInst(CI->getArgOperand(0), UndefValue::get(Vec4Ty),
2402 ConstantVector::get(DownShuffleMask), "", CI);
2403 auto Arg1 =
2404 new ShuffleVectorInst(CI->getArgOperand(1), UndefValue::get(Vec4Ty),
2405 ConstantVector::get(DownShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002406 auto Vec3Ty = Arg0->getType();
David Neto22f144c2017-06-12 14:26:21 -04002407
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002408 auto NewFType = FunctionType::get(Vec3Ty, {Vec3Ty, Vec3Ty}, false);
SJW61531372020-06-09 07:31:08 -05002409 auto NewFName = Builtins::GetMangledFunctionName("cross", NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002410
SJW61531372020-06-09 07:31:08 -05002411 auto Cross3Func = M.getOrInsertFunction(NewFName, NewFType);
David Neto22f144c2017-06-12 14:26:21 -04002412
Kévin Petite8edce32019-04-10 14:23:32 +01002413 auto DownResult = CallInst::Create(Cross3Func, {Arg0, Arg1}, "", CI);
David Neto22f144c2017-06-12 14:26:21 -04002414
Diego Novillo3cc8d7a2019-04-10 13:30:34 -04002415 return new ShuffleVectorInst(DownResult, ConstantVector::get(FloatVec),
2416 ConstantVector::get(UpShuffleMask), "", CI);
Kévin Petite8edce32019-04-10 14:23:32 +01002417 });
David Neto22f144c2017-06-12 14:26:21 -04002418}
David Neto62653202017-10-16 19:05:18 -04002419
SJW2c317da2020-03-23 07:39:13 -05002420bool ReplaceOpenCLBuiltinPass::replaceFract(Function &F, int vec_size) {
David Neto62653202017-10-16 19:05:18 -04002421 // OpenCL's float result = fract(float x, float* ptr)
2422 //
2423 // In the LLVM domain:
2424 //
2425 // %floor_result = call spir_func float @floor(float %x)
2426 // store float %floor_result, float * %ptr
2427 // %fract_intermediate = call spir_func float @clspv.fract(float %x)
2428 // %result = call spir_func float
2429 // @fmin(float %fract_intermediate, float 0x1.fffffep-1f)
2430 //
2431 // Becomes in the SPIR-V domain, where translations of floor, fmin,
2432 // and clspv.fract occur in the SPIR-V generator pass:
2433 //
2434 // %glsl_ext = OpExtInstImport "GLSL.std.450"
2435 // %just_under_1 = OpConstant %float 0x1.fffffep-1f
2436 // ...
2437 // %floor_result = OpExtInst %float %glsl_ext Floor %x
2438 // OpStore %ptr %floor_result
2439 // %fract_intermediate = OpExtInst %float %glsl_ext Fract %x
2440 // %fract_result = OpExtInst %float
Marco Antognini55d51862020-07-21 17:50:07 +01002441 // %glsl_ext Nmin %fract_intermediate %just_under_1
David Neto62653202017-10-16 19:05:18 -04002442
David Neto62653202017-10-16 19:05:18 -04002443 using std::string;
2444
2445 // Mapping from the fract builtin to the floor, fmin, and clspv.fract builtins
2446 // we need. The clspv.fract builtin is the same as GLSL.std.450 Fract.
David Neto62653202017-10-16 19:05:18 -04002447
SJW2c317da2020-03-23 07:39:13 -05002448 Module &M = *F.getParent();
2449 return replaceCallsWithValue(F, [&](CallInst *CI) {
David Neto62653202017-10-16 19:05:18 -04002450
SJW2c317da2020-03-23 07:39:13 -05002451 // This is either float or a float vector. All the float-like
2452 // types are this type.
2453 auto result_ty = F.getReturnType();
2454
SJW61531372020-06-09 07:31:08 -05002455 std::string fmin_name = Builtins::GetMangledFunctionName("fmin", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002456 Function *fmin_fn = M.getFunction(fmin_name);
2457 if (!fmin_fn) {
2458 // Make the fmin function.
2459 FunctionType *fn_ty =
2460 FunctionType::get(result_ty, {result_ty, result_ty}, false);
2461 fmin_fn =
2462 cast<Function>(M.getOrInsertFunction(fmin_name, fn_ty).getCallee());
2463 fmin_fn->addFnAttr(Attribute::ReadNone);
2464 fmin_fn->setCallingConv(CallingConv::SPIR_FUNC);
2465 }
2466
SJW61531372020-06-09 07:31:08 -05002467 std::string floor_name =
2468 Builtins::GetMangledFunctionName("floor", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002469 Function *floor_fn = M.getFunction(floor_name);
2470 if (!floor_fn) {
2471 // Make the floor function.
2472 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2473 floor_fn =
2474 cast<Function>(M.getOrInsertFunction(floor_name, fn_ty).getCallee());
2475 floor_fn->addFnAttr(Attribute::ReadNone);
2476 floor_fn->setCallingConv(CallingConv::SPIR_FUNC);
2477 }
2478
SJW61531372020-06-09 07:31:08 -05002479 std::string clspv_fract_name =
2480 Builtins::GetMangledFunctionName("clspv.fract", result_ty);
SJW2c317da2020-03-23 07:39:13 -05002481 Function *clspv_fract_fn = M.getFunction(clspv_fract_name);
2482 if (!clspv_fract_fn) {
2483 // Make the clspv_fract function.
2484 FunctionType *fn_ty = FunctionType::get(result_ty, {result_ty}, false);
2485 clspv_fract_fn = cast<Function>(
2486 M.getOrInsertFunction(clspv_fract_name, fn_ty).getCallee());
2487 clspv_fract_fn->addFnAttr(Attribute::ReadNone);
2488 clspv_fract_fn->setCallingConv(CallingConv::SPIR_FUNC);
2489 }
2490
2491 // Number of significant significand bits, whether represented or not.
2492 unsigned num_significand_bits;
2493 switch (result_ty->getScalarType()->getTypeID()) {
2494 case Type::HalfTyID:
2495 num_significand_bits = 11;
2496 break;
2497 case Type::FloatTyID:
2498 num_significand_bits = 24;
2499 break;
2500 case Type::DoubleTyID:
2501 num_significand_bits = 53;
2502 break;
2503 default:
2504 llvm_unreachable("Unhandled float type when processing fract builtin");
2505 break;
2506 }
2507 // Beware that the disassembler displays this value as
2508 // OpConstant %float 1
2509 // which is not quite right.
2510 const double kJustUnderOneScalar =
2511 ldexp(double((1 << num_significand_bits) - 1), -num_significand_bits);
2512
2513 Constant *just_under_one =
2514 ConstantFP::get(result_ty->getScalarType(), kJustUnderOneScalar);
2515 if (result_ty->isVectorTy()) {
2516 just_under_one = ConstantVector::getSplat(
alan-baker931253b2020-08-20 17:15:38 -04002517 cast<VectorType>(result_ty)->getElementCount(), just_under_one);
SJW2c317da2020-03-23 07:39:13 -05002518 }
2519
2520 IRBuilder<> Builder(CI);
2521
2522 auto arg = CI->getArgOperand(0);
2523 auto ptr = CI->getArgOperand(1);
2524
2525 // Compute floor result and store it.
2526 auto floor = Builder.CreateCall(floor_fn, {arg});
2527 Builder.CreateStore(floor, ptr);
2528
2529 auto fract_intermediate = Builder.CreateCall(clspv_fract_fn, arg);
2530 auto fract_result =
2531 Builder.CreateCall(fmin_fn, {fract_intermediate, just_under_one});
2532
2533 return fract_result;
2534 });
David Neto62653202017-10-16 19:05:18 -04002535}
alan-bakera52b7312020-10-26 08:58:51 -04002536
Kévin Petit8576f682020-11-02 14:51:32 +00002537bool ReplaceOpenCLBuiltinPass::replaceHadd(Function &F, bool is_signed,
alan-bakerb6da5132020-10-29 15:59:06 -04002538 Instruction::BinaryOps join_opcode) {
Kévin Petit8576f682020-11-02 14:51:32 +00002539 return replaceCallsWithValue(F, [is_signed, join_opcode](CallInst *Call) {
alan-bakerb6da5132020-10-29 15:59:06 -04002540 // a_shr = a >> 1
2541 // b_shr = b >> 1
2542 // add1 = a_shr + b_shr
2543 // join = a |join_opcode| b
2544 // and = join & 1
2545 // add = add1 + and
2546 const auto a = Call->getArgOperand(0);
2547 const auto b = Call->getArgOperand(1);
2548 IRBuilder<> builder(Call);
Kévin Petit8576f682020-11-02 14:51:32 +00002549 Value *a_shift, *b_shift;
2550 if (is_signed) {
2551 a_shift = builder.CreateAShr(a, 1);
2552 b_shift = builder.CreateAShr(b, 1);
2553 } else {
2554 a_shift = builder.CreateLShr(a, 1);
2555 b_shift = builder.CreateLShr(b, 1);
2556 }
alan-bakerb6da5132020-10-29 15:59:06 -04002557 auto add = builder.CreateAdd(a_shift, b_shift);
2558 auto join = BinaryOperator::Create(join_opcode, a, b, "", Call);
2559 auto constant_one = ConstantInt::get(a->getType(), 1);
2560 auto and_bit = builder.CreateAnd(join, constant_one);
2561 return builder.CreateAdd(add, and_bit);
2562 });
2563}
2564
alan-bakera52b7312020-10-26 08:58:51 -04002565bool ReplaceOpenCLBuiltinPass::replaceAddSat(Function &F, bool is_signed) {
2566 Module *module = F.getParent();
alan-baker6b9d1ee2020-11-03 23:11:32 -05002567 return replaceCallsWithValue(F, [&module, is_signed, this](CallInst *Call) {
alan-bakera52b7312020-10-26 08:58:51 -04002568 // SPIR-V OpIAddCarry interprets inputs as unsigned. We use that
2569 // instruction for unsigned additions. For signed addition, it is more
2570 // complicated. For values with bit widths less than 32 bits, we extend
2571 // to the next power of two and perform the addition. For 32- and
2572 // 64-bit values we test the signedness of op1 to determine how to clamp
2573 // the addition.
2574 Type *ty = Call->getType();
2575 Value *op0 = Call->getArgOperand(0);
2576 Value *op1 = Call->getArgOperand(1);
2577 Value *result = nullptr;
2578 if (is_signed) {
2579 unsigned bitwidth = ty->getScalarSizeInBits();
2580 if (bitwidth < 32) {
2581 // sext_op0 = sext op0
2582 // sext_op1 = sext op1
2583 // add = add sext_op0 sext_op1
2584 // clamp = clamp(add, min, max)
2585 // result = trunc clamp
2586 unsigned extended_bits = static_cast<unsigned>(bitwidth << 1);
2587 // The clamp values are the signed min and max of the original bitwidth
2588 // sign extended to the extended bitwidth.
2589 Constant *scalar_min = ConstantInt::get(
2590 Call->getContext(),
2591 APInt::getSignedMinValue(bitwidth).sext(extended_bits));
2592 Constant *scalar_max = ConstantInt::get(
2593 Call->getContext(),
2594 APInt::getSignedMaxValue(bitwidth).sext(extended_bits));
2595 Constant *min = scalar_min;
2596 Constant *max = scalar_max;
2597 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2598 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2599 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2600 }
2601 Type *extended_scalar_ty =
2602 IntegerType::get(Call->getContext(), extended_bits);
2603 Type *extended_ty = extended_scalar_ty;
2604 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2605 extended_ty =
2606 VectorType::get(extended_scalar_ty, vec_ty->getElementCount());
2607 }
2608 auto sext_op0 =
2609 CastInst::Create(Instruction::SExt, op0, extended_ty, "", Call);
2610 auto sext_op1 =
2611 CastInst::Create(Instruction::SExt, op1, extended_ty, "", Call);
2612 // Add the nsw flag since we know no overflow can occur.
2613 auto add = BinaryOperator::CreateNSW(Instruction::Add, sext_op0,
2614 sext_op1, "", Call);
2615 FunctionType *func_ty = FunctionType::get(
2616 extended_ty, {extended_ty, extended_ty, extended_ty}, false);
2617
2618 // Don't use the type in GetMangledFunctionName to ensure we get
2619 // signed parameters.
2620 std::string sclamp_name = Builtins::GetMangledFunctionName("clamp");
2621 uint32_t vec_width = 1;
2622 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2623 vec_width = vec_ty->getElementCount().getKnownMinValue();
2624 }
2625 if (extended_bits == 32) {
2626 if (vec_width == 1) {
2627 sclamp_name += "iii";
2628 } else {
2629 sclamp_name += "Dv" + std::to_string(vec_width) + "_iS_S_";
2630 }
2631 } else {
2632 if (vec_width == 1) {
2633 sclamp_name += "sss";
2634 } else {
2635 sclamp_name += "Dv" + std::to_string(vec_width) + "_sS_S_";
2636 }
2637 }
2638 auto sclamp_callee = module->getOrInsertFunction(sclamp_name, func_ty);
2639 auto clamp = CallInst::Create(sclamp_callee, {add, min, max}, "", Call);
2640 result = CastInst::Create(Instruction::Trunc, clamp, ty, "", Call);
2641 } else {
2642 // Pseudo-code:
2643 // c = a + b;
2644 // if (b < 0)
2645 // c = c > a ? min : c;
2646 // else
2647 // c = c < a ? max : c;
2648 //
2649 unsigned bitwidth = ty->getScalarSizeInBits();
2650 Constant *scalar_min = ConstantInt::get(
2651 Call->getContext(), APInt::getSignedMinValue(bitwidth));
2652 Constant *scalar_max = ConstantInt::get(
2653 Call->getContext(), APInt::getSignedMaxValue(bitwidth));
2654 Constant *min = scalar_min;
2655 Constant *max = scalar_max;
2656 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2657 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2658 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2659 }
2660 auto zero = Constant::getNullValue(ty);
2661 // Cannot add the nsw flag.
2662 auto add = BinaryOperator::Create(Instruction::Add, op0, op1, "", Call);
2663 auto add_gt_op0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT,
2664 add, op0, "", Call);
2665 auto min_clamp = SelectInst::Create(add_gt_op0, min, add, "", Call);
2666 auto add_lt_op0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT,
2667 add, op0, "", Call);
2668 auto max_clamp = SelectInst::Create(add_lt_op0, max, add, "", Call);
2669 auto op1_lt_0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT,
2670 op1, zero, "", Call);
2671 result = SelectInst::Create(op1_lt_0, min_clamp, max_clamp, "", Call);
2672 }
2673 } else {
2674 // Just use OpIAddCarry and use the carry to clamp the result.
alan-baker6b9d1ee2020-11-03 23:11:32 -05002675 auto ret_ty = GetPairStruct(ty);
alan-bakera52b7312020-10-26 08:58:51 -04002676 auto add = clspv::InsertSPIRVOp(
2677 Call, spv::OpIAddCarry, {Attribute::ReadNone}, ret_ty, {op0, op1});
2678 auto ex0 = ExtractValueInst::Create(add, {0}, "", Call);
2679 auto ex1 = ExtractValueInst::Create(add, {1}, "", Call);
2680 auto cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, ex1,
2681 Constant::getNullValue(ty), "", Call);
2682 result =
2683 SelectInst::Create(cmp, ex0, Constant::getAllOnesValue(ty), "", Call);
2684 }
2685
2686 return result;
2687 });
2688}
alan-baker4986eff2020-10-29 13:38:00 -04002689
2690bool ReplaceOpenCLBuiltinPass::replaceAtomicLoad(Function &F) {
2691 return replaceCallsWithValue(F, [](CallInst *Call) {
2692 auto pointer = Call->getArgOperand(0);
2693 // Clang emits an address space cast to the generic address space. Skip the
2694 // cast and use the input directly.
2695 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2696 pointer = cast->getPointerOperand();
2697 }
2698 Value *order_arg =
2699 Call->getNumArgOperands() > 1 ? Call->getArgOperand(1) : nullptr;
2700 Value *scope_arg =
2701 Call->getNumArgOperands() > 2 ? Call->getArgOperand(2) : nullptr;
2702 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2703 clspv::AddressSpace::Global;
2704 auto order = MemoryOrderSemantics(order_arg, is_global, Call,
2705 spv::MemorySemanticsAcquireMask);
2706 auto scope = MemoryScope(scope_arg, is_global, Call);
2707 return InsertSPIRVOp(Call, spv::OpAtomicLoad, {Attribute::Convergent},
2708 Call->getType(), {pointer, scope, order});
2709 });
2710}
2711
2712bool ReplaceOpenCLBuiltinPass::replaceExplicitAtomics(
2713 Function &F, spv::Op Op, spv::MemorySemanticsMask semantics) {
2714 return replaceCallsWithValue(F, [Op, semantics](CallInst *Call) {
2715 auto pointer = Call->getArgOperand(0);
2716 // Clang emits an address space cast to the generic address space. Skip the
2717 // cast and use the input directly.
2718 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2719 pointer = cast->getPointerOperand();
2720 }
2721 Value *value = Call->getArgOperand(1);
2722 Value *order_arg =
2723 Call->getNumArgOperands() > 2 ? Call->getArgOperand(2) : nullptr;
2724 Value *scope_arg =
2725 Call->getNumArgOperands() > 3 ? Call->getArgOperand(3) : nullptr;
2726 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2727 clspv::AddressSpace::Global;
2728 auto scope = MemoryScope(scope_arg, is_global, Call);
2729 auto order = MemoryOrderSemantics(order_arg, is_global, Call, semantics);
2730 return InsertSPIRVOp(Call, Op, {Attribute::Convergent}, Call->getType(),
2731 {pointer, scope, order, value});
2732 });
2733}
2734
2735bool ReplaceOpenCLBuiltinPass::replaceAtomicCompareExchange(Function &F) {
2736 return replaceCallsWithValue(F, [](CallInst *Call) {
2737 auto pointer = Call->getArgOperand(0);
2738 // Clang emits an address space cast to the generic address space. Skip the
2739 // cast and use the input directly.
2740 if (auto cast = dyn_cast<AddrSpaceCastOperator>(pointer)) {
2741 pointer = cast->getPointerOperand();
2742 }
2743 auto expected = Call->getArgOperand(1);
2744 if (auto cast = dyn_cast<AddrSpaceCastOperator>(expected)) {
2745 expected = cast->getPointerOperand();
2746 }
2747 auto value = Call->getArgOperand(2);
2748 bool is_global = pointer->getType()->getPointerAddressSpace() ==
2749 clspv::AddressSpace::Global;
2750 Value *success_arg =
2751 Call->getNumArgOperands() > 3 ? Call->getArgOperand(3) : nullptr;
2752 Value *failure_arg =
2753 Call->getNumArgOperands() > 4 ? Call->getArgOperand(4) : nullptr;
2754 Value *scope_arg =
2755 Call->getNumArgOperands() > 5 ? Call->getArgOperand(5) : nullptr;
2756 auto scope = MemoryScope(scope_arg, is_global, Call);
2757 auto success = MemoryOrderSemantics(success_arg, is_global, Call,
2758 spv::MemorySemanticsAcquireReleaseMask);
2759 auto failure = MemoryOrderSemantics(failure_arg, is_global, Call,
2760 spv::MemorySemanticsAcquireMask);
2761
2762 // If the value pointed to by |expected| equals the value pointed to by
2763 // |pointer|, |value| is written into |pointer|, otherwise the value in
2764 // |pointer| is written into |expected|. In order to avoid extra stores,
2765 // the basic block with the original atomic is split and the store is
2766 // performed in the |then| block. The condition is the inversion of the
2767 // comparison result.
2768 IRBuilder<> builder(Call);
2769 auto load = builder.CreateLoad(expected);
2770 auto cmp_xchg = InsertSPIRVOp(
2771 Call, spv::OpAtomicCompareExchange, {Attribute::Convergent},
2772 value->getType(), {pointer, scope, success, failure, value, load});
2773 auto cmp = builder.CreateICmpEQ(cmp_xchg, load);
2774 auto not_cmp = builder.CreateNot(cmp);
2775 auto then_branch = SplitBlockAndInsertIfThen(not_cmp, Call, false);
2776 builder.SetInsertPoint(then_branch);
2777 builder.CreateStore(cmp_xchg, expected);
2778 return cmp;
2779 });
2780}
alan-bakercc2bafb2020-11-02 08:30:18 -05002781
2782bool ReplaceOpenCLBuiltinPass::replaceClz(Function &F) {
2783 if (!isa<IntegerType>(F.getReturnType()->getScalarType()))
2784 return false;
2785
2786 auto bitwidth = F.getReturnType()->getScalarSizeInBits();
2787 if (bitwidth == 32 || bitwidth > 64)
2788 return false;
2789
2790 return replaceCallsWithValue(F, [&F, bitwidth](CallInst *Call) {
2791 auto in = Call->getArgOperand(0);
2792 IRBuilder<> builder(Call);
2793 auto int32_ty = builder.getInt32Ty();
2794 Type *ty = int32_ty;
2795 if (auto vec_ty = dyn_cast<VectorType>(Call->getType())) {
2796 ty = VectorType::get(ty, vec_ty->getElementCount());
2797 }
2798 auto clz_32bit_ty = FunctionType::get(ty, {ty}, false);
2799 std::string clz_32bit_name = Builtins::GetMangledFunctionName("clz", ty);
2800 auto clz_32bit =
2801 F.getParent()->getOrInsertFunction(clz_32bit_name, clz_32bit_ty);
2802 if (bitwidth < 32) {
2803 // Extend the input to 32-bits and perform a clz. The clz for 32-bit is
2804 // translated as 31 - FindUMsb(in). Adjust that result to the right size.
2805 auto zext = builder.CreateZExt(in, ty);
2806 auto clz = builder.CreateCall(clz_32bit, {zext});
2807 Constant *sub_const = builder.getInt32(32 - bitwidth);
2808 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2809 sub_const =
2810 ConstantVector::getSplat(vec_ty->getElementCount(), sub_const);
2811 }
2812 auto sub = builder.CreateSub(clz, sub_const);
2813 return builder.CreateTrunc(sub, Call->getType());
2814 } else {
2815 // Split the input into top and bottom parts and perform clz on both. If
2816 // the most significant 1 is in the upper 32-bits, return the top result
2817 // directly. Otherwise return 32 + the bottom result to adjust for the
2818 // correct size.
2819 auto lshr = builder.CreateLShr(in, 32);
2820 auto top_bits = builder.CreateTrunc(lshr, ty);
2821 auto bot_bits = builder.CreateTrunc(in, ty);
2822 auto top_clz = builder.CreateCall(clz_32bit, {top_bits});
2823 auto bot_clz = builder.CreateCall(clz_32bit, {bot_bits});
2824 Constant *c32 = builder.getInt32(32);
2825 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2826 c32 = ConstantVector::getSplat(vec_ty->getElementCount(), c32);
2827 }
2828 auto cmp = builder.CreateICmpEQ(top_clz, c32);
2829 auto bot_adjust = builder.CreateAdd(bot_clz, c32);
2830 auto sel = builder.CreateSelect(cmp, bot_adjust, top_clz);
2831 return builder.CreateZExt(sel, Call->getType());
2832 }
2833 });
2834}
alan-baker6b9d1ee2020-11-03 23:11:32 -05002835
2836bool ReplaceOpenCLBuiltinPass::replaceMadSat(Function &F, bool is_signed) {
2837 return replaceCallsWithValue(F, [&F, is_signed, this](CallInst *Call) {
2838 const auto ty = Call->getType();
2839 const auto a = Call->getArgOperand(0);
2840 const auto b = Call->getArgOperand(1);
2841 const auto c = Call->getArgOperand(2);
2842 IRBuilder<> builder(Call);
2843 if (is_signed) {
2844 unsigned bitwidth = Call->getType()->getScalarSizeInBits();
2845 if (bitwidth < 32) {
2846 // mul = sext(a) * sext(b)
2847 // add = mul + sext(c)
2848 // res = clamp(add, MIN, MAX)
2849 unsigned extended_width = bitwidth << 1;
2850 Type *extended_ty = IntegerType::get(F.getContext(), extended_width);
2851 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2852 extended_ty = VectorType::get(extended_ty, vec_ty->getElementCount());
2853 }
2854 auto a_sext = builder.CreateSExt(a, extended_ty);
2855 auto b_sext = builder.CreateSExt(b, extended_ty);
2856 auto c_sext = builder.CreateSExt(c, extended_ty);
2857 // Extended the size so no overflows occur.
2858 auto mul = builder.CreateMul(a_sext, b_sext, "", true, true);
2859 auto add = builder.CreateAdd(mul, c_sext, "", true, true);
2860 auto func_ty = FunctionType::get(
2861 extended_ty, {extended_ty, extended_ty, extended_ty}, false);
2862 // Don't use function type because we need signed parameters.
2863 std::string clamp_name = Builtins::GetMangledFunctionName("clamp");
2864 // The clamp values are the signed min and max of the original bitwidth
2865 // sign extended to the extended bitwidth.
2866 Constant *min = ConstantInt::get(
2867 Call->getContext(),
2868 APInt::getSignedMinValue(bitwidth).sext(extended_width));
2869 Constant *max = ConstantInt::get(
2870 Call->getContext(),
2871 APInt::getSignedMaxValue(bitwidth).sext(extended_width));
2872 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2873 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2874 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2875 unsigned vec_width = vec_ty->getElementCount().getKnownMinValue();
2876 if (extended_width == 32)
2877 clamp_name += "Dv" + std::to_string(vec_width) + "_iS_S_";
2878 else
2879 clamp_name += "Dv" + std::to_string(vec_width) + "_sS_S_";
2880 } else {
2881 if (extended_width == 32)
2882 clamp_name += "iii";
2883 else
2884 clamp_name += "sss";
2885 }
2886 auto callee = F.getParent()->getOrInsertFunction(clamp_name, func_ty);
2887 auto clamp = builder.CreateCall(callee, {add, min, max});
2888 return builder.CreateTrunc(clamp, ty);
2889 } else {
2890 auto struct_ty = GetPairStruct(ty);
2891 // Compute
2892 // {hi, lo} = smul_extended(a, b)
2893 // add = lo + c
2894 auto mul_ext = InsertSPIRVOp(Call, spv::OpSMulExtended,
2895 {Attribute::ReadNone}, struct_ty, {a, b});
2896 auto mul_lo = builder.CreateExtractValue(mul_ext, {0});
2897 auto mul_hi = builder.CreateExtractValue(mul_ext, {1});
2898 auto add = builder.CreateAdd(mul_lo, c);
2899
2900 // Constants for use in the calculation.
2901 Constant *min = ConstantInt::get(Call->getContext(),
2902 APInt::getSignedMinValue(bitwidth));
2903 Constant *max = ConstantInt::get(Call->getContext(),
2904 APInt::getSignedMaxValue(bitwidth));
2905 Constant *max_plus_1 = ConstantInt::get(
2906 Call->getContext(),
2907 APInt::getSignedMaxValue(bitwidth) + APInt(bitwidth, 1));
2908 if (auto vec_ty = dyn_cast<VectorType>(ty)) {
2909 min = ConstantVector::getSplat(vec_ty->getElementCount(), min);
2910 max = ConstantVector::getSplat(vec_ty->getElementCount(), max);
2911 max_plus_1 =
2912 ConstantVector::getSplat(vec_ty->getElementCount(), max_plus_1);
2913 }
2914
2915 auto a_xor_b = builder.CreateXor(a, b);
2916 auto same_sign =
2917 builder.CreateICmpSGT(a_xor_b, Constant::getAllOnesValue(ty));
2918 auto different_sign = builder.CreateNot(same_sign);
2919 auto hi_eq_0 = builder.CreateICmpEQ(mul_hi, Constant::getNullValue(ty));
2920 auto hi_ne_0 = builder.CreateNot(hi_eq_0);
2921 auto lo_ge_max = builder.CreateICmpUGE(mul_lo, max);
2922 auto c_gt_0 = builder.CreateICmpSGT(c, Constant::getNullValue(ty));
2923 auto c_lt_0 = builder.CreateICmpSLT(c, Constant::getNullValue(ty));
2924 auto add_gt_max = builder.CreateICmpUGT(add, max);
2925 auto hi_eq_m1 =
2926 builder.CreateICmpEQ(mul_hi, Constant::getAllOnesValue(ty));
2927 auto hi_ne_m1 = builder.CreateNot(hi_eq_m1);
2928 auto lo_le_max_plus_1 = builder.CreateICmpULE(mul_lo, max_plus_1);
2929 auto max_sub_lo = builder.CreateSub(max, mul_lo);
2930 auto c_lt_max_sub_lo = builder.CreateICmpULT(c, max_sub_lo);
2931
2932 // Equivalent to:
2933 // if (((x < 0) == (y < 0)) && mul_hi != 0)
2934 // return MAX
2935 // if (mul_hi == 0 && mul_lo >= MAX && (z > 0 || add > MAX))
2936 // return MAX
2937 // if (((x < 0) != (y < 0)) && mul_hi != -1)
2938 // return MIN
2939 // if (hi == -1 && mul_lo <= (MAX + 1) && (z < 0 || z < (MAX - mul_lo))
2940 // return MIN
2941 // return add
2942 auto max_clamp_1 = builder.CreateAnd(same_sign, hi_ne_0);
2943 auto max_clamp_2 = builder.CreateOr(c_gt_0, add_gt_max);
2944 auto tmp = builder.CreateAnd(hi_eq_0, lo_ge_max);
2945 max_clamp_2 = builder.CreateAnd(tmp, max_clamp_2);
2946 auto max_clamp = builder.CreateOr(max_clamp_1, max_clamp_2);
2947 auto min_clamp_1 = builder.CreateAnd(different_sign, hi_ne_m1);
2948 auto min_clamp_2 = builder.CreateOr(c_lt_0, c_lt_max_sub_lo);
2949 tmp = builder.CreateAnd(hi_eq_m1, lo_le_max_plus_1);
2950 min_clamp_2 = builder.CreateAnd(tmp, min_clamp_2);
2951 auto min_clamp = builder.CreateOr(min_clamp_1, min_clamp_2);
2952 auto sel = builder.CreateSelect(min_clamp, min, add);
2953 return builder.CreateSelect(max_clamp, max, sel);
2954 }
2955 } else {
2956 // {lo, hi} = mul_extended(a, b)
2957 // {add, carry} = add_carry(lo, c)
2958 // cmp = (mul_hi | carry) == 0
2959 // mad_sat = cmp ? add : MAX
2960 auto struct_ty = GetPairStruct(ty);
2961 auto mul_ext = InsertSPIRVOp(Call, spv::OpUMulExtended,
2962 {Attribute::ReadNone}, struct_ty, {a, b});
2963 auto mul_lo = builder.CreateExtractValue(mul_ext, {0});
2964 auto mul_hi = builder.CreateExtractValue(mul_ext, {1});
2965 auto add_carry =
2966 InsertSPIRVOp(Call, spv::OpIAddCarry, {Attribute::ReadNone},
2967 struct_ty, {mul_lo, c});
2968 auto add = builder.CreateExtractValue(add_carry, {0});
2969 auto carry = builder.CreateExtractValue(add_carry, {1});
2970 auto or_value = builder.CreateOr(mul_hi, carry);
2971 auto cmp = builder.CreateICmpEQ(or_value, Constant::getNullValue(ty));
2972 return builder.CreateSelect(cmp, add, Constant::getAllOnesValue(ty));
2973 }
2974 });
2975}