blob: 1cc9bd22cc1b65da35df4c016d28e70a2b66f98e [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Netoe345e0e2018-06-15 11:38:32 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
alan-bakerbccf62c2019-03-29 10:32:41 -040018#include "llvm/IR/IntrinsicInst.h"
David Netoe345e0e2018-06-15 11:38:32 -040019#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040023
alan-bakere0902602020-03-23 08:43:40 -040024#include "spirv/unified1/spirv.hpp"
David Neto22f144c2017-06-12 14:26:21 -040025
Diego Novilloa4c44fa2019-04-11 10:56:15 -040026#include "Passes.h"
27
David Neto22f144c2017-06-12 14:26:21 -040028using namespace llvm;
29
30#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
31
32namespace {
33struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
34 static char ID;
35 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
36
37 bool runOnModule(Module &M) override;
38 bool replaceMemset(Module &M);
39 bool replaceMemcpy(Module &M);
David Netoe345e0e2018-06-15 11:38:32 -040040 bool removeLifetimeDeclarations(Module &M);
David Neto22f144c2017-06-12 14:26:21 -040041};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace
David Neto22f144c2017-06-12 14:26:21 -040043
44char ReplaceLLVMIntrinsicsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040045INITIALIZE_PASS(ReplaceLLVMIntrinsicsPass, "ReplaceLLVMIntrinsics",
46 "Replace LLVM intrinsics Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040047
48namespace clspv {
49ModulePass *createReplaceLLVMIntrinsicsPass() {
50 return new ReplaceLLVMIntrinsicsPass();
51}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040052} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040053
54bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
55 bool Changed = false;
56
David Netoe345e0e2018-06-15 11:38:32 -040057 // Remove lifetime annotations first. They coulud be using memset
58 // and memcpy calls.
59 Changed |= removeLifetimeDeclarations(M);
David Neto22f144c2017-06-12 14:26:21 -040060 Changed |= replaceMemset(M);
61 Changed |= replaceMemcpy(M);
62
63 return Changed;
64}
65
66bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
67 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040068 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040069
70 for (auto &F : M) {
71 if (F.getName().startswith("llvm.memset")) {
72 SmallVector<CallInst *, 8> CallsToReplace;
73
74 for (auto U : F.users()) {
75 if (auto CI = dyn_cast<CallInst>(U)) {
76 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
77
78 // We only handle cases where the initializer is a constant int that
79 // is 0.
80 if (!Initializer || (0 != Initializer->getZExtValue())) {
81 Initializer->print(errs());
82 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
83 "non-0 initializer!");
84 }
85
86 CallsToReplace.push_back(CI);
87 }
88 }
89
90 for (auto CI : CallsToReplace) {
91 auto NewArg = CI->getArgOperand(0);
Kévin Petit70944912019-04-17 23:22:28 +010092 auto Bitcast = dyn_cast<BitCastInst>(NewArg);
93 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -040094 NewArg = Bitcast->getOperand(0);
95 }
96
David Netod3f59382017-10-18 18:30:30 -040097 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -040098 auto Ty = NewArg->getType();
99 auto PointeeTy = Ty->getPointerElementType();
David Neto22f144c2017-06-12 14:26:21 -0400100 auto Zero = Constant::getNullValue(PointeeTy);
101
David Netod3f59382017-10-18 18:30:30 -0400102 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
103 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
104 "Null memset can't be divided evenly across multiple stores.");
105 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400106
David Netod3f59382017-10-18 18:30:30 -0400107 // Generate the first store.
Kévin Petit58c445c2019-06-18 18:09:46 +0100108 new StoreInst(Zero, NewArg, CI);
David Netod3f59382017-10-18 18:30:30 -0400109
110 // Generate subsequent stores, but only if needed.
111 if (num_stores) {
112 auto I32Ty = Type::getInt32Ty(M.getContext());
113 auto One = ConstantInt::get(I32Ty, 1);
114 auto Ptr = NewArg;
115 for (uint32_t i = 1; i < num_stores; i++) {
116 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
Kévin Petit58c445c2019-06-18 18:09:46 +0100117 new StoreInst(Zero, Ptr, CI);
David Netod3f59382017-10-18 18:30:30 -0400118 }
119 }
120
David Neto22f144c2017-06-12 14:26:21 -0400121 CI->eraseFromParent();
122
Kévin Petit70944912019-04-17 23:22:28 +0100123 if (Bitcast != nullptr) {
David Neto22f144c2017-06-12 14:26:21 -0400124 Bitcast->eraseFromParent();
125 }
126 }
127 }
128 }
129
130 return Changed;
131}
132
133bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
134 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400135 auto Layout = M.getDataLayout();
136
137 // Unpack source and destination types until we find a matching
138 // element type. Count the number of levels we unpack for the
139 // source and destination types. So far this only works for
140 // array types, but could be generalized to other regular types
141 // like vectors.
Alan Baker7dea8842018-10-22 10:15:41 -0400142 auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
143 Type **SrcElemTy, unsigned *NumDstUnpackings,
David Netob84ba342017-06-19 17:55:37 -0400144 unsigned *NumSrcUnpackings) {
Alan Baker7dea8842018-10-22 10:15:41 -0400145 auto descend_type = [](Type *InType) {
146 Type *OutType = InType;
147 if (OutType->isStructTy()) {
148 OutType = OutType->getStructElementType(0);
149 } else if (OutType->isArrayTy()) {
150 OutType = OutType->getArrayElementType();
151 } else if (OutType->isVectorTy()) {
152 OutType = OutType->getVectorElementType();
153 } else {
154 assert(false && "Don't know how to descend into type");
155 }
156
157 return OutType;
158 };
159
David Netob84ba342017-06-19 17:55:37 -0400160 unsigned *numSrcUnpackings = 0;
161 unsigned *numDstUnpackings = 0;
162 while (*SrcElemTy != *DstElemTy) {
163 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
164 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
165 if (SrcElemSize >= DstElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400166 *SrcElemTy = descend_type(*SrcElemTy);
David Netob84ba342017-06-19 17:55:37 -0400167 (*NumSrcUnpackings)++;
168 } else if (DstElemSize >= SrcElemSize) {
Alan Baker7dea8842018-10-22 10:15:41 -0400169 *DstElemTy = descend_type(*DstElemTy);
David Netob84ba342017-06-19 17:55:37 -0400170 (*NumDstUnpackings)++;
171 } else {
172 errs() << "Don't know how to unpack types for memcpy: " << CI
173 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
174 assert(false && "Don't know how to unpack these types");
175 }
176 }
Alan Baker7dea8842018-10-22 10:15:41 -0400177
178 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
179 while (Size < DstElemSize) {
180 *DstElemTy = descend_type(*DstElemTy);
181 *SrcElemTy = descend_type(*SrcElemTy);
182 (*NumDstUnpackings)++;
183 (*NumSrcUnpackings)++;
184 DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
185 }
David Netob84ba342017-06-19 17:55:37 -0400186 };
David Neto22f144c2017-06-12 14:26:21 -0400187
188 for (auto &F : M) {
189 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400190 SmallPtrSet<Instruction *, 8> BitCastsToForget;
191 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400192
193 for (auto U : F.users()) {
194 if (auto CI = dyn_cast<CallInst>(U)) {
alan-bakered80f572019-02-11 17:28:26 -0500195 assert(isa<BitCastOperator>(CI->getArgOperand(0)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400196 auto Dst =
197 dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400198
alan-bakered80f572019-02-11 17:28:26 -0500199 assert(isa<BitCastOperator>(CI->getArgOperand(1)));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400200 auto Src =
201 dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);
David Neto22f144c2017-06-12 14:26:21 -0400202
203 // The original type of Dst we get from the argument to the bitcast
204 // instruction.
205 auto DstTy = Dst->getType();
206 assert(DstTy->isPointerTy());
207
208 // The original type of Src we get from the argument to the bitcast
209 // instruction.
210 auto SrcTy = Src->getType();
211 assert(SrcTy->isPointerTy());
212
David Neto22f144c2017-06-12 14:26:21 -0400213 // Check that the size is a constant integer.
214 assert(isa<ConstantInt>(CI->getArgOperand(2)));
215 auto Size =
216 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
217
Alan Baker7dea8842018-10-22 10:15:41 -0400218 auto DstElemTy = DstTy->getPointerElementType();
219 auto SrcElemTy = SrcTy->getPointerElementType();
220 unsigned NumDstUnpackings = 0;
221 unsigned NumSrcUnpackings = 0;
222 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
223 &NumSrcUnpackings);
224
225 // Check that the pointee types match.
226 assert(DstElemTy == SrcElemTy);
227
David Netob84ba342017-06-19 17:55:37 -0400228 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400229
David Netob84ba342017-06-19 17:55:37 -0400230 // Check that the size is a multiple of the size of the pointee type.
231 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400232
alan-bakerbccf62c2019-03-29 10:32:41 -0400233 auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
David Netob84ba342017-06-19 17:55:37 -0400234 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400235
236 // Check that the alignment is at least the alignment of the pointee
237 // type.
238 assert(Alignment >= TypeAlignment);
239
240 // Check that the alignment is a multiple of the alignment of the
241 // pointee type.
242 assert(0 == (Alignment % TypeAlignment));
243
244 // Check that volatile is a constant.
alan-bakerbccf62c2019-03-29 10:32:41 -0400245 assert(isa<ConstantInt>(CI->getArgOperand(3)));
David Neto22f144c2017-06-12 14:26:21 -0400246
David Netob84ba342017-06-19 17:55:37 -0400247 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400248 }
249 }
250
David Netob84ba342017-06-19 17:55:37 -0400251 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
alan-bakered80f572019-02-11 17:28:26 -0500252 auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
253 auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400254 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
David Neto22f144c2017-06-12 14:26:21 -0400255
256 auto I32Ty = Type::getInt32Ty(M.getContext());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400257 auto Alignment =
258 ConstantInt::get(I32Ty, cast<MemIntrinsic>(CI)->getDestAlignment());
alan-bakerbccf62c2019-03-29 10:32:41 -0400259 auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400260
alan-bakered80f572019-02-11 17:28:26 -0500261 auto Dst = Arg0->getOperand(0);
262 auto Src = Arg1->getOperand(0);
David Netob84ba342017-06-19 17:55:37 -0400263
264 auto DstElemTy = Dst->getType()->getPointerElementType();
265 auto SrcElemTy = Src->getType()->getPointerElementType();
266 unsigned NumDstUnpackings = 0;
267 unsigned NumSrcUnpackings = 0;
David Netob84ba342017-06-19 17:55:37 -0400268 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
Alan Baker7dea8842018-10-22 10:15:41 -0400269 match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
270 &NumSrcUnpackings);
271 auto SPIRVIntrinsic = "spirv.copy_memory";
David Neto22f144c2017-06-12 14:26:21 -0400272
David Netob84ba342017-06-19 17:55:37 -0400273 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400274
David Netob84ba342017-06-19 17:55:37 -0400275 IRBuilder<> Builder(CI);
276
277 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
278 auto NewFType = FunctionType::get(
279 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
280 false);
281 auto NewF =
282 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
283 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
284 } else {
285 auto Zero = ConstantInt::get(I32Ty, 0);
286 SmallVector<Value *, 3> SrcIndices;
287 SmallVector<Value *, 3> DstIndices;
288 // Make unpacking indices.
289 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
290 ++unpacking) {
291 SrcIndices.push_back(Zero);
292 }
293 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
294 ++unpacking) {
295 DstIndices.push_back(Zero);
296 }
297 // Add a placeholder for the final index.
298 SrcIndices.push_back(Zero);
299 DstIndices.push_back(Zero);
300
301 // Build the function and function type only once.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400302 FunctionType *NewFType = nullptr;
303 Function *NewF = nullptr;
David Netob84ba342017-06-19 17:55:37 -0400304
305 IRBuilder<> Builder(CI);
306 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
307 auto Index = ConstantInt::get(I32Ty, i);
308 SrcIndices.back() = Index;
309 DstIndices.back() = Index;
310
alan-bakered80f572019-02-11 17:28:26 -0500311 // Avoid the builder for Src in order to prevent the folder from
312 // creating constant expressions for constant memcpys.
313 auto SrcElemPtr =
314 GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
David Netob84ba342017-06-19 17:55:37 -0400315 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
316 NewFType =
317 NewFType != nullptr
318 ? NewFType
319 : FunctionType::get(F.getReturnType(),
320 {DstElemPtr->getType(),
321 SrcElemPtr->getType(), I32Ty, I32Ty},
322 false);
323 NewF = NewF != nullptr ? NewF
324 : Function::Create(NewFType, F.getLinkage(),
325 SPIRVIntrinsic, &M);
326 Builder.CreateCall(
327 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
328 }
329 }
330
331 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400332 CI->eraseFromParent();
333
David Netob84ba342017-06-19 17:55:37 -0400334 // Erase the bitcasts. A particular bitcast might be used
335 // in more than one memcpy, so defer actual deleting until later.
alan-bakered80f572019-02-11 17:28:26 -0500336 if (isa<BitCastInst>(Arg0))
337 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
338 if (isa<BitCastInst>(Arg1))
339 BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
David Netob84ba342017-06-19 17:55:37 -0400340 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400341 for (auto *Inst : BitCastsToForget) {
David Netob84ba342017-06-19 17:55:37 -0400342 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400343 }
344 }
345 }
346
347 return Changed;
348}
David Netoe345e0e2018-06-15 11:38:32 -0400349
350bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
351 // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
352 // Vulkan doesn't support that, so remove all lifteime bounds declarations.
353
354 bool Changed = false;
355
356 SmallVector<Function *, 2> WorkList;
357 for (auto &F : M) {
358 if (F.getName().startswith("llvm.lifetime.")) {
359 WorkList.push_back(&F);
360 }
361 }
362
363 for (auto *F : WorkList) {
364 Changed = true;
alan-bakera5ff28e2018-11-21 16:27:20 -0500365 // Copy users to avoid modifying the list in place.
366 SmallVector<User *, 8> users(F->users());
367 for (auto U : users) {
David Netoe345e0e2018-06-15 11:38:32 -0400368 if (auto *CI = dyn_cast<CallInst>(U)) {
369 CI->eraseFromParent();
370 }
371 }
372 F->eraseFromParent();
373 }
374
375 return Changed;
376}