blob: e747ba84674240613e64c455ce356695adcfc358 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
David Netoe345e0e2018-06-15 11:38:32 -040024#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
Diego Novilloa4c44fa2019-04-11 10:56:15 -040026#include "Passes.h"
27
David Neto22f144c2017-06-12 14:26:21 -040028using namespace llvm;
29
30#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
31
32namespace {
33struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
34 static char ID;
35 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
36
37 bool runOnModule(Module &M) override;
38 bool replaceMemset(Module &M);
39 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040040 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040041};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace
David Neto22f144c2017-06-12 14:26:21 -040043
44char ReplaceLLVMIntrinsicsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040045INITIALIZE_PASS(ReplaceLLVMIntrinsicsPass, "ReplaceLLVMIntrinsics",
46 "Replace LLVM intrinsics Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040047
48namespace clspv {
49ModulePass *createReplaceLLVMIntrinsicsPass() {
50 return new ReplaceLLVMIntrinsicsPass();
51}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040052} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040053
54bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
55 bool Changed = false;
56
David Netoe345e0e2018-06-15 11:38:32 -040057 // Remove lifetime annotations first. They coulud be using memset
58 // and memcpy calls.
59 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040060 Changed |= replaceMemset(M);
61 Changed |= replaceMemcpy(M);
62
63 return Changed;
64}
65
66bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
67 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040068 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040069
70 for (auto &F : M) {
71 if (F.getName().startswith("llvm.memset")) {
72 SmallVector<CallInst *, 8> CallsToReplace;
73
74 for (auto U : F.users()) {
75 if (auto CI = dyn_cast<CallInst>(U)) {
76 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
77
78 // We only handle cases where the initializer is a constant int that
79 // is 0.
80 if (!Initializer || (0 != Initializer->getZExtValue())) {
81 Initializer->print(errs());
82 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
83 "non-0 initializer!");
84 }
85
86 CallsToReplace.push_back(CI);
87 }
88 }
89
90 for (auto CI : CallsToReplace) {
91 auto NewArg = CI->getArgOperand(0);
92
alan-bakered80f572019-02-11 17:28:26 -050093 if (auto Bitcast = dyn_cast<BitCastOperator>(NewArg)) {
David Neto22f144c2017-06-12 14:26:21 -040094 NewArg = Bitcast->getOperand(0);
95 }
96
David Netod3f59382017-10-18 18:30:30 -040097 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
98
David Neto22f144c2017-06-12 14:26:21 -040099 auto Ty = NewArg->getType();
100 auto PointeeTy = Ty->getPointerElementType();
101
102 auto NewFType =
103 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
104
105 // Create our fake intrinsic to initialize it to 0.
106 auto SPIRVIntrinsic = "spirv.store_null";
107
108 auto NewF =
109 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
110
111 auto Zero = Constant::getNullValue(PointeeTy);
112
David Netod3f59382017-10-18 18:30:30 -0400113 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
114 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
115 "Null memset can't be divided evenly across multiple stores.");
116 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400117
David Netod3f59382017-10-18 18:30:30 -0400118 // Generate the first store.
119 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
120
121 // Generate subsequent stores, but only if needed.
122 if (num_stores) {
123 auto I32Ty = Type::getInt32Ty(M.getContext());
124 auto One = ConstantInt::get(I32Ty, 1);
125 auto Ptr = NewArg;
126 for (uint32_t i = 1; i < num_stores; i++) {
127 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
128 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
129 }
130 }
131
David Neto22f144c2017-06-12 14:26:21 -0400132 CI->eraseFromParent();
133
134 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
135 Bitcast->eraseFromParent();
136 }
137 }
138 }
139 }
140
141 return Changed;
142}
143
144bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
145 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400146 auto Layout = M.getDataLayout();
147
148 // Unpack source and destination types until we find a matching
149 // element type. Count the number of levels we unpack for the
150 // source and destination types. So far this only works for
151 // array types, but could be generalized to other regular types
152 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400153 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
154 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400155 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400156 auto descend_type = [](Type *InType) {
157 Type *OutType = InType;
158 if (OutType->isStructTy()) {
159 OutType = OutType->getStructElementType(0);
160 } else if (OutType->isArrayTy()) {
161 OutType = OutType->getArrayElementType();
162 } else if (OutType->isVectorTy()) {
163 OutType = OutType->getVectorElementType();
164 } else {
165 assert(false && "Don't know how to descend into type");
166 }
167
168 return OutType;
169 };
170
David Netob84ba342017-06-19 17:55:37 -0400171 unsigned *numSrcUnpackings = 0;
172 unsigned *numDstUnpackings = 0;
173 while (*SrcElemTy != *DstElemTy) {
174 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
175 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
176 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400177 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400178 (*NumSrcUnpackings)++;
179 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400180 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400181 (*NumDstUnpackings)++;
182 } else {
183 errs() << "Don't know how to unpack types for memcpy: " << CI
184 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
185 assert(false && "Don't know how to unpack these types");
186 }
187 }
Alan Baker7dea8842018-10-22 10:15:41 -0400188
189 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
190 while (Size < DstElemSize) {
191 *DstElemTy = descend_type(*DstElemTy);
192 *SrcElemTy = descend_type(*SrcElemTy);
193 (*NumDstUnpackings)++;
194 (*NumSrcUnpackings)++;
195 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
196 }
David Netob84ba342017-06-19 17:55:37 -0400197 };
David Neto22f144c2017-06-12 14:26:21 -0400198
199 for (auto &F : M) {
200 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400201 SmallPtrSet<Instruction *, 8> BitCastsToForget;
202 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400203
204 for (auto U : F.users()) {
205 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500206 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400207 auto Dst =
208 dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400209
alan-bakered80f572019-02-11 17:28:26 -0500210 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400211 auto Src =
212 dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400213
214 // The original type of Dst we get from the argument to the bitcast
215 // instruction.
216 auto DstTy = Dst->getType();
217 assert(DstTy->isPointerTy());
218
219 // The original type of Src we get from the argument to the bitcast
220 // instruction.
221 auto SrcTy = Src->getType();
222 assert(SrcTy->isPointerTy());
223
David Neto22f144c2017-06-12 14:26:21 -0400224 // Check that the size is a constant integer.
225 assert(isa<ConstantInt>(CI->getArgOperand(2)));
226 auto Size =
227 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
228
Alan Baker7dea8842018-10-22 10:15:41 -0400229 auto DstElemTy = DstTy->getPointerElementType();
230 auto SrcElemTy = SrcTy->getPointerElementType();
231 unsigned NumDstUnpackings = 0;
232 unsigned NumSrcUnpackings = 0;
233 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
234 &NumSrcUnpackings);
235
236 // Check that the pointee types match.
237 assert(DstElemTy == SrcElemTy);
238
David Netob84ba342017-06-19 17:55:37 -0400239 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400240
David Netob84ba342017-06-19 17:55:37 -0400241 // Check that the size is a multiple of the size of the pointee type.
242 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400243
alan-bakerbccf62c2019-03-29 10:32:41 -0400244 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400245 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400246
247 // Check that the alignment is at least the alignment of the pointee
248 // type.
249 assert(Alignment >= TypeAlignment);
250
251 // Check that the alignment is a multiple of the alignment of the
252 // pointee type.
253 assert(0 == (Alignment % TypeAlignment));
254
255 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400256 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400257
David Netob84ba342017-06-19 17:55:37 -0400258 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400259 }
260 }
261
David Netob84ba342017-06-19 17:55:37 -0400262 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500263 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
264 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400265 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400266
267 auto I32Ty = Type::getInt32Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400268 auto Alignment =
269 ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
alan-bakerbccf62c2019-03-29 10:32:41 -0400270 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400271
alan-bakered80f572019-02-11 17:28:26 -0500272 auto Dst = Arg0->getOperand(0);
273 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400274
275 auto DstElemTy = Dst->getType()->getPointerElementType();
276 auto SrcElemTy = Src->getType()->getPointerElementType();
277 unsigned NumDstUnpackings = 0;
278 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400279 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400280 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
281 &NumSrcUnpackings);
282 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400283
David Netob84ba342017-06-19 17:55:37 -0400284 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400285
David Netob84ba342017-06-19 17:55:37 -0400286 IRBuilder<> Builder(CI);
287
288 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
289 auto NewFType = FunctionType::get(
290 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
291 false);
292 auto NewF =
293 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
294 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
295 } else {
296 auto Zero = ConstantInt::get(I32Ty, 0);
297 SmallVector<Value *, 3> SrcIndices;
298 SmallVector<Value *, 3> DstIndices;
299 // Make unpacking indices.
300 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
301 ++unpacking) {
302 SrcIndices.push_back(Zero);
303 }
304 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
305 ++unpacking) {
306 DstIndices.push_back(Zero);
307 }
308 // Add a placeholder for the final index.
309 SrcIndices.push_back(Zero);
310 DstIndices.push_back(Zero);
311
312 // Build the function and function type only once.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400313 FunctionType *NewFType = nullptr;
314 Function *NewF = nullptr;
David Netob84ba342017-06-19 17:55:37 -0400315
316 IRBuilder<> Builder(CI);
317 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
318 auto Index = ConstantInt::get(I32Ty, i);
319 SrcIndices.back() = Index;
320 DstIndices.back() = Index;
321
alan-bakered80f572019-02-11 17:28:26 -0500322 // Avoid the builder for Src in order to prevent the folder from
323 // creating constant expressions for constant memcpys.
324 auto SrcElemPtr =
325 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400326 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
327 NewFType =
328 NewFType != nullptr
329 ? NewFType
330 : FunctionType::get(F.getReturnType(),
331 {DstElemPtr->getType(),
332 SrcElemPtr->getType(), I32Ty, I32Ty},
333 false);
334 NewF = NewF != nullptr ? NewF
335 : Function::Create(NewFType, F.getLinkage(),
336 SPIRVIntrinsic, &M);
337 Builder.CreateCall(
338 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
339 }
340 }
341
342 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400343 CI->eraseFromParent();
344
David Netob84ba342017-06-19 17:55:37 -0400345 // Erase the bitcasts. A particular bitcast might be used
346 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500347 if (isa<BitCastInst>(Arg0))
348 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
349 if (isa<BitCastInst>(Arg1))
350 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400351 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400352 for (auto *Inst : BitCastsToForget) {
David Netob84ba342017-06-19 17:55:37 -0400353 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400354 }
355 }
356 }
357
358 return Changed;
359}
David Netoe345e0e2018-06-15 11:38:32 -0400360
361bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
362 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
363 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
364
365 bool Changed = false;
366
367 SmallVector<Function *, 2> WorkList;
368 for (auto &F : M) {
369 if (F.getName().startswith("llvm.lifetime.")) {
370 WorkList.push_back(&F);
371 }
372 }
373
374 for (auto *F : WorkList) {
375 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500376 // Copy users to avoid modifying the list in place.
377 SmallVector<User *, 8> users(F->users());
378 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400379 if (auto *CI = dyn_cast<CallInst>(U)) {
380 CI->eraseFromParent();
381 }
382 }
383 F->eraseFromParent();
384 }
385
386 return Changed;
387}