blob: 4b4309202531190cd631eed6b45ca04a9d120534 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/Constants.h>
David Netob84ba342017-06-19 17:55:37 -040016#include <llvm/IR/IRBuilder.h>
David Neto22f144c2017-06-12 14:26:21 -040017#include <llvm/IR/Instructions.h>
18#include <llvm/IR/Module.h>
19#include <llvm/Pass.h>
20#include <llvm/Support/raw_ostream.h>
21#include <llvm/Transforms/Utils/Cloning.h>
22
23#include <spirv/1.0/spirv.hpp>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
28
29namespace {
30struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
31 static char ID;
32 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
33
34 bool runOnModule(Module &M) override;
35 bool replaceMemset(Module &M);
36 bool replaceMemcpy(Module &M);
37};
38}
39
40char ReplaceLLVMIntrinsicsPass::ID = 0;
41static RegisterPass<ReplaceLLVMIntrinsicsPass>
42 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
43
44namespace clspv {
45ModulePass *createReplaceLLVMIntrinsicsPass() {
46 return new ReplaceLLVMIntrinsicsPass();
47}
48}
49
50bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
51 bool Changed = false;
52
53 Changed |= replaceMemset(M);
54 Changed |= replaceMemcpy(M);
55
56 return Changed;
57}
58
59bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
60 bool Changed = false;
61
62 for (auto &F : M) {
63 if (F.getName().startswith("llvm.memset")) {
64 SmallVector<CallInst *, 8> CallsToReplace;
65
66 for (auto U : F.users()) {
67 if (auto CI = dyn_cast<CallInst>(U)) {
68 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
69
70 // We only handle cases where the initializer is a constant int that
71 // is 0.
72 if (!Initializer || (0 != Initializer->getZExtValue())) {
73 Initializer->print(errs());
74 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
75 "non-0 initializer!");
76 }
77
78 CallsToReplace.push_back(CI);
79 }
80 }
81
82 for (auto CI : CallsToReplace) {
83 auto NewArg = CI->getArgOperand(0);
84
85 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
86 NewArg = Bitcast->getOperand(0);
87 }
88
89 auto Ty = NewArg->getType();
90 auto PointeeTy = Ty->getPointerElementType();
91
92 auto NewFType =
93 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
94
95 // Create our fake intrinsic to initialize it to 0.
96 auto SPIRVIntrinsic = "spirv.store_null";
97
98 auto NewF =
99 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
100
101 auto Zero = Constant::getNullValue(PointeeTy);
102
103 auto NewCI = CallInst::Create(NewF, {NewArg, Zero}, "", CI);
104
105 CI->replaceAllUsesWith(NewCI);
106 CI->eraseFromParent();
107
108 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
109 Bitcast->eraseFromParent();
110 }
111 }
112 }
113 }
114
115 return Changed;
116}
117
118bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
119 bool Changed = false;
David Netob84ba342017-06-19 17:55:37 -0400120 auto Layout = M.getDataLayout();
121
122 // Unpack source and destination types until we find a matching
123 // element type. Count the number of levels we unpack for the
124 // source and destination types. So far this only works for
125 // array types, but could be generalized to other regular types
126 // like vectors.
127 auto match_types = [&Layout](CallInst &CI, Type **DstElemTy, Type **SrcElemTy,
128 unsigned *NumDstUnpackings,
129 unsigned *NumSrcUnpackings) {
130 unsigned *numSrcUnpackings = 0;
131 unsigned *numDstUnpackings = 0;
132 while (*SrcElemTy != *DstElemTy) {
133 auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
134 auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
135 if (SrcElemSize >= DstElemSize) {
136 assert((*SrcElemTy)->isArrayTy());
137 *SrcElemTy = (*SrcElemTy)->getArrayElementType();
138 (*NumSrcUnpackings)++;
139 } else if (DstElemSize >= SrcElemSize) {
140 assert((*DstElemTy)->isArrayTy());
141 *DstElemTy = (*DstElemTy)->getArrayElementType();
142 (*NumDstUnpackings)++;
143 } else {
144 errs() << "Don't know how to unpack types for memcpy: " << CI
145 << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
146 assert(false && "Don't know how to unpack these types");
147 }
148 }
149 };
David Neto22f144c2017-06-12 14:26:21 -0400150
151 for (auto &F : M) {
152 if (F.getName().startswith("llvm.memcpy")) {
David Netob84ba342017-06-19 17:55:37 -0400153 SmallPtrSet<Instruction *, 8> BitCastsToForget;
154 SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;
David Neto22f144c2017-06-12 14:26:21 -0400155
156 for (auto U : F.users()) {
157 if (auto CI = dyn_cast<CallInst>(U)) {
158 assert(isa<BitCastInst>(CI->getArgOperand(0)));
159 auto Dst = dyn_cast<BitCastInst>(CI->getArgOperand(0))->getOperand(0);
160
161 assert(isa<BitCastInst>(CI->getArgOperand(1)));
162 auto Src = dyn_cast<BitCastInst>(CI->getArgOperand(1))->getOperand(0);
163
164 // The original type of Dst we get from the argument to the bitcast
165 // instruction.
166 auto DstTy = Dst->getType();
167 assert(DstTy->isPointerTy());
168
169 // The original type of Src we get from the argument to the bitcast
170 // instruction.
171 auto SrcTy = Src->getType();
172 assert(SrcTy->isPointerTy());
173
David Netob84ba342017-06-19 17:55:37 -0400174 auto DstElemTy = DstTy->getPointerElementType();
175 auto SrcElemTy = SrcTy->getPointerElementType();
176 unsigned NumDstUnpackings = 0;
177 unsigned NumSrcUnpackings = 0;
178 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
179 &NumSrcUnpackings);
180
David Neto22f144c2017-06-12 14:26:21 -0400181 // Check that the pointee types match.
David Netob84ba342017-06-19 17:55:37 -0400182 assert(DstElemTy == SrcElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400183
184 // Check that the size is a constant integer.
185 assert(isa<ConstantInt>(CI->getArgOperand(2)));
186 auto Size =
187 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
188
David Netob84ba342017-06-19 17:55:37 -0400189 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400190
David Netob84ba342017-06-19 17:55:37 -0400191 // Check that the size is a multiple of the size of the pointee type.
192 assert(Size % DstElemSize == 0);
David Neto22f144c2017-06-12 14:26:21 -0400193
194 // Check that the alignment is a constant integer.
195 assert(isa<ConstantInt>(CI->getArgOperand(3)));
196 auto Alignment =
197 dyn_cast<ConstantInt>(CI->getArgOperand(3))->getZExtValue();
198
David Netob84ba342017-06-19 17:55:37 -0400199 auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
David Neto22f144c2017-06-12 14:26:21 -0400200
201 // Check that the alignment is at least the alignment of the pointee
202 // type.
203 assert(Alignment >= TypeAlignment);
204
205 // Check that the alignment is a multiple of the alignment of the
206 // pointee type.
207 assert(0 == (Alignment % TypeAlignment));
208
209 // Check that volatile is a constant.
210 assert(isa<ConstantInt>(CI->getArgOperand(4)));
211
David Netob84ba342017-06-19 17:55:37 -0400212 CallsToReplaceWithSpirvCopyMemory.push_back(CI);
David Neto22f144c2017-06-12 14:26:21 -0400213 }
214 }
215
David Netob84ba342017-06-19 17:55:37 -0400216 for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
David Neto22f144c2017-06-12 14:26:21 -0400217 auto Arg0 = dyn_cast<BitCastInst>(CI->getArgOperand(0));
218 auto Arg1 = dyn_cast<BitCastInst>(CI->getArgOperand(1));
David Neto22f144c2017-06-12 14:26:21 -0400219 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
220 auto Arg4 = dyn_cast<ConstantInt>(CI->getArgOperand(4));
221
222 auto I32Ty = Type::getInt32Ty(M.getContext());
David Netob84ba342017-06-19 17:55:37 -0400223 auto Alignment = ConstantInt::get(I32Ty, Arg3->getZExtValue());
224 auto Volatile = ConstantInt::get(I32Ty, Arg4->getZExtValue());
David Neto22f144c2017-06-12 14:26:21 -0400225
David Netob84ba342017-06-19 17:55:37 -0400226 auto Dst = dyn_cast<BitCastInst>(Arg0)->getOperand(0);
227 auto Src = dyn_cast<BitCastInst>(Arg1)->getOperand(0);
228
229 auto DstElemTy = Dst->getType()->getPointerElementType();
230 auto SrcElemTy = Src->getType()->getPointerElementType();
231 unsigned NumDstUnpackings = 0;
232 unsigned NumSrcUnpackings = 0;
233 match_types(*CI, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
234 &NumSrcUnpackings);
235
236 assert(NumDstUnpackings < 2 && "Need to generalize dst unpacking case");
237 assert(NumSrcUnpackings < 2 && "Need to generalize src unpacking case");
238 assert((NumDstUnpackings == 0 || NumSrcUnpackings == 0) &&
239 "Need to generalize unpackings in both dimensions");
David Neto22f144c2017-06-12 14:26:21 -0400240
241 auto SPIRVIntrinsic = "spirv.copy_memory";
242
David Netob84ba342017-06-19 17:55:37 -0400243 auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
David Neto22f144c2017-06-12 14:26:21 -0400244
David Netob84ba342017-06-19 17:55:37 -0400245 auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
David Neto22f144c2017-06-12 14:26:21 -0400246
David Netob84ba342017-06-19 17:55:37 -0400247 IRBuilder<> Builder(CI);
248
249 if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
250 auto NewFType = FunctionType::get(
251 F.getReturnType(), {Dst->getType(), Src->getType(), I32Ty, I32Ty},
252 false);
253 auto NewF =
254 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
255 Builder.CreateCall(NewF, {Dst, Src, Alignment, Volatile}, "");
256 } else {
257 auto Zero = ConstantInt::get(I32Ty, 0);
258 SmallVector<Value *, 3> SrcIndices;
259 SmallVector<Value *, 3> DstIndices;
260 // Make unpacking indices.
261 for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
262 ++unpacking) {
263 SrcIndices.push_back(Zero);
264 }
265 for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
266 ++unpacking) {
267 DstIndices.push_back(Zero);
268 }
269 // Add a placeholder for the final index.
270 SrcIndices.push_back(Zero);
271 DstIndices.push_back(Zero);
272
273 // Build the function and function type only once.
274 FunctionType* NewFType = nullptr;
275 Function* NewF = nullptr;
276
277 IRBuilder<> Builder(CI);
278 for (unsigned i = 0; i < Size / DstElemSize; ++i) {
279 auto Index = ConstantInt::get(I32Ty, i);
280 SrcIndices.back() = Index;
281 DstIndices.back() = Index;
282
283 auto SrcElemPtr = Builder.CreateGEP(Src, SrcIndices);
284 auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
285 NewFType =
286 NewFType != nullptr
287 ? NewFType
288 : FunctionType::get(F.getReturnType(),
289 {DstElemPtr->getType(),
290 SrcElemPtr->getType(), I32Ty, I32Ty},
291 false);
292 NewF = NewF != nullptr ? NewF
293 : Function::Create(NewFType, F.getLinkage(),
294 SPIRVIntrinsic, &M);
295 Builder.CreateCall(
296 NewF, {DstElemPtr, SrcElemPtr, Alignment, Volatile}, "");
297 }
298 }
299
300 // Erase the call.
David Neto22f144c2017-06-12 14:26:21 -0400301 CI->eraseFromParent();
302
David Netob84ba342017-06-19 17:55:37 -0400303 // Erase the bitcasts. A particular bitcast might be used
304 // in more than one memcpy, so defer actual deleting until later.
305 BitCastsToForget.insert(Arg0);
306 BitCastsToForget.insert(Arg1);
307 }
308 for (auto* Inst : BitCastsToForget) {
309 Inst->eraseFromParent();
David Neto22f144c2017-06-12 14:26:21 -0400310 }
311 }
312 }
313
314 return Changed;
315}