blob: ee26f32a3c6723f39ac826af279776b5125371ea [file] [log] [blame]
alan-bakera71f1932019-04-11 11:04:34 -04001// Copyright 2019 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <climits>
16#include <map>
17#include <set>
18#include <utility>
19#include <vector>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/UniqueVector.h"
25#include "llvm/IR/CallingConv.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/raw_ostream.h"
31#include "llvm/Transforms/Utils/Cloning.h"
32
33#include "clspv/Passes.h"
34
35#include "ArgKind.h"
36#include "CallGraphOrderedFunctions.h"
37#include "Constants.h"
38
39using namespace llvm;
40
41namespace {
42
43class MultiVersionUBOFunctionsPass final : public ModulePass {
44public:
45 static char ID;
46 MultiVersionUBOFunctionsPass() : ModulePass(ID) {}
47 bool runOnModule(Module &M) override;
48
49private:
50 // Struct for tracking specialization information.
51 struct ResourceInfo {
52 // The specific argument.
53 Argument *arg;
54 // The resource var base call.
55 CallInst *base;
56 // Series of GEPs that operate on |base|.
57 std::vector<GetElementPtrInst *> indices;
58 };
59
60 // Analyzes the call, |user|, to |fn| in terms of its UBO arguments. Returns
61 // true if |user| can be transformed into a specialized function.
62 //
63 // Currently, this function is only successful in analyzing GEP chains to a
64 // resource variable.
65 bool AnalyzeCall(Function *fn, CallInst *user,
66 std::vector<ResourceInfo> *resources);
67
68 // Inlines |call|.
69 void InlineCallSite(CallInst *call);
70
71 // Transforms the call to |fn| into a specialized call based on |resources|.
72 // Replaces |call| with a call to the specialized version.
73 void SpecializeCall(Function *fn, CallInst *call,
74 const std::vector<ResourceInfo> &resources, size_t id);
75
76 // Adds extra arguments to |fn| by rebuilding the entire function.
77 Function *AddExtraArguments(Function *fn,
78 const std::vector<Value *> &extra_args);
79};
80
81} // namespace
82
83char MultiVersionUBOFunctionsPass::ID = 0;
84static RegisterPass<MultiVersionUBOFunctionsPass>
85 X("MultiVersionUBOFunctionsPass",
86 "Multi-version functions with UBO params");
87
88namespace clspv {
89ModulePass *createMultiVersionUBOFunctionsPass() {
90 return new MultiVersionUBOFunctionsPass();
91}
92} // namespace clspv
93
94bool MultiVersionUBOFunctionsPass::runOnModule(Module &M) {
95 bool changed = false;
96 UniqueVector<Function *> ordered_functions =
97 clspv::CallGraphOrderedFunctions(M);
98
99 for (auto fn : ordered_functions) {
100 // Kernels don't need modified.
101 if (fn->isDeclaration() || fn->getCallingConv() == CallingConv::SPIR_KERNEL)
102 continue;
103
104 bool local_changed = false;
105 size_t count = 0;
alan-baker973ba8d2019-07-16 19:16:19 -0400106 SmallVector<User *, 8> users(fn->users());
107 for (auto user : users) {
alan-bakera71f1932019-04-11 11:04:34 -0400108 if (auto call = dyn_cast<CallInst>(user)) {
109 std::vector<ResourceInfo> resources;
110 if (AnalyzeCall(fn, call, &resources)) {
111 if (!resources.empty()) {
112 local_changed = true;
113 SpecializeCall(fn, call, resources, count++);
114 }
115 } else {
116 local_changed = true;
117 InlineCallSite(call);
118 }
119 }
120 }
121
122 fn->removeDeadConstantUsers();
123 if (local_changed) {
124 // All calls to this function were either specialized or inlined.
125 fn->eraseFromParent();
126 }
127 changed |= local_changed;
128 }
129
130 return changed;
131}
132
133bool MultiVersionUBOFunctionsPass::AnalyzeCall(
134 Function *fn, CallInst *user, std::vector<ResourceInfo> *resources) {
135 for (auto &arg : fn->args()) {
136 if (clspv::GetArgKindForType(arg.getType()) != clspv::ArgKind::BufferUBO)
137 continue;
138
139 Value *arg_operand = user->getOperand(arg.getArgNo());
140 ResourceInfo info;
141 info.arg = &arg;
142
143 DenseSet<Value *> visited;
144 std::vector<Value *> stack;
145 stack.push_back(arg_operand);
146
147 while (!stack.empty()) {
148 Value *value = stack.back();
149 stack.pop_back();
150
151 if (!visited.insert(value).second)
152 continue;
153
154 if (CallInst *call = dyn_cast<CallInst>(value)) {
155 if (call->getCalledFunction()->getName().startswith(
156 clspv::ResourceAccessorFunction())) {
157 info.base = call;
158 } else {
159 // Unknown function call returning a constant pointer requires
160 // inlining.
161 return false;
162 }
163 } else if (auto gep = dyn_cast<GetElementPtrInst>(value)) {
164 info.indices.push_back(gep);
165 stack.push_back(gep->getOperand(0));
166 } else {
167 // Unhandled instruction requires inlining.
168 return false;
169 }
170 }
171
172 resources->push_back(std::move(info));
173 }
174
175 return true;
176}
177
178void MultiVersionUBOFunctionsPass::InlineCallSite(CallInst *call) {
179 InlineFunctionInfo IFI;
180 CallSite CS(call);
181 InlineFunction(CS, IFI, nullptr, false);
182}
183
184void MultiVersionUBOFunctionsPass::SpecializeCall(
185 Function *fn, CallInst *call, const std::vector<ResourceInfo> &resources,
186 size_t id) {
187
188 // The basis of the specialization is a clone of |fn|, however, the clone may
189 // need rebuilt in order to receive extra arguments.
190 ValueToValueMapTy remapped;
191 auto *clone = CloneFunction(fn, remapped);
192 std::string name;
193 raw_string_ostream str(name);
194 str << fn->getName() << "_clspv_" << id;
195 clone->setName(str.str());
196
197 std::vector<Value *> extra_args;
198 for (auto info : resources) {
199 // Must traverse the GEPs in reverse order to match how the code will be
200 // generated below so that the iterator for the extra arguments is
201 // consistent.
202 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
203 ++iter) {
204 // Skip pointer operand.
205 auto *idx = *iter;
206 for (size_t i = 1; i < idx->getNumOperands(); ++i) {
207 Value *operand = idx->getOperand(i);
208 if (!isa<Constant>(operand)) {
209 extra_args.push_back(operand);
210 }
211 }
212 }
213 }
214
215 if (!extra_args.empty()) {
216 // Need to add extra arguments to this function.
217 clone = AddExtraArguments(clone, extra_args);
218 }
219
220 auto where = clone->begin()->begin();
221 while (isa<AllocaInst>(where)) {
222 ++where;
223 }
224
225 IRBuilder<> builder(&*where);
226 auto new_arg_iter = clone->arg_begin();
227 for (auto &arg : fn->args()) {
228 ++new_arg_iter;
229 }
230 for (auto info : resources) {
231 // Create the resource var function.
232 SmallVector<Value *, 8> operands;
233 for (size_t i = 0; i < info.base->getNumOperands() - 1; ++i)
234 operands.push_back(info.base->getOperand(i));
235 CallInst *resource_fn =
236 builder.CreateCall(info.base->getCalledFunction(), operands);
237
238 // Create the chain of GEPs. Traversed in reverse order because we added
239 // them from use to def.
240 Value *ptr = resource_fn;
241 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
242 ++iter) {
243 SmallVector<Value *, 8> indices;
244 for (size_t i = 1; i != (*iter)->getNumOperands(); ++i) {
245 Value *operand = (*iter)->getOperand(i);
246 if (isa<Constant>(operand)) {
247 indices.push_back(operand);
248 } else {
249 // Each extra argument is unique so the iterator is "consumed".
250 indices.push_back(&*new_arg_iter);
251 ++new_arg_iter;
252 }
253 }
254 ptr = builder.CreateGEP(ptr, indices);
255 }
256
257 // Now replace the use of the argument with the result GEP.
258 Value *remapped_arg = remapped.lookup(info.arg);
259 remapped_arg->replaceAllUsesWith(ptr);
260 }
261
262 // Replace the call with a call to the newly specialized function.
263 SmallVector<Value *, 16> new_args;
264 for (size_t i = 0; i < call->getNumOperands() - 1; ++i) {
265 new_args.push_back(call->getOperand(i));
266 }
267 for (auto extra : extra_args) {
268 new_args.push_back(extra);
269 }
270 auto *replacement = CallInst::Create(clone, new_args, "", call);
271 call->replaceAllUsesWith(replacement);
272 call->eraseFromParent();
273}
274
275Function *MultiVersionUBOFunctionsPass::AddExtraArguments(
276 Function *fn, const std::vector<Value *> &extra_args) {
277 // Generate the new function type.
278 SmallVector<Type *, 8> arg_types;
279 for (auto &arg : fn->args()) {
280 arg_types.push_back(arg.getType());
281 }
282 for (auto v : extra_args) {
283 arg_types.push_back(v->getType());
284 }
285 FunctionType *new_type =
286 FunctionType::get(fn->getReturnType(), arg_types, fn->isVarArg());
287
288 // Insert the new function and copy calling conv, attributes and metadata.
289 auto *module = fn->getParent();
290 fn->removeFromParent();
291 auto pair =
292 module->getOrInsertFunction(fn->getName(), new_type, fn->getAttributes());
293 Function *new_function = cast<Function>(pair.getCallee());
294 new_function->setCallingConv(fn->getCallingConv());
295 new_function->copyMetadata(fn, 0);
296
297 // Move the basic blocks into the new function
298 if (!fn->isDeclaration()) {
299 std::vector<BasicBlock *> blocks;
300 for (auto &BB : *fn) {
301 blocks.push_back(&BB);
302 }
303 for (auto *BB : blocks) {
304 BB->removeFromParent();
305 BB->insertInto(new_function);
306 }
307 }
308
309 // Replace arg uses.
310 for (auto old_arg_iter = fn->arg_begin(),
311 new_arg_iter = new_function->arg_begin();
312 old_arg_iter != fn->arg_end(); ++old_arg_iter, ++new_arg_iter) {
313 old_arg_iter->replaceAllUsesWith(&*new_arg_iter);
314 }
315
316 // There are no calls to |fn| yet so we don't need to worry about updating
317 // calls.
318
319 delete fn;
320 return new_function;
321}