blob: 6d4a684d62c300de9f8da1ac6de477575f22bbb1 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
David Netoe345e0e2018-06-15 11:38:32 -040024#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
26using namespace llvm;
27
28#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
29
30namespace {
31struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
32 static char ID;
33 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
34
35 bool runOnModule(Module &M) override;
36 bool replaceMemset(Module &M);
37 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040038 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040039};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040040} // namespace
David Neto22f144c2017-06-12 14:26:21 -040041
42char ReplaceLLVMIntrinsicsPass::ID = 0;
43static RegisterPass<ReplaceLLVMIntrinsicsPass>
44 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
45
46namespace clspv {
47ModulePass *createReplaceLLVMIntrinsicsPass() {
48 return new ReplaceLLVMIntrinsicsPass();
49}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040050} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040051
52bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
53 bool Changed = false;
54
David Netoe345e0e2018-06-15 11:38:32 -040055 // Remove lifetime annotations first. They coulud be using memset
56 // and memcpy calls.
57 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040058 Changed |= replaceMemset(M);
59 Changed |= replaceMemcpy(M);
60
61 return Changed;
62}
63
64bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
65 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040066 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040067
68 for (auto &F : M) {
69 if (F.getName().startswith("llvm.memset")) {
70 SmallVector<CallInst *, 8> CallsToReplace;
71
72 for (auto U : F.users()) {
73 if (auto CI = dyn_cast<CallInst>(U)) {
74 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
75
76 // We only handle cases where the initializer is a constant int that
77 // is 0.
78 if (!Initializer || (0 != Initializer->getZExtValue())) {
79 Initializer->print(errs());
80 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
81 "non-0 initializer!");
82 }
83
84 CallsToReplace.push_back(CI);
85 }
86 }
87
88 for (auto CI : CallsToReplace) {
89 auto NewArg = CI->getArgOperand(0);
90
alan-bakered80f572019-02-11 17:28:26 -050091 if (auto Bitcast = dyn_cast<BitCastOperator>(NewArg)) {
David Neto22f144c2017-06-12 14:26:21 -040092 NewArg = Bitcast->getOperand(0);
93 }
94
David Netod3f59382017-10-18 18:30:30 -040095 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
96
David Neto22f144c2017-06-12 14:26:21 -040097 auto Ty = NewArg->getType();
98 auto PointeeTy = Ty->getPointerElementType();
99
100 auto NewFType =
101 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
102
103 // Create our fake intrinsic to initialize it to 0.
104 auto SPIRVIntrinsic = "spirv.store_null";
105
106 auto NewF =
107 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
108
109 auto Zero = Constant::getNullValue(PointeeTy);
110
David Netod3f59382017-10-18 18:30:30 -0400111 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
112 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
113 "Null memset can't be divided evenly across multiple stores.");
114 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400115
David Netod3f59382017-10-18 18:30:30 -0400116 // Generate the first store.
117 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
118
119 // Generate subsequent stores, but only if needed.
120 if (num_stores) {
121 auto I32Ty = Type::getInt32Ty(M.getContext());
122 auto One = ConstantInt::get(I32Ty, 1);
123 auto Ptr = NewArg;
124 for (uint32_t i = 1; i < num_stores; i++) {
125 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
126 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
127 }
128 }
129
David Neto22f144c2017-06-12 14:26:21 -0400130 CI->eraseFromParent();
131
132 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
133 Bitcast->eraseFromParent();
134 }
135 }
136 }
137 }
138
139 return Changed;
140}
141
142bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
143 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400144 auto Layout = M.getDataLayout();
145
146 // Unpack source and destination types until we find a matching
147 // element type. Count the number of levels we unpack for the
148 // source and destination types. So far this only works for
149 // array types, but could be generalized to other regular types
150 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400151 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
152 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400153 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400154 auto descend_type = [](Type *InType) {
155 Type *OutType = InType;
156 if (OutType->isStructTy()) {
157 OutType = OutType->getStructElementType(0);
158 } else if (OutType->isArrayTy()) {
159 OutType = OutType->getArrayElementType();
160 } else if (OutType->isVectorTy()) {
161 OutType = OutType->getVectorElementType();
162 } else {
163 assert(false && "Don't know how to descend into type");
164 }
165
166 return OutType;
167 };
168
David Netob84ba342017-06-19 17:55:37 -0400169 unsigned *numSrcUnpackings = 0;
170 unsigned *numDstUnpackings = 0;
171 while (*SrcElemTy != *DstElemTy) {
172 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
173 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
174 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400175 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400176 (*NumSrcUnpackings)++;
177 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400178 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400179 (*NumDstUnpackings)++;
180 } else {
181 errs() << "Don't know how to unpack types for memcpy: " << CI
182 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
183 assert(false && "Don't know how to unpack these types");
184 }
185 }
Alan Baker7dea8842018-10-22 10:15:41 -0400186
187 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
188 while (Size < DstElemSize) {
189 *DstElemTy = descend_type(*DstElemTy);
190 *SrcElemTy = descend_type(*SrcElemTy);
191 (*NumDstUnpackings)++;
192 (*NumSrcUnpackings)++;
193 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
194 }
David Netob84ba342017-06-19 17:55:37 -0400195 };
David Neto22f144c2017-06-12 14:26:21 -0400196
197 for (auto &F : M) {
198 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400199 SmallPtrSet<Instruction *, 8> BitCastsToForget;
200 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400201
202 for (auto U : F.users()) {
203 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500204 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400205 auto Dst =
206 dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400207
alan-bakered80f572019-02-11 17:28:26 -0500208 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400209 auto Src =
210 dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400211
212 // The original type of Dst we get from the argument to the bitcast
213 // instruction.
214 auto DstTy = Dst->getType();
215 assert(DstTy->isPointerTy());
216
217 // The original type of Src we get from the argument to the bitcast
218 // instruction.
219 auto SrcTy = Src->getType();
220 assert(SrcTy->isPointerTy());
221
David Neto22f144c2017-06-12 14:26:21 -0400222 // Check that the size is a constant integer.
223 assert(isa<ConstantInt>(CI->getArgOperand(2)));
224 auto Size =
225 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
226
Alan Baker7dea8842018-10-22 10:15:41 -0400227 auto DstElemTy = DstTy->getPointerElementType();
228 auto SrcElemTy = SrcTy->getPointerElementType();
229 unsigned NumDstUnpackings = 0;
230 unsigned NumSrcUnpackings = 0;
231 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
232 &NumSrcUnpackings);
233
234 // Check that the pointee types match.
235 assert(DstElemTy == SrcElemTy);
236
David Netob84ba342017-06-19 17:55:37 -0400237 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400238
David Netob84ba342017-06-19 17:55:37 -0400239 // Check that the size is a multiple of the size of the pointee type.
240 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400241
alan-bakerbccf62c2019-03-29 10:32:41 -0400242 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400243 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400244
245 // Check that the alignment is at least the alignment of the pointee
246 // type.
247 assert(Alignment >= TypeAlignment);
248
249 // Check that the alignment is a multiple of the alignment of the
250 // pointee type.
251 assert(0 == (Alignment % TypeAlignment));
252
253 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400254 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400255
David Netob84ba342017-06-19 17:55:37 -0400256 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400257 }
258 }
259
David Netob84ba342017-06-19 17:55:37 -0400260 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500261 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
262 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400263 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400264
265 auto I32Ty = Type::getInt32Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400266 auto Alignment =
267 ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
alan-bakerbccf62c2019-03-29 10:32:41 -0400268 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400269
alan-bakered80f572019-02-11 17:28:26 -0500270 auto Dst = Arg0->getOperand(0);
271 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400272
273 auto DstElemTy = Dst->getType()->getPointerElementType();
274 auto SrcElemTy = Src->getType()->getPointerElementType();
275 unsigned NumDstUnpackings = 0;
276 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400277 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400278 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
279 &NumSrcUnpackings);
280 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400281
David Netob84ba342017-06-19 17:55:37 -0400282 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400283
David Netob84ba342017-06-19 17:55:37 -0400284 IRBuilder<> Builder(CI);
285
286 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
287 auto NewFType = FunctionType::get(
288 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
289 false);
290 auto NewF =
291 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
292 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
293 } else {
294 auto Zero = ConstantInt::get(I32Ty, 0);
295 SmallVector<Value *, 3> SrcIndices;
296 SmallVector<Value *, 3> DstIndices;
297 // Make unpacking indices.
298 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
299 ++unpacking) {
300 SrcIndices.push_back(Zero);
301 }
302 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
303 ++unpacking) {
304 DstIndices.push_back(Zero);
305 }
306 // Add a placeholder for the final index.
307 SrcIndices.push_back(Zero);
308 DstIndices.push_back(Zero);
309
310 // Build the function and function type only once.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400311 FunctionType *NewFType = nullptr;
312 Function *NewF = nullptr;
David Netob84ba342017-06-19 17:55:37 -0400313
314 IRBuilder<> Builder(CI);
315 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
316 auto Index = ConstantInt::get(I32Ty, i);
317 SrcIndices.back() = Index;
318 DstIndices.back() = Index;
319
alan-bakered80f572019-02-11 17:28:26 -0500320 // Avoid the builder for Src in order to prevent the folder from
321 // creating constant expressions for constant memcpys.
322 auto SrcElemPtr =
323 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400324 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
325 NewFType =
326 NewFType != nullptr
327 ? NewFType
328 : FunctionType::get(F.getReturnType(),
329 {DstElemPtr->getType(),
330 SrcElemPtr->getType(), I32Ty, I32Ty},
331 false);
332 NewF = NewF != nullptr ? NewF
333 : Function::Create(NewFType, F.getLinkage(),
334 SPIRVIntrinsic, &M);
335 Builder.CreateCall(
336 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
337 }
338 }
339
340 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400341 CI->eraseFromParent();
342
David Netob84ba342017-06-19 17:55:37 -0400343 // Erase the bitcasts. A particular bitcast might be used
344 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500345 if (isa<BitCastInst>(Arg0))
346 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
347 if (isa<BitCastInst>(Arg1))
348 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400349 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400350 for (auto *Inst : BitCastsToForget) {
David Netob84ba342017-06-19 17:55:37 -0400351 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400352 }
353 }
354 }
355
356 return Changed;
357}
David Netoe345e0e2018-06-15 11:38:32 -0400358
359bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
360 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
361 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
362
363 bool Changed = false;
364
365 SmallVector<Function *, 2> WorkList;
366 for (auto &F : M) {
367 if (F.getName().startswith("llvm.lifetime.")) {
368 WorkList.push_back(&F);
369 }
370 }
371
372 for (auto *F : WorkList) {
373 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500374 // Copy users to avoid modifying the list in place.
375 SmallVector<User *, 8> users(F->users());
376 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400377 if (auto *CI = dyn_cast<CallInst>(U)) {
378 CI->eraseFromParent();
379 }
380 }
381 F->eraseFromParent();
382 }
383
384 return Changed;
385}