blob: 651c04c12c80dfb5e751fe6ff2cbdf0f9205a78d [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/Constants.h>
David Netob84ba342017-06-19 17:55:37 -040016#include <llvm/IR/IRBuilder.h>
David Neto22f144c2017-06-12 14:26:21 -040017#include <llvm/IR/Instructions.h>
18#include <llvm/IR/Module.h>
19#include <llvm/Pass.h>
20#include <llvm/Support/raw_ostream.h>
21#include <llvm/Transforms/Utils/Cloning.h>
22
23#include <spirv/1.0/spirv.hpp>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
28
29namespace {
30struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
31 static char ID;
32 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
33
34 bool runOnModule(Module &M) override;
35 bool replaceMemset(Module &M);
36 bool replaceMemcpy(Module &M);
37};
38}
39
40char ReplaceLLVMIntrinsicsPass::ID = 0;
41static RegisterPass<ReplaceLLVMIntrinsicsPass>
42 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
43
44namespace clspv {
45ModulePass *createReplaceLLVMIntrinsicsPass() {
46 return new ReplaceLLVMIntrinsicsPass();
47}
48}
49
50bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
51 bool Changed = false;
52
53 Changed |= replaceMemset(M);
54 Changed |= replaceMemcpy(M);
55
56 return Changed;
57}
58
59bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
60 bool Changed = false;
David Netod3f59382017-10-18 18:30:30 -040061 auto Layout = M.getDataLayout();
David Neto22f144c2017-06-12 14:26:21 -040062
63 for (auto &F : M) {
64 if (F.getName().startswith("llvm.memset")) {
65 SmallVector<CallInst *, 8> CallsToReplace;
66
67 for (auto U : F.users()) {
68 if (auto CI = dyn_cast<CallInst>(U)) {
69 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
70
71 // We only handle cases where the initializer is a constant int that
72 // is 0.
73 if (!Initializer || (0 != Initializer->getZExtValue())) {
74 Initializer->print(errs());
75 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
76 "non-0 initializer!");
77 }
78
79 CallsToReplace.push_back(CI);
80 }
81 }
82
83 for (auto CI : CallsToReplace) {
84 auto NewArg = CI->getArgOperand(0);
85
86 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
87 NewArg = Bitcast->getOperand(0);
88 }
89
David Netod3f59382017-10-18 18:30:30 -040090 auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
91
David Neto22f144c2017-06-12 14:26:21 -040092 auto Ty = NewArg->getType();
93 auto PointeeTy = Ty->getPointerElementType();
94
95 auto NewFType =
96 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
97
98 // Create our fake intrinsic to initialize it to 0.
99 auto SPIRVIntrinsic = "spirv.store_null";
100
101 auto NewF =
102 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
103
104 auto Zero = Constant::getNullValue(PointeeTy);
105
David Netod3f59382017-10-18 18:30:30 -0400106 const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
107 assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
108 "Null memset can't be divided evenly across multiple stores.");
109 assert((num_stores & 0xFFFFFFFF) == num_stores);
David Neto22f144c2017-06-12 14:26:21 -0400110
David Netod3f59382017-10-18 18:30:30 -0400111 // Generate the first store.
112 CallInst::Create(NewF, {NewArg, Zero}, "", CI);
113
114 // Generate subsequent stores, but only if needed.
115 if (num_stores) {
116 auto I32Ty = Type::getInt32Ty(M.getContext());
117 auto One = ConstantInt::get(I32Ty, 1);
118 auto Ptr = NewArg;
119 for (uint32_t i = 1; i < num_stores; i++) {
120 Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
121 CallInst::Create(NewF, {Ptr, Zero}, "", CI);
122 }
123 }
124
David Neto22f144c2017-06-12 14:26:21 -0400125 CI->eraseFromParent();
126
127 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
128 Bitcast->eraseFromParent();
129 }
130 }
131 }
132 }
133
134 return Changed;
135}
136
137bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
138 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400139 auto Layout = M.getDataLayout();
140
141 // Unpack source and destination types until we find a matching
142 // element type. Count the number of levels we unpack for the
143 // source and destination types. So far this only works for
144 // array types, but could be generalized to other regular types
145 // like vectors.
146 auto match_types = [&Layout](CallInst &CI, Type **DstElemTy, Type **SrcElemTy,
147 unsigned *NumDstUnpackings,
148 unsigned *NumSrcUnpackings) {
149 unsigned *numSrcUnpackings = 0;
150 unsigned *numDstUnpackings = 0;
151 while (*SrcElemTy != *DstElemTy) {
152 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
153 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
154 if (SrcElemSize >= DstElemSize) {
155 assert((*SrcElemTy)->isArrayTy());
156 *SrcElemTy = (*SrcElemTy)->getArrayElementType();
157 (*NumSrcUnpackings)++;
158 } else if (DstElemSize >= SrcElemSize) {
159 assert((*DstElemTy)->isArrayTy());
160 *DstElemTy = (*DstElemTy)->getArrayElementType();
161 (*NumDstUnpackings)++;
162 } else {
163 errs() << "Don't know how to unpack types for memcpy: " << CI
164 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
165 assert(false && "Don't know how to unpack these types");
166 }
167 }
168 };
David Neto22f144c2017-06-12 14:26:21 -0400169
170 for (auto &F : M) {
171 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400172 SmallPtrSet<Instruction *, 8> BitCastsToForget;
173 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400174
175 for (auto U : F.users()) {
176 if (auto CI = dyn_cast<CallInst>(U)) {
177 assert(isa<BitCastInst>(CI->getArgOperand(0)));
178 auto Dst = dyn_cast<BitCastInst>(CI->getArgOperand(0))->getOperand(0);
179
180 assert(isa<BitCastInst>(CI->getArgOperand(1)));
181 auto Src = dyn_cast<BitCastInst>(CI->getArgOperand(1))->getOperand(0);
182
183 // The original type of Dst we get from the argument to the bitcast
184 // instruction.
185 auto DstTy = Dst->getType();
186 assert(DstTy->isPointerTy());
187
188 // The original type of Src we get from the argument to the bitcast
189 // instruction.
190 auto SrcTy = Src->getType();
191 assert(SrcTy->isPointerTy());
192
David Netob84ba342017-06-19 17:55:37 -0400193 auto DstElemTy = DstTy->getPointerElementType();
194 auto SrcElemTy = SrcTy->getPointerElementType();
195 unsigned NumDstUnpackings = 0;
196 unsigned NumSrcUnpackings = 0;
197 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
198 &NumSrcUnpackings);
199
David Neto22f144c2017-06-12 14:26:21 -0400200 // Check that the pointee types match.
David Netob84ba342017-06-19 17:55:37 -0400201 assert(DstElemTy == SrcElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400202
203 // Check that the size is a constant integer.
204 assert(isa<ConstantInt>(CI->getArgOperand(2)));
205 auto Size =
206 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
207
David Netob84ba342017-06-19 17:55:37 -0400208 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400209
David Netob84ba342017-06-19 17:55:37 -0400210 // Check that the size is a multiple of the size of the pointee type.
211 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400212
213 // Check that the alignment is a constant integer.
214 assert(isa<ConstantInt>(CI->getArgOperand(3)));
215 auto Alignment =
216 dyn_cast<ConstantInt>(CI->getArgOperand(3))->getZExtValue();
217
David Netob84ba342017-06-19 17:55:37 -0400218 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400219
220 // Check that the alignment is at least the alignment of the pointee
221 // type.
222 assert(Alignment >= TypeAlignment);
223
224 // Check that the alignment is a multiple of the alignment of the
225 // pointee type.
226 assert(0 == (Alignment % TypeAlignment));
227
228 // Check that volatile is a constant.
229 assert(isa<ConstantInt>(CI->getArgOperand(4)));
230
David Netob84ba342017-06-19 17:55:37 -0400231 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400232 }
233 }
234
David Netob84ba342017-06-19 17:55:37 -0400235 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
David Neto22f144c2017-06-12 14:26:21 -0400236 auto Arg0 = dyn_cast<BitCastInst>(CI->getArgOperand(0));
237 auto Arg1 = dyn_cast<BitCastInst>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400238 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
239 auto Arg4 = dyn_cast<ConstantInt>(CI->getArgOperand(4));
240
241 auto I32Ty = Type::getInt32Ty(M.getContext());
David Netob84ba342017-06-19 17:55:37 -0400242 auto Alignment = ConstantInt::get(I32Ty, Arg3->getZExtValue());
243 auto Volatile = ConstantInt::get(I32Ty, Arg4->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400244
David Netob84ba342017-06-19 17:55:37 -0400245 auto Dst = dyn_cast<BitCastInst>(Arg0)->getOperand(0);
246 auto Src = dyn_cast<BitCastInst>(Arg1)->getOperand(0);
247
248 auto DstElemTy = Dst->getType()->getPointerElementType();
249 auto SrcElemTy = Src->getType()->getPointerElementType();
250 unsigned NumDstUnpackings = 0;
251 unsigned NumSrcUnpackings = 0;
252 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
253 &NumSrcUnpackings);
254
255 assert(NumDstUnpackings < 2 && "Need to generalize dst unpacking case");
256 assert(NumSrcUnpackings < 2 && "Need to generalize src unpacking case");
257 assert((NumDstUnpackings == 0 || NumSrcUnpackings == 0) &&
258 "Need to generalize unpackings in both dimensions");
David Neto22f144c2017-06-12 14:26:21 -0400259
260 auto SPIRVIntrinsic = "spirv.copy_memory";
261
David Netob84ba342017-06-19 17:55:37 -0400262 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -0400263
David Netob84ba342017-06-19 17:55:37 -0400264 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400265
David Netob84ba342017-06-19 17:55:37 -0400266 IRBuilder<> Builder(CI);
267
268 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
269 auto NewFType = FunctionType::get(
270 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
271 false);
272 auto NewF =
273 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
274 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
275 } else {
276 auto Zero = ConstantInt::get(I32Ty, 0);
277 SmallVector<Value *, 3> SrcIndices;
278 SmallVector<Value *, 3> DstIndices;
279 // Make unpacking indices.
280 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
281 ++unpacking) {
282 SrcIndices.push_back(Zero);
283 }
284 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
285 ++unpacking) {
286 DstIndices.push_back(Zero);
287 }
288 // Add a placeholder for the final index.
289 SrcIndices.push_back(Zero);
290 DstIndices.push_back(Zero);
291
292 // Build the function and function type only once.
293 FunctionType* NewFType = nullptr;
294 Function* NewF = nullptr;
295
296 IRBuilder<> Builder(CI);
297 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
298 auto Index = ConstantInt::get(I32Ty, i);
299 SrcIndices.back() = Index;
300 DstIndices.back() = Index;
301
302 auto SrcElemPtr = Builder.CreateGEP(Src, SrcIndices);
303 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
304 NewFType =
305 NewFType != nullptr
306 ? NewFType
307 : FunctionType::get(F.getReturnType(),
308 {DstElemPtr->getType(),
309 SrcElemPtr->getType(), I32Ty, I32Ty},
310 false);
311 NewF = NewF != nullptr ? NewF
312 : Function::Create(NewFType, F.getLinkage(),
313 SPIRVIntrinsic, &M);
314 Builder.CreateCall(
315 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
316 }
317 }
318
319 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400320 CI->eraseFromParent();
321
David Netob84ba342017-06-19 17:55:37 -0400322 // Erase the bitcasts. A particular bitcast might be used
323 // in more than one memcpy, so defer actual deleting until later.
324 BitCastsToForget.insert(Arg0);
325 BitCastsToForget.insert(Arg1);
326 }
327 for (auto* Inst : BitCastsToForget) {
328 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400329 }
330 }
331 }
332
333 return Changed;
334}