blob: d67cb29f2382cb053cb71565d57fecc1bf03cca9 [file] [log] [blame]
alan-bakera71f1932019-04-11 11:04:34 -04001// Copyright 2019 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <climits>
16#include <map>
17#include <set>
18#include <utility>
19#include <vector>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/UniqueVector.h"
25#include "llvm/IR/CallingConv.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/raw_ostream.h"
31#include "llvm/Transforms/Utils/Cloning.h"
32
33#include "clspv/Passes.h"
34
35#include "ArgKind.h"
36#include "CallGraphOrderedFunctions.h"
37#include "Constants.h"
38
39using namespace llvm;
40
41namespace {
42
43class MultiVersionUBOFunctionsPass final : public ModulePass {
44public:
45 static char ID;
46 MultiVersionUBOFunctionsPass() : ModulePass(ID) {}
47 bool runOnModule(Module &M) override;
48
49private:
50 // Struct for tracking specialization information.
51 struct ResourceInfo {
52 // The specific argument.
53 Argument *arg;
54 // The resource var base call.
55 CallInst *base;
56 // Series of GEPs that operate on |base|.
57 std::vector<GetElementPtrInst *> indices;
58 };
59
60 // Analyzes the call, |user|, to |fn| in terms of its UBO arguments. Returns
61 // true if |user| can be transformed into a specialized function.
62 //
63 // Currently, this function is only successful in analyzing GEP chains to a
64 // resource variable.
65 bool AnalyzeCall(Function *fn, CallInst *user,
66 std::vector<ResourceInfo> *resources);
67
68 // Inlines |call|.
69 void InlineCallSite(CallInst *call);
70
71 // Transforms the call to |fn| into a specialized call based on |resources|.
72 // Replaces |call| with a call to the specialized version.
73 void SpecializeCall(Function *fn, CallInst *call,
74 const std::vector<ResourceInfo> &resources, size_t id);
75
76 // Adds extra arguments to |fn| by rebuilding the entire function.
77 Function *AddExtraArguments(Function *fn,
78 const std::vector<Value *> &extra_args);
79};
80
81} // namespace
82
83char MultiVersionUBOFunctionsPass::ID = 0;
84static RegisterPass<MultiVersionUBOFunctionsPass>
85 X("MultiVersionUBOFunctionsPass",
86 "Multi-version functions with UBO params");
87
88namespace clspv {
89ModulePass *createMultiVersionUBOFunctionsPass() {
90 return new MultiVersionUBOFunctionsPass();
91}
92} // namespace clspv
93
94bool MultiVersionUBOFunctionsPass::runOnModule(Module &M) {
95 bool changed = false;
96 UniqueVector<Function *> ordered_functions =
97 clspv::CallGraphOrderedFunctions(M);
98
99 for (auto fn : ordered_functions) {
100 // Kernels don't need modified.
101 if (fn->isDeclaration() || fn->getCallingConv() == CallingConv::SPIR_KERNEL)
102 continue;
103
104 bool local_changed = false;
105 size_t count = 0;
alan-baker973ba8d2019-07-16 19:16:19 -0400106 SmallVector<User *, 8> users(fn->users());
107 for (auto user : users) {
alan-bakera71f1932019-04-11 11:04:34 -0400108 if (auto call = dyn_cast<CallInst>(user)) {
109 std::vector<ResourceInfo> resources;
110 if (AnalyzeCall(fn, call, &resources)) {
111 if (!resources.empty()) {
112 local_changed = true;
113 SpecializeCall(fn, call, resources, count++);
114 }
115 } else {
116 local_changed = true;
117 InlineCallSite(call);
118 }
119 }
120 }
121
122 fn->removeDeadConstantUsers();
123 if (local_changed) {
124 // All calls to this function were either specialized or inlined.
125 fn->eraseFromParent();
126 }
127 changed |= local_changed;
128 }
129
130 return changed;
131}
132
133bool MultiVersionUBOFunctionsPass::AnalyzeCall(
134 Function *fn, CallInst *user, std::vector<ResourceInfo> *resources) {
135 for (auto &arg : fn->args()) {
alan-bakerc4579bb2020-04-29 14:15:50 -0400136 if (clspv::GetArgKind(arg) != clspv::ArgKind::BufferUBO)
alan-bakera71f1932019-04-11 11:04:34 -0400137 continue;
138
139 Value *arg_operand = user->getOperand(arg.getArgNo());
140 ResourceInfo info;
141 info.arg = &arg;
142
143 DenseSet<Value *> visited;
144 std::vector<Value *> stack;
145 stack.push_back(arg_operand);
146
147 while (!stack.empty()) {
148 Value *value = stack.back();
149 stack.pop_back();
150
151 if (!visited.insert(value).second)
152 continue;
153
154 if (CallInst *call = dyn_cast<CallInst>(value)) {
155 if (call->getCalledFunction()->getName().startswith(
156 clspv::ResourceAccessorFunction())) {
157 info.base = call;
158 } else {
159 // Unknown function call returning a constant pointer requires
160 // inlining.
161 return false;
162 }
163 } else if (auto gep = dyn_cast<GetElementPtrInst>(value)) {
164 info.indices.push_back(gep);
165 stack.push_back(gep->getOperand(0));
166 } else {
167 // Unhandled instruction requires inlining.
168 return false;
169 }
170 }
171
172 resources->push_back(std::move(info));
173 }
174
175 return true;
176}
177
178void MultiVersionUBOFunctionsPass::InlineCallSite(CallInst *call) {
179 InlineFunctionInfo IFI;
alan-baker741fd1f2020-04-14 17:38:15 -0400180 InlineFunction(*call, IFI, nullptr, false);
alan-bakera71f1932019-04-11 11:04:34 -0400181}
182
183void MultiVersionUBOFunctionsPass::SpecializeCall(
184 Function *fn, CallInst *call, const std::vector<ResourceInfo> &resources,
185 size_t id) {
186
187 // The basis of the specialization is a clone of |fn|, however, the clone may
188 // need rebuilt in order to receive extra arguments.
189 ValueToValueMapTy remapped;
190 auto *clone = CloneFunction(fn, remapped);
191 std::string name;
192 raw_string_ostream str(name);
193 str << fn->getName() << "_clspv_" << id;
194 clone->setName(str.str());
195
196 std::vector<Value *> extra_args;
197 for (auto info : resources) {
198 // Must traverse the GEPs in reverse order to match how the code will be
199 // generated below so that the iterator for the extra arguments is
200 // consistent.
201 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
202 ++iter) {
203 // Skip pointer operand.
204 auto *idx = *iter;
205 for (size_t i = 1; i < idx->getNumOperands(); ++i) {
206 Value *operand = idx->getOperand(i);
207 if (!isa<Constant>(operand)) {
208 extra_args.push_back(operand);
209 }
210 }
211 }
212 }
213
214 if (!extra_args.empty()) {
215 // Need to add extra arguments to this function.
216 clone = AddExtraArguments(clone, extra_args);
217 }
218
219 auto where = clone->begin()->begin();
220 while (isa<AllocaInst>(where)) {
221 ++where;
222 }
223
224 IRBuilder<> builder(&*where);
225 auto new_arg_iter = clone->arg_begin();
alan-baker4a757f62020-04-22 08:17:49 -0400226 for (size_t i = 0; i < fn->arg_size(); ++i) {
alan-bakera71f1932019-04-11 11:04:34 -0400227 ++new_arg_iter;
228 }
229 for (auto info : resources) {
230 // Create the resource var function.
231 SmallVector<Value *, 8> operands;
232 for (size_t i = 0; i < info.base->getNumOperands() - 1; ++i)
233 operands.push_back(info.base->getOperand(i));
234 CallInst *resource_fn =
235 builder.CreateCall(info.base->getCalledFunction(), operands);
236
237 // Create the chain of GEPs. Traversed in reverse order because we added
238 // them from use to def.
239 Value *ptr = resource_fn;
240 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
241 ++iter) {
242 SmallVector<Value *, 8> indices;
243 for (size_t i = 1; i != (*iter)->getNumOperands(); ++i) {
244 Value *operand = (*iter)->getOperand(i);
245 if (isa<Constant>(operand)) {
246 indices.push_back(operand);
247 } else {
248 // Each extra argument is unique so the iterator is "consumed".
249 indices.push_back(&*new_arg_iter);
250 ++new_arg_iter;
251 }
252 }
253 ptr = builder.CreateGEP(ptr, indices);
254 }
255
256 // Now replace the use of the argument with the result GEP.
257 Value *remapped_arg = remapped.lookup(info.arg);
258 remapped_arg->replaceAllUsesWith(ptr);
259 }
260
261 // Replace the call with a call to the newly specialized function.
262 SmallVector<Value *, 16> new_args;
263 for (size_t i = 0; i < call->getNumOperands() - 1; ++i) {
264 new_args.push_back(call->getOperand(i));
265 }
266 for (auto extra : extra_args) {
267 new_args.push_back(extra);
268 }
269 auto *replacement = CallInst::Create(clone, new_args, "", call);
270 call->replaceAllUsesWith(replacement);
271 call->eraseFromParent();
272}
273
274Function *MultiVersionUBOFunctionsPass::AddExtraArguments(
275 Function *fn, const std::vector<Value *> &extra_args) {
276 // Generate the new function type.
277 SmallVector<Type *, 8> arg_types;
278 for (auto &arg : fn->args()) {
279 arg_types.push_back(arg.getType());
280 }
281 for (auto v : extra_args) {
282 arg_types.push_back(v->getType());
283 }
284 FunctionType *new_type =
285 FunctionType::get(fn->getReturnType(), arg_types, fn->isVarArg());
286
287 // Insert the new function and copy calling conv, attributes and metadata.
288 auto *module = fn->getParent();
289 fn->removeFromParent();
290 auto pair =
291 module->getOrInsertFunction(fn->getName(), new_type, fn->getAttributes());
292 Function *new_function = cast<Function>(pair.getCallee());
293 new_function->setCallingConv(fn->getCallingConv());
294 new_function->copyMetadata(fn, 0);
295
296 // Move the basic blocks into the new function
297 if (!fn->isDeclaration()) {
298 std::vector<BasicBlock *> blocks;
299 for (auto &BB : *fn) {
300 blocks.push_back(&BB);
301 }
302 for (auto *BB : blocks) {
303 BB->removeFromParent();
304 BB->insertInto(new_function);
305 }
306 }
307
308 // Replace arg uses.
309 for (auto old_arg_iter = fn->arg_begin(),
310 new_arg_iter = new_function->arg_begin();
311 old_arg_iter != fn->arg_end(); ++old_arg_iter, ++new_arg_iter) {
312 old_arg_iter->replaceAllUsesWith(&*new_arg_iter);
313 }
314
315 // There are no calls to |fn| yet so we don't need to worry about updating
316 // calls.
317
318 delete fn;
319 return new_function;
320}