blob: 6f1265ceaa5c7338bd34a2b728b54d8b114a6e77 [file] [log] [blame]
alan-bakera71f1932019-04-11 11:04:34 -04001// Copyright 2019 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <climits>
16#include <map>
17#include <set>
18#include <utility>
19#include <vector>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/UniqueVector.h"
25#include "llvm/IR/CallingConv.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/raw_ostream.h"
31#include "llvm/Transforms/Utils/Cloning.h"
32
33#include "clspv/Passes.h"
34
35#include "ArgKind.h"
SJW61531372020-06-09 07:31:08 -050036#include "Builtins.h"
alan-bakera71f1932019-04-11 11:04:34 -040037#include "CallGraphOrderedFunctions.h"
38#include "Constants.h"
39
40using namespace llvm;
41
42namespace {
43
44class MultiVersionUBOFunctionsPass final : public ModulePass {
45public:
46 static char ID;
47 MultiVersionUBOFunctionsPass() : ModulePass(ID) {}
48 bool runOnModule(Module &M) override;
49
50private:
51 // Struct for tracking specialization information.
52 struct ResourceInfo {
53 // The specific argument.
54 Argument *arg;
55 // The resource var base call.
56 CallInst *base;
57 // Series of GEPs that operate on |base|.
58 std::vector<GetElementPtrInst *> indices;
59 };
60
61 // Analyzes the call, |user|, to |fn| in terms of its UBO arguments. Returns
62 // true if |user| can be transformed into a specialized function.
63 //
64 // Currently, this function is only successful in analyzing GEP chains to a
65 // resource variable.
66 bool AnalyzeCall(Function *fn, CallInst *user,
67 std::vector<ResourceInfo> *resources);
68
69 // Inlines |call|.
70 void InlineCallSite(CallInst *call);
71
72 // Transforms the call to |fn| into a specialized call based on |resources|.
73 // Replaces |call| with a call to the specialized version.
74 void SpecializeCall(Function *fn, CallInst *call,
75 const std::vector<ResourceInfo> &resources, size_t id);
76
77 // Adds extra arguments to |fn| by rebuilding the entire function.
78 Function *AddExtraArguments(Function *fn,
79 const std::vector<Value *> &extra_args);
80};
81
82} // namespace
83
84char MultiVersionUBOFunctionsPass::ID = 0;
85static RegisterPass<MultiVersionUBOFunctionsPass>
86 X("MultiVersionUBOFunctionsPass",
87 "Multi-version functions with UBO params");
88
89namespace clspv {
90ModulePass *createMultiVersionUBOFunctionsPass() {
91 return new MultiVersionUBOFunctionsPass();
92}
93} // namespace clspv
94
95bool MultiVersionUBOFunctionsPass::runOnModule(Module &M) {
96 bool changed = false;
97 UniqueVector<Function *> ordered_functions =
98 clspv::CallGraphOrderedFunctions(M);
99
100 for (auto fn : ordered_functions) {
101 // Kernels don't need modified.
102 if (fn->isDeclaration() || fn->getCallingConv() == CallingConv::SPIR_KERNEL)
103 continue;
104
105 bool local_changed = false;
106 size_t count = 0;
alan-baker973ba8d2019-07-16 19:16:19 -0400107 SmallVector<User *, 8> users(fn->users());
108 for (auto user : users) {
alan-bakera71f1932019-04-11 11:04:34 -0400109 if (auto call = dyn_cast<CallInst>(user)) {
110 std::vector<ResourceInfo> resources;
111 if (AnalyzeCall(fn, call, &resources)) {
112 if (!resources.empty()) {
113 local_changed = true;
114 SpecializeCall(fn, call, resources, count++);
115 }
116 } else {
117 local_changed = true;
118 InlineCallSite(call);
119 }
120 }
121 }
122
123 fn->removeDeadConstantUsers();
124 if (local_changed) {
125 // All calls to this function were either specialized or inlined.
126 fn->eraseFromParent();
127 }
128 changed |= local_changed;
129 }
130
131 return changed;
132}
133
134bool MultiVersionUBOFunctionsPass::AnalyzeCall(
135 Function *fn, CallInst *user, std::vector<ResourceInfo> *resources) {
136 for (auto &arg : fn->args()) {
alan-bakerc4579bb2020-04-29 14:15:50 -0400137 if (clspv::GetArgKind(arg) != clspv::ArgKind::BufferUBO)
alan-bakera71f1932019-04-11 11:04:34 -0400138 continue;
139
140 Value *arg_operand = user->getOperand(arg.getArgNo());
141 ResourceInfo info;
142 info.arg = &arg;
143
144 DenseSet<Value *> visited;
145 std::vector<Value *> stack;
146 stack.push_back(arg_operand);
147
148 while (!stack.empty()) {
149 Value *value = stack.back();
150 stack.pop_back();
151
152 if (!visited.insert(value).second)
153 continue;
154
155 if (CallInst *call = dyn_cast<CallInst>(value)) {
SJW61531372020-06-09 07:31:08 -0500156 auto &func_info = clspv::Builtins::Lookup(call->getCalledFunction());
157 if (func_info.getType() == clspv::Builtins::kClspvResource) {
alan-bakera71f1932019-04-11 11:04:34 -0400158 info.base = call;
159 } else {
160 // Unknown function call returning a constant pointer requires
161 // inlining.
162 return false;
163 }
164 } else if (auto gep = dyn_cast<GetElementPtrInst>(value)) {
165 info.indices.push_back(gep);
166 stack.push_back(gep->getOperand(0));
167 } else {
168 // Unhandled instruction requires inlining.
169 return false;
170 }
171 }
172
173 resources->push_back(std::move(info));
174 }
175
176 return true;
177}
178
179void MultiVersionUBOFunctionsPass::InlineCallSite(CallInst *call) {
180 InlineFunctionInfo IFI;
alan-baker741fd1f2020-04-14 17:38:15 -0400181 InlineFunction(*call, IFI, nullptr, false);
alan-bakera71f1932019-04-11 11:04:34 -0400182}
183
184void MultiVersionUBOFunctionsPass::SpecializeCall(
185 Function *fn, CallInst *call, const std::vector<ResourceInfo> &resources,
186 size_t id) {
187
188 // The basis of the specialization is a clone of |fn|, however, the clone may
189 // need rebuilt in order to receive extra arguments.
190 ValueToValueMapTy remapped;
191 auto *clone = CloneFunction(fn, remapped);
192 std::string name;
193 raw_string_ostream str(name);
194 str << fn->getName() << "_clspv_" << id;
195 clone->setName(str.str());
196
197 std::vector<Value *> extra_args;
198 for (auto info : resources) {
199 // Must traverse the GEPs in reverse order to match how the code will be
200 // generated below so that the iterator for the extra arguments is
201 // consistent.
202 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
203 ++iter) {
204 // Skip pointer operand.
205 auto *idx = *iter;
206 for (size_t i = 1; i < idx->getNumOperands(); ++i) {
207 Value *operand = idx->getOperand(i);
208 if (!isa<Constant>(operand)) {
209 extra_args.push_back(operand);
210 }
211 }
212 }
213 }
214
215 if (!extra_args.empty()) {
216 // Need to add extra arguments to this function.
217 clone = AddExtraArguments(clone, extra_args);
218 }
219
220 auto where = clone->begin()->begin();
221 while (isa<AllocaInst>(where)) {
222 ++where;
223 }
224
225 IRBuilder<> builder(&*where);
226 auto new_arg_iter = clone->arg_begin();
alan-baker4a757f62020-04-22 08:17:49 -0400227 for (size_t i = 0; i < fn->arg_size(); ++i) {
alan-bakera71f1932019-04-11 11:04:34 -0400228 ++new_arg_iter;
229 }
230 for (auto info : resources) {
231 // Create the resource var function.
232 SmallVector<Value *, 8> operands;
233 for (size_t i = 0; i < info.base->getNumOperands() - 1; ++i)
234 operands.push_back(info.base->getOperand(i));
235 CallInst *resource_fn =
236 builder.CreateCall(info.base->getCalledFunction(), operands);
237
238 // Create the chain of GEPs. Traversed in reverse order because we added
239 // them from use to def.
240 Value *ptr = resource_fn;
241 for (auto iter = info.indices.rbegin(); iter != info.indices.rend();
242 ++iter) {
243 SmallVector<Value *, 8> indices;
244 for (size_t i = 1; i != (*iter)->getNumOperands(); ++i) {
245 Value *operand = (*iter)->getOperand(i);
246 if (isa<Constant>(operand)) {
247 indices.push_back(operand);
248 } else {
249 // Each extra argument is unique so the iterator is "consumed".
250 indices.push_back(&*new_arg_iter);
251 ++new_arg_iter;
252 }
253 }
254 ptr = builder.CreateGEP(ptr, indices);
255 }
256
257 // Now replace the use of the argument with the result GEP.
258 Value *remapped_arg = remapped.lookup(info.arg);
259 remapped_arg->replaceAllUsesWith(ptr);
260 }
261
262 // Replace the call with a call to the newly specialized function.
263 SmallVector<Value *, 16> new_args;
264 for (size_t i = 0; i < call->getNumOperands() - 1; ++i) {
265 new_args.push_back(call->getOperand(i));
266 }
267 for (auto extra : extra_args) {
268 new_args.push_back(extra);
269 }
270 auto *replacement = CallInst::Create(clone, new_args, "", call);
271 call->replaceAllUsesWith(replacement);
272 call->eraseFromParent();
273}
274
275Function *MultiVersionUBOFunctionsPass::AddExtraArguments(
276 Function *fn, const std::vector<Value *> &extra_args) {
277 // Generate the new function type.
278 SmallVector<Type *, 8> arg_types;
279 for (auto &arg : fn->args()) {
280 arg_types.push_back(arg.getType());
281 }
282 for (auto v : extra_args) {
283 arg_types.push_back(v->getType());
284 }
285 FunctionType *new_type =
286 FunctionType::get(fn->getReturnType(), arg_types, fn->isVarArg());
287
288 // Insert the new function and copy calling conv, attributes and metadata.
289 auto *module = fn->getParent();
290 fn->removeFromParent();
291 auto pair =
292 module->getOrInsertFunction(fn->getName(), new_type, fn->getAttributes());
293 Function *new_function = cast<Function>(pair.getCallee());
294 new_function->setCallingConv(fn->getCallingConv());
295 new_function->copyMetadata(fn, 0);
296
297 // Move the basic blocks into the new function
298 if (!fn->isDeclaration()) {
299 std::vector<BasicBlock *> blocks;
300 for (auto &BB : *fn) {
301 blocks.push_back(&BB);
302 }
303 for (auto *BB : blocks) {
304 BB->removeFromParent();
305 BB->insertInto(new_function);
306 }
307 }
308
309 // Replace arg uses.
310 for (auto old_arg_iter = fn->arg_begin(),
311 new_arg_iter = new_function->arg_begin();
312 old_arg_iter != fn->arg_end(); ++old_arg_iter, ++new_arg_iter) {
313 old_arg_iter->replaceAllUsesWith(&*new_arg_iter);
314 }
315
316 // There are no calls to |fn| yet so we don't need to worry about updating
317 // calls.
318
319 delete fn;
320 return new_function;
321}