blob: 341b1153c21bd5a60c783fdecd48d8dd42ecfa09 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
alan-bakere0902602020-03-23 08:43:40 -040024#include "spirv/unified1/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
SJW61531372020-06-09 07:31:08 -050026#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040027#include "Passes.h"
28
David Neto22f144c2017-06-12 14:26:21 -040029using namespace llvm;
30
31#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
32
33namespace {
34struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
35 static char ID;
36 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
37
38 bool runOnModule(Module &M) override;
James Price3a116142020-10-16 06:52:18 -040039 bool replaceFshl(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040040 bool replaceMemset(Module &M);
41 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040042 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040043};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044} // namespace
David Neto22f144c2017-06-12 14:26:21 -040045
46char ReplaceLLVMIntrinsicsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040047INITIALIZE_PASS(ReplaceLLVMIntrinsicsPass, "ReplaceLLVMIntrinsics",
48 "Replace LLVM intrinsics Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040049
50namespace clspv {
51ModulePass *createReplaceLLVMIntrinsicsPass() {
52 return new ReplaceLLVMIntrinsicsPass();
53}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040054} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040055
56bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
57 bool Changed = false;
58
James Price3a116142020-10-16 06:52:18 -040059 // Remove lifetime annotations first. They could be using memset
David Netoe345e0e2018-06-15 11:38:32 -040060 // and memcpy calls.
61 Changed |= removeLifetimeDeclarations(M);
James Price3a116142020-10-16 06:52:18 -040062 Changed |= replaceFshl(M);
David Neto22f144c2017-06-12 14:26:21 -040063 Changed |= replaceMemset(M);
64 Changed |= replaceMemcpy(M);
65
66 return Changed;
67}
68
James Price3a116142020-10-16 06:52:18 -040069bool ReplaceLLVMIntrinsicsPass::replaceFshl(Module &M) {
70 bool changed = false;
71
72 // Get list of fshl intrinsic declarations.
73 SmallVector<Function *, 8> intrinsics;
74 for (auto &func : M) {
75 if (func.getName().startswith("llvm.fshl")) {
76 intrinsics.push_back(&func);
77 }
78 }
79
80 for (auto func : intrinsics) {
81 // Get list of callsites.
82 SmallVector<CallInst *, 8> callsites;
83 for (auto user : func->users()) {
84 if (auto call = dyn_cast<CallInst>(user)) {
85 callsites.push_back(call);
86 }
87 }
88
89 // Replace each callsite with a manual implementation.
90 for (auto call : callsites) {
91 auto arg_hi = call->getArgOperand(0);
92 auto arg_lo = call->getArgOperand(1);
93 auto arg_shift = call->getArgOperand(2);
94
95 // Validate argument types.
96 auto type = arg_hi->getType();
97 if ((type->getScalarSizeInBits() != 8) &&
98 (type->getScalarSizeInBits() != 16) &&
99 (type->getScalarSizeInBits() != 32) &&
100 (type->getScalarSizeInBits() != 64)) {
101 llvm_unreachable("Invalid integer width in llvm.fshl intrinsic");
102 return false;
103 }
104
105 changed = true;
106
107 // We shift the bottom bits of the first argument up, the top bits of the
108 // second argument down, and then OR the two shifted values.
109
110 // The shift amount is treated modulo the element size.
111 auto mod_mask = ConstantInt::get(type, type->getScalarSizeInBits() - 1);
112 auto shift_amount = BinaryOperator::Create(Instruction::And, arg_shift,
113 mod_mask, "", call);
114
115 // Calculate the amount by which to shift the second argument down.
116 auto scalar_size = ConstantInt::get(type, type->getScalarSizeInBits());
117 auto down_amount = BinaryOperator::Create(Instruction::Sub, scalar_size,
118 shift_amount, "", call);
119
120 // Shift the two arguments and OR the results together.
121 auto hi_bits = BinaryOperator::Create(Instruction::Shl, arg_hi,
122 shift_amount, "", call);
123 auto lo_bits = BinaryOperator::Create(Instruction::LShr, arg_lo,
124 down_amount, "", call);
125 auto result =
126 BinaryOperator::Create(Instruction::Or, lo_bits, hi_bits, "", call);
127
128 // Replace the original call with the manually computed result.
129 call->replaceAllUsesWith(result);
130 call->eraseFromParent();
131 }
132
133 func->eraseFromParent();
134 }
135
136 return changed;
137}
138
David Neto22f144c2017-06-12 14:26:21 -0400139bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
140 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -0400141 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -0400142
143 for (auto &F : M) {
144 if (F.getName().startswith("llvm.memset")) {
145 SmallVector<CallInst *, 8> CallsToReplace;
146
147 for (auto U : F.users()) {
148 if (auto CI = dyn_cast<CallInst>(U)) {
149 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
150
151 // We only handle cases where the initializer is a constant int that
152 // is 0.
153 if (!Initializer || (0 != Initializer->getZExtValue())) {
154 Initializer->print(errs());
155 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
156 "non-0 initializer!");
157 }
158
159 CallsToReplace.push_back(CI);
160 }
161 }
162
163 for (auto CI : CallsToReplace) {
164 auto NewArg = CI->getArgOperand(0);
Kévin Petit70944912019-04-17 23:22:28 +0100165 auto Bitcast = dyn_cast<BitCastInst>(NewArg);
166 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -0400167 NewArg = Bitcast->getOperand(0);
168 }
169
David Netod3f59382017-10-18 18:30:30 -0400170 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -0400171 auto Ty = NewArg->getType();
172 auto PointeeTy = Ty->getPointerElementType();
David Neto22f144c2017-06-12 14:26:21 -0400173 auto Zero = Constant::getNullValue(PointeeTy);
174
David Netod3f59382017-10-18 18:30:30 -0400175 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
176 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
177 "Null memset can't be divided evenly across multiple stores.");
178 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400179
David Netod3f59382017-10-18 18:30:30 -0400180 // Generate the first store.
Kévin Petit58c445c2019-06-18 18:09:46 +0100181 new StoreInst(Zero, NewArg, CI);
David Netod3f59382017-10-18 18:30:30 -0400182
183 // Generate subsequent stores, but only if needed.
184 if (num_stores) {
185 auto I32Ty = Type::getInt32Ty(M.getContext());
186 auto One = ConstantInt::get(I32Ty, 1);
187 auto Ptr = NewArg;
188 for (uint32_t i = 1; i < num_stores; i++) {
189 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
Kévin Petit58c445c2019-06-18 18:09:46 +0100190 new StoreInst(Zero, Ptr, CI);
David Netod3f59382017-10-18 18:30:30 -0400191 }
192 }
193
David Neto22f144c2017-06-12 14:26:21 -0400194 CI->eraseFromParent();
195
Kévin Petit70944912019-04-17 23:22:28 +0100196 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -0400197 Bitcast->eraseFromParent();
198 }
199 }
200 }
201 }
202
203 return Changed;
204}
205
206bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
207 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400208 auto Layout = M.getDataLayout();
209
210 // Unpack source and destination types until we find a matching
211 // element type. Count the number of levels we unpack for the
212 // source and destination types. So far this only works for
213 // array types, but could be generalized to other regular types
214 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400215 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
216 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400217 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400218 auto descend_type = [](Type *InType) {
219 Type *OutType = InType;
220 if (OutType->isStructTy()) {
221 OutType = OutType->getStructElementType(0);
222 } else if (OutType->isArrayTy()) {
223 OutType = OutType->getArrayElementType();
James Pricecf53df42020-04-20 14:41:24 -0400224 } else if (auto vec_type = dyn_cast<VectorType>(OutType)) {
225 OutType = vec_type->getElementType();
Alan Baker7dea8842018-10-22 10:15:41 -0400226 } else {
227 assert(false && "Don't know how to descend into type");
228 }
229
230 return OutType;
231 };
232
David Netob84ba342017-06-19 17:55:37 -0400233 while (*SrcElemTy != *DstElemTy) {
234 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
235 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
236 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400237 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400238 (*NumSrcUnpackings)++;
239 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400240 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400241 (*NumDstUnpackings)++;
242 } else {
243 errs() << "Don't know how to unpack types for memcpy: " << CI
244 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
245 assert(false && "Don't know how to unpack these types");
246 }
247 }
Alan Baker7dea8842018-10-22 10:15:41 -0400248
249 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
250 while (Size < DstElemSize) {
251 *DstElemTy = descend_type(*DstElemTy);
252 *SrcElemTy = descend_type(*SrcElemTy);
253 (*NumDstUnpackings)++;
254 (*NumSrcUnpackings)++;
255 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
256 }
David Netob84ba342017-06-19 17:55:37 -0400257 };
David Neto22f144c2017-06-12 14:26:21 -0400258
259 for (auto &F : M) {
260 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400261 SmallPtrSet<Instruction *, 8> BitCastsToForget;
262 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400263
264 for (auto U : F.users()) {
265 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500266 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400267 auto Dst =
268 dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400269
alan-bakered80f572019-02-11 17:28:26 -0500270 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400271 auto Src =
272 dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400273
274 // The original type of Dst we get from the argument to the bitcast
275 // instruction.
276 auto DstTy = Dst->getType();
277 assert(DstTy->isPointerTy());
278
279 // The original type of Src we get from the argument to the bitcast
280 // instruction.
281 auto SrcTy = Src->getType();
282 assert(SrcTy->isPointerTy());
283
David Neto22f144c2017-06-12 14:26:21 -0400284 // Check that the size is a constant integer.
285 assert(isa<ConstantInt>(CI->getArgOperand(2)));
286 auto Size =
287 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
288
Alan Baker7dea8842018-10-22 10:15:41 -0400289 auto DstElemTy = DstTy->getPointerElementType();
290 auto SrcElemTy = SrcTy->getPointerElementType();
291 unsigned NumDstUnpackings = 0;
292 unsigned NumSrcUnpackings = 0;
293 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
294 &NumSrcUnpackings);
295
296 // Check that the pointee types match.
297 assert(DstElemTy == SrcElemTy);
298
David Netob84ba342017-06-19 17:55:37 -0400299 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
alan-baker4a757f62020-04-22 08:17:49 -0400300 (void)DstElemSize;
David Neto22f144c2017-06-12 14:26:21 -0400301
David Netob84ba342017-06-19 17:55:37 -0400302 // Check that the size is a multiple of the size of the pointee type.
303 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400304
alan-bakerbccf62c2019-03-29 10:32:41 -0400305 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400306 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
alan-baker4a757f62020-04-22 08:17:49 -0400307 (void)Alignment;
308 (void)TypeAlignment;
David Neto22f144c2017-06-12 14:26:21 -0400309
310 // Check that the alignment is at least the alignment of the pointee
311 // type.
312 assert(Alignment >= TypeAlignment);
313
314 // Check that the alignment is a multiple of the alignment of the
315 // pointee type.
316 assert(0 == (Alignment % TypeAlignment));
317
318 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400319 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400320
David Netob84ba342017-06-19 17:55:37 -0400321 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400322 }
323 }
324
David Netob84ba342017-06-19 17:55:37 -0400325 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500326 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
327 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400328 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400329
330 auto I32Ty = Type::getInt32Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400331 auto Alignment =
332 ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
alan-bakerbccf62c2019-03-29 10:32:41 -0400333 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400334
alan-bakered80f572019-02-11 17:28:26 -0500335 auto Dst = Arg0->getOperand(0);
336 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400337
338 auto DstElemTy = Dst->getType()->getPointerElementType();
339 auto SrcElemTy = Src->getType()->getPointerElementType();
340 unsigned NumDstUnpackings = 0;
341 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400342 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400343 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
344 &NumSrcUnpackings);
SJW61531372020-06-09 07:31:08 -0500345 auto SPIRVIntrinsic = clspv::CopyMemoryFunction();
David Neto22f144c2017-06-12 14:26:21 -0400346
David Netob84ba342017-06-19 17:55:37 -0400347 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400348
David Netob84ba342017-06-19 17:55:37 -0400349 IRBuilder<> Builder(CI);
350
351 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
352 auto NewFType = FunctionType::get(
353 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
354 false);
355 auto NewF =
356 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
357 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
358 } else {
359 auto Zero = ConstantInt::get(I32Ty, 0);
360 SmallVector<Value *, 3> SrcIndices;
361 SmallVector<Value *, 3> DstIndices;
362 // Make unpacking indices.
363 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
364 ++unpacking) {
365 SrcIndices.push_back(Zero);
366 }
367 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
368 ++unpacking) {
369 DstIndices.push_back(Zero);
370 }
371 // Add a placeholder for the final index.
372 SrcIndices.push_back(Zero);
373 DstIndices.push_back(Zero);
374
375 // Build the function and function type only once.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400376 FunctionType *NewFType = nullptr;
377 Function *NewF = nullptr;
David Netob84ba342017-06-19 17:55:37 -0400378
379 IRBuilder<> Builder(CI);
380 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
381 auto Index = ConstantInt::get(I32Ty, i);
382 SrcIndices.back() = Index;
383 DstIndices.back() = Index;
384
alan-bakered80f572019-02-11 17:28:26 -0500385 // Avoid the builder for Src in order to prevent the folder from
386 // creating constant expressions for constant memcpys.
387 auto SrcElemPtr =
388 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400389 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
390 NewFType =
391 NewFType != nullptr
392 ? NewFType
393 : FunctionType::get(F.getReturnType(),
394 {DstElemPtr->getType(),
395 SrcElemPtr->getType(), I32Ty, I32Ty},
396 false);
397 NewF = NewF != nullptr ? NewF
398 : Function::Create(NewFType, F.getLinkage(),
399 SPIRVIntrinsic, &M);
400 Builder.CreateCall(
401 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
402 }
403 }
404
405 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400406 CI->eraseFromParent();
407
David Netob84ba342017-06-19 17:55:37 -0400408 // Erase the bitcasts. A particular bitcast might be used
409 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500410 if (isa<BitCastInst>(Arg0))
411 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
412 if (isa<BitCastInst>(Arg1))
413 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400414 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400415 for (auto *Inst : BitCastsToForget) {
David Netob84ba342017-06-19 17:55:37 -0400416 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400417 }
418 }
419 }
420
421 return Changed;
422}
David Netoe345e0e2018-06-15 11:38:32 -0400423
424bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
425 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
426 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
427
428 bool Changed = false;
429
430 SmallVector<Function *, 2> WorkList;
431 for (auto &F : M) {
432 if (F.getName().startswith("llvm.lifetime.")) {
433 WorkList.push_back(&F);
434 }
435 }
436
437 for (auto *F : WorkList) {
438 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500439 // Copy users to avoid modifying the list in place.
440 SmallVector<User *, 8> users(F->users());
441 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400442 if (auto *CI = dyn_cast<CallInst>(U)) {
443 CI->eraseFromParent();
444 }
445 }
446 F->eraseFromParent();
447 }
448
449 return Changed;
450}