blob: a260a8452757133e15c28696673fae5b9b876061 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/raw_ostream.h"
21#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040022
David Netoe345e0e2018-06-15 11:38:32 -040023#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040024
25using namespace llvm;
26
27#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
28
29namespace {
30struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
31 static char ID;
32 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
33
34 bool runOnModule(Module &M) override;
35 bool replaceMemset(Module &M);
36 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040037 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040038};
39}
40
41char ReplaceLLVMIntrinsicsPass::ID = 0;
42static RegisterPass<ReplaceLLVMIntrinsicsPass>
43 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
44
45namespace clspv {
46ModulePass *createReplaceLLVMIntrinsicsPass() {
47 return new ReplaceLLVMIntrinsicsPass();
48}
49}
50
51bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
52 bool Changed = false;
53
David Netoe345e0e2018-06-15 11:38:32 -040054 // Remove lifetime annotations first. They coulud be using memset
55 // and memcpy calls.
56 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040057 Changed |= replaceMemset(M);
58 Changed |= replaceMemcpy(M);
59
60 return Changed;
61}
62
63bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
64 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040065 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040066
67 for (auto &F : M) {
68 if (F.getName().startswith("llvm.memset")) {
69 SmallVector<CallInst *, 8> CallsToReplace;
70
71 for (auto U : F.users()) {
72 if (auto CI = dyn_cast<CallInst>(U)) {
73 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
74
75 // We only handle cases where the initializer is a constant int that
76 // is 0.
77 if (!Initializer || (0 != Initializer->getZExtValue())) {
78 Initializer->print(errs());
79 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
80 "non-0 initializer!");
81 }
82
83 CallsToReplace.push_back(CI);
84 }
85 }
86
87 for (auto CI : CallsToReplace) {
88 auto NewArg = CI->getArgOperand(0);
89
alan-bakered80f572019-02-11 17:28:26 -050090 if (auto Bitcast = dyn_cast<BitCastOperator>(NewArg)) {
David Neto22f144c2017-06-12 14:26:21 -040091 NewArg = Bitcast->getOperand(0);
92 }
93
David Netod3f59382017-10-18 18:30:30 -040094 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
95
David Neto22f144c2017-06-12 14:26:21 -040096 auto Ty = NewArg->getType();
97 auto PointeeTy = Ty->getPointerElementType();
98
99 auto NewFType =
100 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
101
102 // Create our fake intrinsic to initialize it to 0.
103 auto SPIRVIntrinsic = "spirv.store_null";
104
105 auto NewF =
106 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
107
108 auto Zero = Constant::getNullValue(PointeeTy);
109
David Netod3f59382017-10-18 18:30:30 -0400110 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
111 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
112 "Null memset can't be divided evenly across multiple stores.");
113 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400114
David Netod3f59382017-10-18 18:30:30 -0400115 // Generate the first store.
116 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
117
118 // Generate subsequent stores, but only if needed.
119 if (num_stores) {
120 auto I32Ty = Type::getInt32Ty(M.getContext());
121 auto One = ConstantInt::get(I32Ty, 1);
122 auto Ptr = NewArg;
123 for (uint32_t i = 1; i < num_stores; i++) {
124 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
125 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
126 }
127 }
128
David Neto22f144c2017-06-12 14:26:21 -0400129 CI->eraseFromParent();
130
131 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
132 Bitcast->eraseFromParent();
133 }
134 }
135 }
136 }
137
138 return Changed;
139}
140
141bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
142 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400143 auto Layout = M.getDataLayout();
144
145 // Unpack source and destination types until we find a matching
146 // element type. Count the number of levels we unpack for the
147 // source and destination types. So far this only works for
148 // array types, but could be generalized to other regular types
149 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400150 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
151 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400152 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400153 auto descend_type = [](Type *InType) {
154 Type *OutType = InType;
155 if (OutType->isStructTy()) {
156 OutType = OutType->getStructElementType(0);
157 } else if (OutType->isArrayTy()) {
158 OutType = OutType->getArrayElementType();
159 } else if (OutType->isVectorTy()) {
160 OutType = OutType->getVectorElementType();
161 } else {
162 assert(false && "Don't know how to descend into type");
163 }
164
165 return OutType;
166 };
167
David Netob84ba342017-06-19 17:55:37 -0400168 unsigned *numSrcUnpackings = 0;
169 unsigned *numDstUnpackings = 0;
170 while (*SrcElemTy != *DstElemTy) {
171 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
172 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
173 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400174 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400175 (*NumSrcUnpackings)++;
176 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400177 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400178 (*NumDstUnpackings)++;
179 } else {
180 errs() << "Don't know how to unpack types for memcpy: " << CI
181 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
182 assert(false && "Don't know how to unpack these types");
183 }
184 }
Alan Baker7dea8842018-10-22 10:15:41 -0400185
186 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
187 while (Size < DstElemSize) {
188 *DstElemTy = descend_type(*DstElemTy);
189 *SrcElemTy = descend_type(*SrcElemTy);
190 (*NumDstUnpackings)++;
191 (*NumSrcUnpackings)++;
192 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
193 }
David Netob84ba342017-06-19 17:55:37 -0400194 };
David Neto22f144c2017-06-12 14:26:21 -0400195
196 for (auto &F : M) {
197 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400198 SmallPtrSet<Instruction *, 8> BitCastsToForget;
199 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400200
201 for (auto U : F.users()) {
202 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500203 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
204 auto Dst = dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400205
alan-bakered80f572019-02-11 17:28:26 -0500206 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
207 auto Src = dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400208
209 // The original type of Dst we get from the argument to the bitcast
210 // instruction.
211 auto DstTy = Dst->getType();
212 assert(DstTy->isPointerTy());
213
214 // The original type of Src we get from the argument to the bitcast
215 // instruction.
216 auto SrcTy = Src->getType();
217 assert(SrcTy->isPointerTy());
218
David Neto22f144c2017-06-12 14:26:21 -0400219 // Check that the size is a constant integer.
220 assert(isa<ConstantInt>(CI->getArgOperand(2)));
221 auto Size =
222 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
223
Alan Baker7dea8842018-10-22 10:15:41 -0400224 auto DstElemTy = DstTy->getPointerElementType();
225 auto SrcElemTy = SrcTy->getPointerElementType();
226 unsigned NumDstUnpackings = 0;
227 unsigned NumSrcUnpackings = 0;
228 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
229 &NumSrcUnpackings);
230
231 // Check that the pointee types match.
232 assert(DstElemTy == SrcElemTy);
233
David Netob84ba342017-06-19 17:55:37 -0400234 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400235
David Netob84ba342017-06-19 17:55:37 -0400236 // Check that the size is a multiple of the size of the pointee type.
237 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400238
239 // Check that the alignment is a constant integer.
240 assert(isa<ConstantInt>(CI->getArgOperand(3)));
241 auto Alignment =
242 dyn_cast<ConstantInt>(CI->getArgOperand(3))->getZExtValue();
243
David Netob84ba342017-06-19 17:55:37 -0400244 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400245
246 // Check that the alignment is at least the alignment of the pointee
247 // type.
248 assert(Alignment >= TypeAlignment);
249
250 // Check that the alignment is a multiple of the alignment of the
251 // pointee type.
252 assert(0 == (Alignment % TypeAlignment));
253
254 // Check that volatile is a constant.
255 assert(isa<ConstantInt>(CI->getArgOperand(4)));
256
David Netob84ba342017-06-19 17:55:37 -0400257 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400258 }
259 }
260
David Netob84ba342017-06-19 17:55:37 -0400261 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500262 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
263 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400264 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
265 auto Arg4 = dyn_cast<ConstantInt>(CI->getArgOperand(4));
266
267 auto I32Ty = Type::getInt32Ty(M.getContext());
David Netob84ba342017-06-19 17:55:37 -0400268 auto Alignment = ConstantInt::get(I32Ty, Arg3->getZExtValue());
269 auto Volatile = ConstantInt::get(I32Ty, Arg4->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400270
alan-bakered80f572019-02-11 17:28:26 -0500271 auto Dst = Arg0->getOperand(0);
272 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400273
274 auto DstElemTy = Dst->getType()->getPointerElementType();
275 auto SrcElemTy = Src->getType()->getPointerElementType();
276 unsigned NumDstUnpackings = 0;
277 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400278 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400279 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
280 &NumSrcUnpackings);
281 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400282
David Netob84ba342017-06-19 17:55:37 -0400283 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400284
David Netob84ba342017-06-19 17:55:37 -0400285 IRBuilder<> Builder(CI);
286
287 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
288 auto NewFType = FunctionType::get(
289 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
290 false);
291 auto NewF =
292 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
293 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
294 } else {
295 auto Zero = ConstantInt::get(I32Ty, 0);
296 SmallVector<Value *, 3> SrcIndices;
297 SmallVector<Value *, 3> DstIndices;
298 // Make unpacking indices.
299 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
300 ++unpacking) {
301 SrcIndices.push_back(Zero);
302 }
303 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
304 ++unpacking) {
305 DstIndices.push_back(Zero);
306 }
307 // Add a placeholder for the final index.
308 SrcIndices.push_back(Zero);
309 DstIndices.push_back(Zero);
310
311 // Build the function and function type only once.
312 FunctionType* NewFType = nullptr;
313 Function* NewF = nullptr;
314
315 IRBuilder<> Builder(CI);
316 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
317 auto Index = ConstantInt::get(I32Ty, i);
318 SrcIndices.back() = Index;
319 DstIndices.back() = Index;
320
alan-bakered80f572019-02-11 17:28:26 -0500321 // Avoid the builder for Src in order to prevent the folder from
322 // creating constant expressions for constant memcpys.
323 auto SrcElemPtr =
324 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400325 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
326 NewFType =
327 NewFType != nullptr
328 ? NewFType
329 : FunctionType::get(F.getReturnType(),
330 {DstElemPtr->getType(),
331 SrcElemPtr->getType(), I32Ty, I32Ty},
332 false);
333 NewF = NewF != nullptr ? NewF
334 : Function::Create(NewFType, F.getLinkage(),
335 SPIRVIntrinsic, &M);
336 Builder.CreateCall(
337 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
338 }
339 }
340
341 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400342 CI->eraseFromParent();
343
David Netob84ba342017-06-19 17:55:37 -0400344 // Erase the bitcasts. A particular bitcast might be used
345 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500346 if (isa<BitCastInst>(Arg0))
347 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
348 if (isa<BitCastInst>(Arg1))
349 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400350 }
351 for (auto* Inst : BitCastsToForget) {
352 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400353 }
354 }
355 }
356
357 return Changed;
358}
David Netoe345e0e2018-06-15 11:38:32 -0400359
360bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
361 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
362 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
363
364 bool Changed = false;
365
366 SmallVector<Function *, 2> WorkList;
367 for (auto &F : M) {
368 if (F.getName().startswith("llvm.lifetime.")) {
369 WorkList.push_back(&F);
370 }
371 }
372
373 for (auto *F : WorkList) {
374 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500375 // Copy users to avoid modifying the list in place.
376 SmallVector<User *, 8> users(F->users());
377 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400378 if (auto *CI = dyn_cast<CallInst>(U)) {
379 CI->eraseFromParent();
380 }
381 }
382 F->eraseFromParent();
383 }
384
385 return Changed;
386}