blob: a89405664ddc6b33ba73eee7d48268348f5eb0bb [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/raw_ostream.h"
21#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040022
David Netoe345e0e2018-06-15 11:38:32 -040023#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040024
25using namespace llvm;
26
27#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
28
29namespace {
30struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
31 static char ID;
32 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
33
34 bool runOnModule(Module &M) override;
35 bool replaceMemset(Module &M);
36 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040037 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040038};
39}
40
41char ReplaceLLVMIntrinsicsPass::ID = 0;
42static RegisterPass<ReplaceLLVMIntrinsicsPass>
43 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
44
45namespace clspv {
46ModulePass *createReplaceLLVMIntrinsicsPass() {
47 return new ReplaceLLVMIntrinsicsPass();
48}
49}
50
51bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
52 bool Changed = false;
53
David Netoe345e0e2018-06-15 11:38:32 -040054 // Remove lifetime annotations first. They coulud be using memset
55 // and memcpy calls.
56 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040057 Changed |= replaceMemset(M);
58 Changed |= replaceMemcpy(M);
59
60 return Changed;
61}
62
63bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
64 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040065 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040066
67 for (auto &F : M) {
68 if (F.getName().startswith("llvm.memset")) {
69 SmallVector<CallInst *, 8> CallsToReplace;
70
71 for (auto U : F.users()) {
72 if (auto CI = dyn_cast<CallInst>(U)) {
73 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
74
75 // We only handle cases where the initializer is a constant int that
76 // is 0.
77 if (!Initializer || (0 != Initializer->getZExtValue())) {
78 Initializer->print(errs());
79 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
80 "non-0 initializer!");
81 }
82
83 CallsToReplace.push_back(CI);
84 }
85 }
86
87 for (auto CI : CallsToReplace) {
88 auto NewArg = CI->getArgOperand(0);
89
90 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
91 NewArg = Bitcast->getOperand(0);
92 }
93
David Netod3f59382017-10-18 18:30:30 -040094 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
95
David Neto22f144c2017-06-12 14:26:21 -040096 auto Ty = NewArg->getType();
97 auto PointeeTy = Ty->getPointerElementType();
98
99 auto NewFType =
100 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
101
102 // Create our fake intrinsic to initialize it to 0.
103 auto SPIRVIntrinsic = "spirv.store_null";
104
105 auto NewF =
106 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
107
108 auto Zero = Constant::getNullValue(PointeeTy);
109
David Netod3f59382017-10-18 18:30:30 -0400110 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
111 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
112 "Null memset can't be divided evenly across multiple stores.");
113 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400114
David Netod3f59382017-10-18 18:30:30 -0400115 // Generate the first store.
116 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
117
118 // Generate subsequent stores, but only if needed.
119 if (num_stores) {
120 auto I32Ty = Type::getInt32Ty(M.getContext());
121 auto One = ConstantInt::get(I32Ty, 1);
122 auto Ptr = NewArg;
123 for (uint32_t i = 1; i < num_stores; i++) {
124 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
125 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
126 }
127 }
128
David Neto22f144c2017-06-12 14:26:21 -0400129 CI->eraseFromParent();
130
131 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
132 Bitcast->eraseFromParent();
133 }
134 }
135 }
136 }
137
138 return Changed;
139}
140
141bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
142 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400143 auto Layout = M.getDataLayout();
144
145 // Unpack source and destination types until we find a matching
146 // element type. Count the number of levels we unpack for the
147 // source and destination types. So far this only works for
148 // array types, but could be generalized to other regular types
149 // like vectors.
150 auto match_types = [&Layout](CallInst &CI, Type **DstElemTy, Type **SrcElemTy,
151 unsigned *NumDstUnpackings,
152 unsigned *NumSrcUnpackings) {
153 unsigned *numSrcUnpackings = 0;
154 unsigned *numDstUnpackings = 0;
155 while (*SrcElemTy != *DstElemTy) {
156 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
157 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
158 if (SrcElemSize >= DstElemSize) {
159 assert((*SrcElemTy)->isArrayTy());
160 *SrcElemTy = (*SrcElemTy)->getArrayElementType();
161 (*NumSrcUnpackings)++;
162 } else if (DstElemSize >= SrcElemSize) {
163 assert((*DstElemTy)->isArrayTy());
164 *DstElemTy = (*DstElemTy)->getArrayElementType();
165 (*NumDstUnpackings)++;
166 } else {
167 errs() << "Don't know how to unpack types for memcpy: " << CI
168 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
169 assert(false && "Don't know how to unpack these types");
170 }
171 }
172 };
David Neto22f144c2017-06-12 14:26:21 -0400173
174 for (auto &F : M) {
175 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400176 SmallPtrSet<Instruction *, 8> BitCastsToForget;
177 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400178
179 for (auto U : F.users()) {
180 if (auto CI = dyn_cast<CallInst>(U)) {
181 assert(isa<BitCastInst>(CI->getArgOperand(0)));
182 auto Dst = dyn_cast<BitCastInst>(CI->getArgOperand(0))->getOperand(0);
183
184 assert(isa<BitCastInst>(CI->getArgOperand(1)));
185 auto Src = dyn_cast<BitCastInst>(CI->getArgOperand(1))->getOperand(0);
186
187 // The original type of Dst we get from the argument to the bitcast
188 // instruction.
189 auto DstTy = Dst->getType();
190 assert(DstTy->isPointerTy());
191
192 // The original type of Src we get from the argument to the bitcast
193 // instruction.
194 auto SrcTy = Src->getType();
195 assert(SrcTy->isPointerTy());
196
David Netob84ba342017-06-19 17:55:37 -0400197 auto DstElemTy = DstTy->getPointerElementType();
198 auto SrcElemTy = SrcTy->getPointerElementType();
199 unsigned NumDstUnpackings = 0;
200 unsigned NumSrcUnpackings = 0;
201 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
202 &NumSrcUnpackings);
203
David Neto22f144c2017-06-12 14:26:21 -0400204 // Check that the pointee types match.
David Netob84ba342017-06-19 17:55:37 -0400205 assert(DstElemTy == SrcElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400206
207 // Check that the size is a constant integer.
208 assert(isa<ConstantInt>(CI->getArgOperand(2)));
209 auto Size =
210 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
211
David Netob84ba342017-06-19 17:55:37 -0400212 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400213
David Netob84ba342017-06-19 17:55:37 -0400214 // Check that the size is a multiple of the size of the pointee type.
215 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400216
217 // Check that the alignment is a constant integer.
218 assert(isa<ConstantInt>(CI->getArgOperand(3)));
219 auto Alignment =
220 dyn_cast<ConstantInt>(CI->getArgOperand(3))->getZExtValue();
221
David Netob84ba342017-06-19 17:55:37 -0400222 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400223
224 // Check that the alignment is at least the alignment of the pointee
225 // type.
226 assert(Alignment >= TypeAlignment);
227
228 // Check that the alignment is a multiple of the alignment of the
229 // pointee type.
230 assert(0 == (Alignment % TypeAlignment));
231
232 // Check that volatile is a constant.
233 assert(isa<ConstantInt>(CI->getArgOperand(4)));
234
David Netob84ba342017-06-19 17:55:37 -0400235 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400236 }
237 }
238
David Netob84ba342017-06-19 17:55:37 -0400239 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
David Neto22f144c2017-06-12 14:26:21 -0400240 auto Arg0 = dyn_cast<BitCastInst>(CI->getArgOperand(0));
241 auto Arg1 = dyn_cast<BitCastInst>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400242 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
243 auto Arg4 = dyn_cast<ConstantInt>(CI->getArgOperand(4));
244
245 auto I32Ty = Type::getInt32Ty(M.getContext());
David Netob84ba342017-06-19 17:55:37 -0400246 auto Alignment = ConstantInt::get(I32Ty, Arg3->getZExtValue());
247 auto Volatile = ConstantInt::get(I32Ty, Arg4->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400248
David Netob84ba342017-06-19 17:55:37 -0400249 auto Dst = dyn_cast<BitCastInst>(Arg0)->getOperand(0);
250 auto Src = dyn_cast<BitCastInst>(Arg1)->getOperand(0);
251
252 auto DstElemTy = Dst->getType()->getPointerElementType();
253 auto SrcElemTy = Src->getType()->getPointerElementType();
254 unsigned NumDstUnpackings = 0;
255 unsigned NumSrcUnpackings = 0;
256 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
257 &NumSrcUnpackings);
258
259 assert(NumDstUnpackings < 2 && "Need to generalize dst unpacking case");
260 assert(NumSrcUnpackings < 2 && "Need to generalize src unpacking case");
261 assert((NumDstUnpackings == 0 || NumSrcUnpackings == 0) &&
262 "Need to generalize unpackings in both dimensions");
David Neto22f144c2017-06-12 14:26:21 -0400263
264 auto SPIRVIntrinsic = "spirv.copy_memory";
265
David Netob84ba342017-06-19 17:55:37 -0400266 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -0400267
David Netob84ba342017-06-19 17:55:37 -0400268 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400269
David Netob84ba342017-06-19 17:55:37 -0400270 IRBuilder<> Builder(CI);
271
272 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
273 auto NewFType = FunctionType::get(
274 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
275 false);
276 auto NewF =
277 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
278 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
279 } else {
280 auto Zero = ConstantInt::get(I32Ty, 0);
281 SmallVector<Value *, 3> SrcIndices;
282 SmallVector<Value *, 3> DstIndices;
283 // Make unpacking indices.
284 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
285 ++unpacking) {
286 SrcIndices.push_back(Zero);
287 }
288 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
289 ++unpacking) {
290 DstIndices.push_back(Zero);
291 }
292 // Add a placeholder for the final index.
293 SrcIndices.push_back(Zero);
294 DstIndices.push_back(Zero);
295
296 // Build the function and function type only once.
297 FunctionType* NewFType = nullptr;
298 Function* NewF = nullptr;
299
300 IRBuilder<> Builder(CI);
301 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
302 auto Index = ConstantInt::get(I32Ty, i);
303 SrcIndices.back() = Index;
304 DstIndices.back() = Index;
305
306 auto SrcElemPtr = Builder.CreateGEP(Src, SrcIndices);
307 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
308 NewFType =
309 NewFType != nullptr
310 ? NewFType
311 : FunctionType::get(F.getReturnType(),
312 {DstElemPtr->getType(),
313 SrcElemPtr->getType(), I32Ty, I32Ty},
314 false);
315 NewF = NewF != nullptr ? NewF
316 : Function::Create(NewFType, F.getLinkage(),
317 SPIRVIntrinsic, &M);
318 Builder.CreateCall(
319 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
320 }
321 }
322
323 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400324 CI->eraseFromParent();
325
David Netob84ba342017-06-19 17:55:37 -0400326 // Erase the bitcasts. A particular bitcast might be used
327 // in more than one memcpy, so defer actual deleting until later.
328 BitCastsToForget.insert(Arg0);
329 BitCastsToForget.insert(Arg1);
330 }
331 for (auto* Inst : BitCastsToForget) {
332 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400333 }
334 }
335 }
336
337 return Changed;
338}
David Netoe345e0e2018-06-15 11:38:32 -0400339
340bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
341 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
342 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
343
344 bool Changed = false;
345
346 SmallVector<Function *, 2> WorkList;
347 for (auto &F : M) {
348 if (F.getName().startswith("llvm.lifetime.")) {
349 WorkList.push_back(&F);
350 }
351 }
352
353 for (auto *F : WorkList) {
354 Changed = true;
355 for (auto U : F->users()) {
356 if (auto *CI = dyn_cast<CallInst>(U)) {
357 CI->eraseFromParent();
358 }
359 }
360 F->eraseFromParent();
361 }
362
363 return Changed;
364}