blob: c7df9de5ba648b90c8b48d497a554ab36306e7ef [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/Constants.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/IR/Module.h>
18#include <llvm/Pass.h>
19#include <llvm/Support/raw_ostream.h>
20#include <llvm/Transforms/Utils/Cloning.h>
21
22#include <spirv/1.0/spirv.hpp>
23
24using namespace llvm;
25
26#define DEBUG_TYPE "ReplaceLLVMIntrinsics"
27
28namespace {
29struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
30 static char ID;
31 ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}
32
33 bool runOnModule(Module &M) override;
34 bool replaceMemset(Module &M);
35 bool replaceMemcpy(Module &M);
36};
37}
38
39char ReplaceLLVMIntrinsicsPass::ID = 0;
40static RegisterPass<ReplaceLLVMIntrinsicsPass>
41 X("ReplaceLLVMIntrinsics", "Replace LLVM intrinsics Pass");
42
43namespace clspv {
44ModulePass *createReplaceLLVMIntrinsicsPass() {
45 return new ReplaceLLVMIntrinsicsPass();
46}
47}
48
49bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
50 bool Changed = false;
51
52 Changed |= replaceMemset(M);
53 Changed |= replaceMemcpy(M);
54
55 return Changed;
56}
57
58bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
59 bool Changed = false;
60
61 for (auto &F : M) {
62 if (F.getName().startswith("llvm.memset")) {
63 SmallVector<CallInst *, 8> CallsToReplace;
64
65 for (auto U : F.users()) {
66 if (auto CI = dyn_cast<CallInst>(U)) {
67 auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));
68
69 // We only handle cases where the initializer is a constant int that
70 // is 0.
71 if (!Initializer || (0 != Initializer->getZExtValue())) {
72 Initializer->print(errs());
73 llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
74 "non-0 initializer!");
75 }
76
77 CallsToReplace.push_back(CI);
78 }
79 }
80
81 for (auto CI : CallsToReplace) {
82 auto NewArg = CI->getArgOperand(0);
83
84 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
85 NewArg = Bitcast->getOperand(0);
86 }
87
88 auto Ty = NewArg->getType();
89 auto PointeeTy = Ty->getPointerElementType();
90
91 auto NewFType =
92 FunctionType::get(F.getReturnType(), {Ty, PointeeTy}, false);
93
94 // Create our fake intrinsic to initialize it to 0.
95 auto SPIRVIntrinsic = "spirv.store_null";
96
97 auto NewF =
98 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
99
100 auto Zero = Constant::getNullValue(PointeeTy);
101
102 auto NewCI = CallInst::Create(NewF, {NewArg, Zero}, "", CI);
103
104 CI->replaceAllUsesWith(NewCI);
105 CI->eraseFromParent();
106
107 if (auto Bitcast = dyn_cast<BitCastInst>(NewArg)) {
108 Bitcast->eraseFromParent();
109 }
110 }
111 }
112 }
113
114 return Changed;
115}
116
117bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
118 bool Changed = false;
119
120 for (auto &F : M) {
121 if (F.getName().startswith("llvm.memcpy")) {
122 SmallVector<CallInst *, 8> CallsToReplace;
123
124 for (auto U : F.users()) {
125 if (auto CI = dyn_cast<CallInst>(U)) {
126 assert(isa<BitCastInst>(CI->getArgOperand(0)));
127 auto Dst = dyn_cast<BitCastInst>(CI->getArgOperand(0))->getOperand(0);
128
129 assert(isa<BitCastInst>(CI->getArgOperand(1)));
130 auto Src = dyn_cast<BitCastInst>(CI->getArgOperand(1))->getOperand(0);
131
132 // The original type of Dst we get from the argument to the bitcast
133 // instruction.
134 auto DstTy = Dst->getType();
135 assert(DstTy->isPointerTy());
136
137 // The original type of Src we get from the argument to the bitcast
138 // instruction.
139 auto SrcTy = Src->getType();
140 assert(SrcTy->isPointerTy());
141
142 // Check that the pointee types match.
143 assert(DstTy->getPointerElementType() ==
144 SrcTy->getPointerElementType());
145
146 // Check that the size is a constant integer.
147 assert(isa<ConstantInt>(CI->getArgOperand(2)));
148 auto Size =
149 dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
150
151 auto TypeSize = M.getDataLayout().getTypeSizeInBits(
152 DstTy->getPointerElementType()) /
153 8;
154
155 // Check that the size is equal to the alignment of the pointee type.
156 assert(Size == TypeSize);
157
158 // Check that the alignment is a constant integer.
159 assert(isa<ConstantInt>(CI->getArgOperand(3)));
160 auto Alignment =
161 dyn_cast<ConstantInt>(CI->getArgOperand(3))->getZExtValue();
162
163 auto TypeAlignment = M.getDataLayout().getABITypeAlignment(
164 DstTy->getPointerElementType());
165
166 // Check that the alignment is at least the alignment of the pointee
167 // type.
168 assert(Alignment >= TypeAlignment);
169
170 // Check that the alignment is a multiple of the alignment of the
171 // pointee type.
172 assert(0 == (Alignment % TypeAlignment));
173
174 // Check that volatile is a constant.
175 assert(isa<ConstantInt>(CI->getArgOperand(4)));
176
177 CallsToReplace.push_back(CI);
178 }
179 }
180
181 for (auto CI : CallsToReplace) {
182 auto Arg0 = dyn_cast<BitCastInst>(CI->getArgOperand(0));
183 auto Arg1 = dyn_cast<BitCastInst>(CI->getArgOperand(1));
184
185 auto Dst = dyn_cast<BitCastInst>(Arg0)->getOperand(0);
186 auto Src = dyn_cast<BitCastInst>(Arg1)->getOperand(0);
187
188 auto DstTy = Dst->getType();
189 auto SrcTy = Src->getType();
190
191 auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));
192 auto Arg4 = dyn_cast<ConstantInt>(CI->getArgOperand(4));
193
194 auto I32Ty = Type::getInt32Ty(M.getContext());
195
196 auto NewFType = FunctionType::get(F.getReturnType(),
197 {DstTy, SrcTy, I32Ty, I32Ty}, false);
198
199 auto SPIRVIntrinsic = "spirv.copy_memory";
200
201 auto NewF =
202 Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
203
204 auto NewCI = CallInst::Create(
205 NewF, {Dst, Src, ConstantInt::get(I32Ty, Arg3->getZExtValue()),
206 ConstantInt::get(I32Ty, Arg4->getZExtValue())},
207 "", CI);
208
209 CI->replaceAllUsesWith(NewCI);
210 CI->eraseFromParent();
211
212 Arg0->eraseFromParent();
213 Arg1->eraseFromParent();
214 }
215 }
216 }
217
218 return Changed;
219}