blob: 43efe80c1a114a099eb4d6f2de6e2396b0b08566 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
David Netoe345e0e2018-06-15 11:38:32 -040024#include "spirv/1.0/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
26using namespace llvm;
27
28#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
29
30namespace {
31struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
32 static char ID;
33 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
34
35 bool runOnModule(Module &M) override;
36 bool replaceMemset(Module &M);
37 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040038 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040039};
40}
41
42char ReplaceLLVMIntrinsicsPass::ID = 0;
43static RegisterPass<ReplaceLLVMIntrinsicsPass>
44 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
45
46namespace clspv {
47ModulePass *createReplaceLLVMIntrinsicsPass() {
48 return new ReplaceLLVMIntrinsicsPass();
49}
50}
51
52bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
53 bool Changed = false;
54
David Netoe345e0e2018-06-15 11:38:32 -040055 // Remove lifetime annotations first. They coulud be using memset
56 // and memcpy calls.
57 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040058 Changed |= replaceMemset(M);
59 Changed |= replaceMemcpy(M);
60
61 return Changed;
62}
63
64bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
65 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040066 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040067
68 for (auto &F : M) {
69 if (F.getName().startswith("llvm.memset")) {
70 SmallVector<CallInst *, 8> CallsToReplace;
71
72 for (auto U : F.users()) {
73 if (auto CI = dyn_cast<CallInst>(U)) {
74 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
75
76 // We only handle cases where the initializer is a constant int that
77 // is 0.
78 if (!Initializer || (0 != Initializer->getZExtValue())) {
79 Initializer->print(errs());
80 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
81 "non-0 initializer!");
82 }
83
84 CallsToReplace.push_back(CI);
85 }
86 }
87
88 for (auto CI : CallsToReplace) {
89 auto NewArg = CI->getArgOperand(0);
90
alan-bakered80f572019-02-11 17:28:26 -050091 if (auto Bitcast = dyn_cast<BitCastOperator>(NewArg)) {
David Neto22f144c2017-06-12 14:26:21 -040092 NewArg = Bitcast->getOperand(0);
93 }
94
David Netod3f59382017-10-18 18:30:30 -040095 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
96
David Neto22f144c2017-06-12 14:26:21 -040097 auto Ty = NewArg->getType();
98 auto PointeeTy = Ty->getPointerElementType();
99
100 auto NewFType =
101 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
102
103 // Create our fake intrinsic to initialize it to 0.
104 auto SPIRVIntrinsic = "spirv.store_null";
105
106 auto NewF =
107 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
108
109 auto Zero = Constant::getNullValue(PointeeTy);
110
David Netod3f59382017-10-18 18:30:30 -0400111 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
112 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
113 "Null memset can't be divided evenly across multiple stores.");
114 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400115
David Netod3f59382017-10-18 18:30:30 -0400116 // Generate the first store.
117 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
118
119 // Generate subsequent stores, but only if needed.
120 if (num_stores) {
121 auto I32Ty = Type::getInt32Ty(M.getContext());
122 auto One = ConstantInt::get(I32Ty, 1);
123 auto Ptr = NewArg;
124 for (uint32_t i = 1; i < num_stores; i++) {
125 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
126 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
127 }
128 }
129
David Neto22f144c2017-06-12 14:26:21 -0400130 CI->eraseFromParent();
131
132 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
133 Bitcast->eraseFromParent();
134 }
135 }
136 }
137 }
138
139 return Changed;
140}
141
142bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
143 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400144 auto Layout = M.getDataLayout();
145
146 // Unpack source and destination types until we find a matching
147 // element type. Count the number of levels we unpack for the
148 // source and destination types. So far this only works for
149 // array types, but could be generalized to other regular types
150 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400151 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
152 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400153 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400154 auto descend_type = [](Type *InType) {
155 Type *OutType = InType;
156 if (OutType->isStructTy()) {
157 OutType = OutType->getStructElementType(0);
158 } else if (OutType->isArrayTy()) {
159 OutType = OutType->getArrayElementType();
160 } else if (OutType->isVectorTy()) {
161 OutType = OutType->getVectorElementType();
162 } else {
163 assert(false && "Don't know how to descend into type");
164 }
165
166 return OutType;
167 };
168
David Netob84ba342017-06-19 17:55:37 -0400169 unsigned *numSrcUnpackings = 0;
170 unsigned *numDstUnpackings = 0;
171 while (*SrcElemTy != *DstElemTy) {
172 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
173 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
174 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400175 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400176 (*NumSrcUnpackings)++;
177 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400178 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400179 (*NumDstUnpackings)++;
180 } else {
181 errs() << "Don't know how to unpack types for memcpy: " << CI
182 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
183 assert(false && "Don't know how to unpack these types");
184 }
185 }
Alan Baker7dea8842018-10-22 10:15:41 -0400186
187 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
188 while (Size < DstElemSize) {
189 *DstElemTy = descend_type(*DstElemTy);
190 *SrcElemTy = descend_type(*SrcElemTy);
191 (*NumDstUnpackings)++;
192 (*NumSrcUnpackings)++;
193 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
194 }
David Netob84ba342017-06-19 17:55:37 -0400195 };
David Neto22f144c2017-06-12 14:26:21 -0400196
197 for (auto &F : M) {
198 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400199 SmallPtrSet<Instruction *, 8> BitCastsToForget;
200 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400201
202 for (auto U : F.users()) {
203 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500204 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
205 auto Dst = dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400206
alan-bakered80f572019-02-11 17:28:26 -0500207 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
208 auto Src = dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400209
210 // The original type of Dst we get from the argument to the bitcast
211 // instruction.
212 auto DstTy = Dst->getType();
213 assert(DstTy->isPointerTy());
214
215 // The original type of Src we get from the argument to the bitcast
216 // instruction.
217 auto SrcTy = Src->getType();
218 assert(SrcTy->isPointerTy());
219
David Neto22f144c2017-06-12 14:26:21 -0400220 // Check that the size is a constant integer.
221 assert(isa<ConstantInt>(CI->getArgOperand(2)));
222 auto Size =
223 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
224
Alan Baker7dea8842018-10-22 10:15:41 -0400225 auto DstElemTy = DstTy->getPointerElementType();
226 auto SrcElemTy = SrcTy->getPointerElementType();
227 unsigned NumDstUnpackings = 0;
228 unsigned NumSrcUnpackings = 0;
229 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
230 &NumSrcUnpackings);
231
232 // Check that the pointee types match.
233 assert(DstElemTy == SrcElemTy);
234
David Netob84ba342017-06-19 17:55:37 -0400235 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400236
David Netob84ba342017-06-19 17:55:37 -0400237 // Check that the size is a multiple of the size of the pointee type.
238 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400239
alan-bakerbccf62c2019-03-29 10:32:41 -0400240 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400241 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400242
243 // Check that the alignment is at least the alignment of the pointee
244 // type.
245 assert(Alignment >= TypeAlignment);
246
247 // Check that the alignment is a multiple of the alignment of the
248 // pointee type.
249 assert(0 == (Alignment % TypeAlignment));
250
251 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400252 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400253
David Netob84ba342017-06-19 17:55:37 -0400254 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400255 }
256 }
257
David Netob84ba342017-06-19 17:55:37 -0400258 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500259 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
260 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400261 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400262
263 auto I32Ty = Type::getInt32Ty(M.getContext());
alan-bakerbccf62c2019-03-29 10:32:41 -0400264 auto Alignment = ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
265 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400266
alan-bakered80f572019-02-11 17:28:26 -0500267 auto Dst = Arg0->getOperand(0);
268 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400269
270 auto DstElemTy = Dst->getType()->getPointerElementType();
271 auto SrcElemTy = Src->getType()->getPointerElementType();
272 unsigned NumDstUnpackings = 0;
273 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400274 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400275 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
276 &NumSrcUnpackings);
277 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400278
David Netob84ba342017-06-19 17:55:37 -0400279 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400280
David Netob84ba342017-06-19 17:55:37 -0400281 IRBuilder<> Builder(CI);
282
283 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
284 auto NewFType = FunctionType::get(
285 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
286 false);
287 auto NewF =
288 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
289 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
290 } else {
291 auto Zero = ConstantInt::get(I32Ty, 0);
292 SmallVector<Value *, 3> SrcIndices;
293 SmallVector<Value *, 3> DstIndices;
294 // Make unpacking indices.
295 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
296 ++unpacking) {
297 SrcIndices.push_back(Zero);
298 }
299 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
300 ++unpacking) {
301 DstIndices.push_back(Zero);
302 }
303 // Add a placeholder for the final index.
304 SrcIndices.push_back(Zero);
305 DstIndices.push_back(Zero);
306
307 // Build the function and function type only once.
308 FunctionType* NewFType = nullptr;
309 Function* NewF = nullptr;
310
311 IRBuilder<> Builder(CI);
312 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
313 auto Index = ConstantInt::get(I32Ty, i);
314 SrcIndices.back() = Index;
315 DstIndices.back() = Index;
316
alan-bakered80f572019-02-11 17:28:26 -0500317 // Avoid the builder for Src in order to prevent the folder from
318 // creating constant expressions for constant memcpys.
319 auto SrcElemPtr =
320 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400321 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
322 NewFType =
323 NewFType != nullptr
324 ? NewFType
325 : FunctionType::get(F.getReturnType(),
326 {DstElemPtr->getType(),
327 SrcElemPtr->getType(), I32Ty, I32Ty},
328 false);
329 NewF = NewF != nullptr ? NewF
330 : Function::Create(NewFType, F.getLinkage(),
331 SPIRVIntrinsic, &M);
332 Builder.CreateCall(
333 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
334 }
335 }
336
337 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400338 CI->eraseFromParent();
339
David Netob84ba342017-06-19 17:55:37 -0400340 // Erase the bitcasts. A particular bitcast might be used
341 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500342 if (isa<BitCastInst>(Arg0))
343 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
344 if (isa<BitCastInst>(Arg1))
345 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400346 }
347 for (auto* Inst : BitCastsToForget) {
348 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400349 }
350 }
351 }
352
353 return Changed;
354}
David Netoe345e0e2018-06-15 11:38:32 -0400355
356bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
357 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
358 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
359
360 bool Changed = false;
361
362 SmallVector<Function *, 2> WorkList;
363 for (auto &F : M) {
364 if (F.getName().startswith("llvm.lifetime.")) {
365 WorkList.push_back(&F);
366 }
367 }
368
369 for (auto *F : WorkList) {
370 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500371 // Copy users to avoid modifying the list in place.
372 SmallVector<User *, 8> users(F->users());
373 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400374 if (auto *CI = dyn_cast<CallInst>(U)) {
375 CI->eraseFromParent();
376 }
377 }
378 F->eraseFromParent();
379 }
380
381 return Changed;
382}