blob: ccd599bc79a090cb4ddd7386b741276529caff57 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
alan-bakere0902602020-03-23 08:43:40 -040024#include "spirv/unified1/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
Diego Novilloa4c44fa2019-04-11 10:56:15 -040026#include "Passes.h"
27
David Neto22f144c2017-06-12 14:26:21 -040028using namespace llvm;
29
30#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
31
32namespace {
33struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
34 static char ID;
35 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
36
37 bool runOnModule(Module &M) override;
38 bool replaceMemset(Module &M);
39 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040040 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040041};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace
David Neto22f144c2017-06-12 14:26:21 -040043
44char ReplaceLLVMIntrinsicsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040045INITIALIZE_PASS(ReplaceLLVMIntrinsicsPass, "ReplaceLLVMIntrinsics",
46 "Replace LLVM intrinsics Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040047
48namespace clspv {
49ModulePass *createReplaceLLVMIntrinsicsPass() {
50 return new ReplaceLLVMIntrinsicsPass();
51}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040052} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040053
54bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
55 bool Changed = false;
56
David Netoe345e0e2018-06-15 11:38:32 -040057 // Remove lifetime annotations first. They coulud be using memset
58 // and memcpy calls.
59 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040060 Changed |= replaceMemset(M);
61 Changed |= replaceMemcpy(M);
62
63 return Changed;
64}
65
66bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
67 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040068 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040069
70 for (auto &F : M) {
71 if (F.getName().startswith("llvm.memset")) {
72 SmallVector<CallInst *, 8> CallsToReplace;
73
74 for (auto U : F.users()) {
75 if (auto CI = dyn_cast<CallInst>(U)) {
76 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
77
78 // We only handle cases where the initializer is a constant int that
79 // is 0.
80 if (!Initializer || (0 != Initializer->getZExtValue())) {
81 Initializer->print(errs());
82 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
83 "non-0 initializer!");
84 }
85
86 CallsToReplace.push_back(CI);
87 }
88 }
89
90 for (auto CI : CallsToReplace) {
91 auto NewArg = CI->getArgOperand(0);
Kévin Petit70944912019-04-17 23:22:28 +010092 auto Bitcast = dyn_cast<BitCastInst>(NewArg);
93 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -040094 NewArg = Bitcast->getOperand(0);
95 }
96
David Netod3f59382017-10-18 18:30:30 -040097 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -040098 auto Ty = NewArg->getType();
99 auto PointeeTy = Ty->getPointerElementType();
David Neto22f144c2017-06-12 14:26:21 -0400100 auto Zero = Constant::getNullValue(PointeeTy);
101
David Netod3f59382017-10-18 18:30:30 -0400102 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
103 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
104 "Null memset can't be divided evenly across multiple stores.");
105 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400106
David Netod3f59382017-10-18 18:30:30 -0400107 // Generate the first store.
Kévin Petit58c445c2019-06-18 18:09:46 +0100108 new StoreInst(Zero, NewArg, CI);
David Netod3f59382017-10-18 18:30:30 -0400109
110 // Generate subsequent stores, but only if needed.
111 if (num_stores) {
112 auto I32Ty = Type::getInt32Ty(M.getContext());
113 auto One = ConstantInt::get(I32Ty, 1);
114 auto Ptr = NewArg;
115 for (uint32_t i = 1; i < num_stores; i++) {
116 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
Kévin Petit58c445c2019-06-18 18:09:46 +0100117 new StoreInst(Zero, Ptr, CI);
David Netod3f59382017-10-18 18:30:30 -0400118 }
119 }
120
David Neto22f144c2017-06-12 14:26:21 -0400121 CI->eraseFromParent();
122
Kévin Petit70944912019-04-17 23:22:28 +0100123 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -0400124 Bitcast->eraseFromParent();
125 }
126 }
127 }
128 }
129
130 return Changed;
131}
132
133bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
134 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400135 auto Layout = M.getDataLayout();
136
137 // Unpack source and destination types until we find a matching
138 // element type. Count the number of levels we unpack for the
139 // source and destination types. So far this only works for
140 // array types, but could be generalized to other regular types
141 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400142 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
143 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400144 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400145 auto descend_type = [](Type *InType) {
146 Type *OutType = InType;
147 if (OutType->isStructTy()) {
148 OutType = OutType->getStructElementType(0);
149 } else if (OutType->isArrayTy()) {
150 OutType = OutType->getArrayElementType();
James Pricecf53df42020-04-20 14:41:24 -0400151 } else if (auto vec_type = dyn_cast<VectorType>(OutType)) {
152 OutType = vec_type->getElementType();
Alan Baker7dea8842018-10-22 10:15:41 -0400153 } else {
154 assert(false && "Don't know how to descend into type");
155 }
156
157 return OutType;
158 };
159
David Netob84ba342017-06-19 17:55:37 -0400160 while (*SrcElemTy != *DstElemTy) {
161 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
162 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
163 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400164 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400165 (*NumSrcUnpackings)++;
166 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400167 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400168 (*NumDstUnpackings)++;
169 } else {
170 errs() << "Don't know how to unpack types for memcpy: " << CI
171 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
172 assert(false && "Don't know how to unpack these types");
173 }
174 }
Alan Baker7dea8842018-10-22 10:15:41 -0400175
176 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
177 while (Size < DstElemSize) {
178 *DstElemTy = descend_type(*DstElemTy);
179 *SrcElemTy = descend_type(*SrcElemTy);
180 (*NumDstUnpackings)++;
181 (*NumSrcUnpackings)++;
182 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
183 }
David Netob84ba342017-06-19 17:55:37 -0400184 };
David Neto22f144c2017-06-12 14:26:21 -0400185
186 for (auto &F : M) {
187 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400188 SmallPtrSet<Instruction *, 8> BitCastsToForget;
189 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400190
191 for (auto U : F.users()) {
192 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500193 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400194 auto Dst =
195 dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400196
alan-bakered80f572019-02-11 17:28:26 -0500197 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400198 auto Src =
199 dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400200
201 // The original type of Dst we get from the argument to the bitcast
202 // instruction.
203 auto DstTy = Dst->getType();
204 assert(DstTy->isPointerTy());
205
206 // The original type of Src we get from the argument to the bitcast
207 // instruction.
208 auto SrcTy = Src->getType();
209 assert(SrcTy->isPointerTy());
210
David Neto22f144c2017-06-12 14:26:21 -0400211 // Check that the size is a constant integer.
212 assert(isa<ConstantInt>(CI->getArgOperand(2)));
213 auto Size =
214 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
215
Alan Baker7dea8842018-10-22 10:15:41 -0400216 auto DstElemTy = DstTy->getPointerElementType();
217 auto SrcElemTy = SrcTy->getPointerElementType();
218 unsigned NumDstUnpackings = 0;
219 unsigned NumSrcUnpackings = 0;
220 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
221 &NumSrcUnpackings);
222
223 // Check that the pointee types match.
224 assert(DstElemTy == SrcElemTy);
225
David Netob84ba342017-06-19 17:55:37 -0400226 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
alan-baker4a757f62020-04-22 08:17:49 -0400227 (void)DstElemSize;
David Neto22f144c2017-06-12 14:26:21 -0400228
David Netob84ba342017-06-19 17:55:37 -0400229 // Check that the size is a multiple of the size of the pointee type.
230 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400231
alan-bakerbccf62c2019-03-29 10:32:41 -0400232 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400233 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
alan-baker4a757f62020-04-22 08:17:49 -0400234 (void)Alignment;
235 (void)TypeAlignment;
David Neto22f144c2017-06-12 14:26:21 -0400236
237 // Check that the alignment is at least the alignment of the pointee
238 // type.
239 assert(Alignment >= TypeAlignment);
240
241 // Check that the alignment is a multiple of the alignment of the
242 // pointee type.
243 assert(0 == (Alignment % TypeAlignment));
244
245 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400246 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400247
David Netob84ba342017-06-19 17:55:37 -0400248 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400249 }
250 }
251
David Netob84ba342017-06-19 17:55:37 -0400252 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500253 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
254 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400255 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400256
257 auto I32Ty = Type::getInt32Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400258 auto Alignment =
259 ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
alan-bakerbccf62c2019-03-29 10:32:41 -0400260 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400261
alan-bakered80f572019-02-11 17:28:26 -0500262 auto Dst = Arg0->getOperand(0);
263 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400264
265 auto DstElemTy = Dst->getType()->getPointerElementType();
266 auto SrcElemTy = Src->getType()->getPointerElementType();
267 unsigned NumDstUnpackings = 0;
268 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400269 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400270 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
271 &NumSrcUnpackings);
272 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400273
David Netob84ba342017-06-19 17:55:37 -0400274 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400275
David Netob84ba342017-06-19 17:55:37 -0400276 IRBuilder<> Builder(CI);
277
278 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
279 auto NewFType = FunctionType::get(
280 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
281 false);
282 auto NewF =
283 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
284 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
285 } else {
286 auto Zero = ConstantInt::get(I32Ty, 0);
287 SmallVector<Value *, 3> SrcIndices;
288 SmallVector<Value *, 3> DstIndices;
289 // Make unpacking indices.
290 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
291 ++unpacking) {
292 SrcIndices.push_back(Zero);
293 }
294 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
295 ++unpacking) {
296 DstIndices.push_back(Zero);
297 }
298 // Add a placeholder for the final index.
299 SrcIndices.push_back(Zero);
300 DstIndices.push_back(Zero);
301
302 // Build the function and function type only once.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400303 FunctionType *NewFType = nullptr;
304 Function *NewF = nullptr;
David Netob84ba342017-06-19 17:55:37 -0400305
306 IRBuilder<> Builder(CI);
307 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
308 auto Index = ConstantInt::get(I32Ty, i);
309 SrcIndices.back() = Index;
310 DstIndices.back() = Index;
311
alan-bakered80f572019-02-11 17:28:26 -0500312 // Avoid the builder for Src in order to prevent the folder from
313 // creating constant expressions for constant memcpys.
314 auto SrcElemPtr =
315 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400316 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
317 NewFType =
318 NewFType != nullptr
319 ? NewFType
320 : FunctionType::get(F.getReturnType(),
321 {DstElemPtr->getType(),
322 SrcElemPtr->getType(), I32Ty, I32Ty},
323 false);
324 NewF = NewF != nullptr ? NewF
325 : Function::Create(NewFType, F.getLinkage(),
326 SPIRVIntrinsic, &M);
327 Builder.CreateCall(
328 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
329 }
330 }
331
332 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400333 CI->eraseFromParent();
334
David Netob84ba342017-06-19 17:55:37 -0400335 // Erase the bitcasts. A particular bitcast might be used
336 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500337 if (isa<BitCastInst>(Arg0))
338 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
339 if (isa<BitCastInst>(Arg1))
340 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400341 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400342 for (auto *Inst : BitCastsToForget) {
David Netob84ba342017-06-19 17:55:37 -0400343 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400344 }
345 }
346 }
347
348 return Changed;
349}
David Netoe345e0e2018-06-15 11:38:32 -0400350
351bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
352 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
353 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
354
355 bool Changed = false;
356
357 SmallVector<Function *, 2> WorkList;
358 for (auto &F : M) {
359 if (F.getName().startswith("llvm.lifetime.")) {
360 WorkList.push_back(&F);
361 }
362 }
363
364 for (auto *F : WorkList) {
365 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500366 // Copy users to avoid modifying the list in place.
367 SmallVector<User *, 8> users(F->users());
368 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400369 if (auto *CI = dyn_cast<CallInst>(U)) {
370 CI->eraseFromParent();
371 }
372 }
373 F->eraseFromParent();
374 }
375
376 return Changed;
377}