blob: ab1f267648efacbf5d7b3d5e1e718aa812ebc9e1 [file] [log] [blame]
David Netoc5fb5242018-07-30 13:28:31 -04001// Copyright 2018 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <climits>
16#include <map>
17#include <set>
18#include <utility>
19#include <vector>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/UniqueVector.h"
24#include "llvm/IR/Constants.h"
25#include "llvm/IR/DerivedTypes.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/Module.h"
30#include "llvm/Pass.h"
31#include "llvm/Support/raw_ostream.h"
32
33#include "clspv/Option.h"
34#include "clspv/Passes.h"
35
36#include "ArgKind.h"
37
38using namespace llvm;
39
40#define DEBUG_TYPE "directresourceaccess"
41
42namespace {
43
44cl::opt<bool> ShowDRA("show-dra", cl::init(false), cl::Hidden,
45 cl::desc("Show direct resource access details"));
46
47using SamplerMapType = llvm::ArrayRef<std::pair<unsigned, std::string>>;
48
49class DirectResourceAccessPass final : public ModulePass {
50public:
51 static char ID;
52 DirectResourceAccessPass() : ModulePass(ID) {}
53 bool runOnModule(Module &M) override;
54
55private:
56 // Return the functions reachable from entry point functions, where
57 // callers appear before callees. OpenCL C does not permit recursion
58 // or function or pointers, so this is always well defined. The ordering
59 // should be reproducible from one run to the next.
60 UniqueVector<Function *> CallGraphOrderedFunctions(Module &);
61
62 // For each kernel argument that will map to a resource variable (descriptor),
63 // try to rewrite the uses of the argument as a direct access of the resource.
64 // We can only do this if all the callees of the function use the same
65 // resource access value for that argument. Returns true if the module
66 // changed.
67 bool RewriteResourceAccesses(Function *fn);
68
69 // Rewrite uses of this resrouce-based arg if all the callers pass in the
70 // same resource access. Returns true if the module changed.
71 bool RewriteAccessesForArg(Function *fn, int arg_index, Argument &arg);
72};
73} // namespace
74
75char DirectResourceAccessPass::ID = 0;
76static RegisterPass<DirectResourceAccessPass> X("DirectResourceAccessPass",
77 "Direct resource access");
78
79namespace clspv {
80ModulePass *createDirectResourceAccessPass() {
81 return new DirectResourceAccessPass();
82}
83} // namespace clspv
84
85namespace {
86bool DirectResourceAccessPass::runOnModule(Module &M) {
87 bool Changed = false;
88
89 if (clspv::Option::DirectResourceAccess()) {
90 auto ordered_functions = CallGraphOrderedFunctions(M);
91 for (auto *fn : ordered_functions) {
92 Changed |= RewriteResourceAccesses(fn);
93 }
94 }
95
96 return Changed;
97}
98
99UniqueVector<Function *>
100DirectResourceAccessPass::CallGraphOrderedFunctions(Module &M) {
101 // Use a topological sort.
102
103 // Make an ordered list of all functions having bodies, with kernel entry
104 // points listed first.
105 UniqueVector<Function *> functions;
106 SmallVector<Function *, 10> entry_points;
107 for (Function &F : M) {
108 if (F.isDeclaration()) {
109 continue;
110 }
111 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
112 functions.insert(&F);
113 entry_points.push_back(&F);
114 }
115 }
116 // Add the remaining functions.
117 for (Function &F : M) {
118 if (F.isDeclaration()) {
119 continue;
120 }
121 if (F.getCallingConv() != CallingConv::SPIR_KERNEL) {
122 functions.insert(&F);
123 }
124 }
125
126 // This will be a complete set of reveresed edges, i.e. with all pairs
127 // of (callee, caller).
128 using Edge = std::pair<unsigned, unsigned>;
129 auto make_edge = [&functions](Function *callee, Function *caller) {
130 return std::pair<unsigned, unsigned>{functions.idFor(callee),
131 functions.idFor(caller)};
132 };
133 std::set<Edge> reverse_edges;
134 // Map each function to the functions it calls, and populate |reverse_edges|.
135 std::map<Function *, SmallVector<Function *, 3>> calls_functions;
136 for (Function *callee : functions) {
137 for (auto &use : callee->uses()) {
138 if (auto *call = dyn_cast<CallInst>(use.getUser())) {
139 Function *caller = call->getParent()->getParent();
140 calls_functions[caller].push_back(callee);
141 reverse_edges.insert(make_edge(callee, caller));
142 }
143 }
144 }
145 // Sort the callees in module-order. This helps us produce a deterministic
146 // result.
147 for (auto &pair : calls_functions) {
148 auto &callees = pair.second;
149 std::sort(callees.begin(), callees.end(),
150 [&functions](Function *lhs, Function *rhs) {
151 return functions.idFor(lhs) < functions.idFor(rhs);
152 });
153 }
154
155 // Use Kahn's algorithm for topoological sort.
156 UniqueVector<Function *> result;
157 SmallVector<Function *, 10> work_list(entry_points.begin(),
158 entry_points.end());
159 while (!work_list.empty()) {
160 Function *caller = work_list.back();
161 work_list.pop_back();
162 result.insert(caller);
163 auto &callees = calls_functions[caller];
164 for (auto *callee : callees) {
165 reverse_edges.erase(make_edge(callee, caller));
166 auto lower_bound = reverse_edges.lower_bound(make_edge(callee, nullptr));
167 if (lower_bound == reverse_edges.end() ||
168 lower_bound->first != functions.idFor(callee)) {
169 // Callee has no other unvisited callers.
170 work_list.push_back(callee);
171 }
172 }
173 }
174 // If reverse_edges is not empty then there was a cycle. But we don't care
175 // about that erroneous case.
176
177 if (ShowDRA) {
178 outs() << "DRA: Ordered functions:\n";
179 for (Function *fun : result) {
180 outs() << "DRA: " << fun->getName() << "\n";
181 }
182 }
183 return result;
184}
185
186bool DirectResourceAccessPass::RewriteResourceAccesses(Function *fn) {
187 bool Changed = false;
188 int arg_index = 0;
189 for (Argument &arg : fn->args()) {
190 switch (clspv::GetArgKindForType(arg.getType())) {
191 case clspv::ArgKind::Buffer:
192 case clspv::ArgKind::ReadOnlyImage:
193 case clspv::ArgKind::WriteOnlyImage:
194 case clspv::ArgKind::Sampler:
195 Changed |= RewriteAccessesForArg(fn, arg_index, arg);
196 break;
David Neto3a0df832018-08-03 14:35:42 -0400197 default:
198 // Should not happen
199 break;
David Netoc5fb5242018-07-30 13:28:31 -0400200 }
201 arg_index++;
202 }
203 return Changed;
204}
205
206bool DirectResourceAccessPass::RewriteAccessesForArg(Function *fn,
207 int arg_index,
208 Argument &arg) {
209 bool Changed = false;
210 if (fn->use_empty()) {
211 return false;
212 }
213
214 // We can convert a parameter to a direct resource access if it is
215 // either a direct call to a clspv.resource.var.* or if it a GEP of
216 // such a thing (where the GEP can only have zero indices).
217 struct ParamInfo {
218 // The resource-access builtin function. (@clspv.resource.var.*)
219 Function *var_fn;
220 // The descriptor set.
221 uint32_t set;
222 // The binding.
223 uint32_t binding;
224 // If the parameter is a GEP, then this is the number of zero-indices
225 // the GEP used.
226 unsigned num_gep_zeroes;
227 // An example call fitting
228 CallInst *sample_call;
229 };
230 // The common valid parameter info across all the callers seen soo far.
231
232 bool seen_one = false;
233 ParamInfo common;
234 // Tries to merge the given parameter info into |common|. If it is the first
235 // time we've tried, then save it. Returns true if there is no conflict.
236 auto merge_param_info = [&seen_one, &common](const ParamInfo &pi) {
237 if (!seen_one) {
238 common = pi;
239 seen_one = true;
240 return true;
241 }
242 return pi.var_fn == common.var_fn && pi.set == common.set &&
243 pi.binding == common.binding &&
244 pi.num_gep_zeroes == common.num_gep_zeroes;
245 };
246
247 for (auto &use : fn->uses()) {
248 if (auto *caller = dyn_cast<CallInst>(use.getUser())) {
249 Value *value = caller->getArgOperand(arg_index);
250 // We care about two cases:
251 // - a direct call to clspv.resource.var.*
252 // - a GEP with only zero indices, where the base pointer is
253
254 // Unpack GEPs with zeros, if we can. Rewrite |value| as we go along.
255 unsigned num_gep_zeroes = 0;
David Neto2f450002018-08-01 16:13:03 -0400256 bool first_gep = true;
David Netoc5fb5242018-07-30 13:28:31 -0400257 for (auto *gep = dyn_cast<GetElementPtrInst>(value); gep;
258 gep = dyn_cast<GetElementPtrInst>(value)) {
259 if (!gep->hasAllZeroIndices()) {
260 return false;
261 }
David Neto2f450002018-08-01 16:13:03 -0400262 // If not the first GEP, then ignore the "element" index (which I call
263 // "slide") since that will be combined with the last index of the
264 // previous GEP.
265 num_gep_zeroes += gep->getNumIndices() + (first_gep ? 0 : -1);
David Netoc5fb5242018-07-30 13:28:31 -0400266 value = gep->getPointerOperand();
David Neto2f450002018-08-01 16:13:03 -0400267 first_gep = false;
David Netoc5fb5242018-07-30 13:28:31 -0400268 }
269 if (auto *call = dyn_cast<CallInst>(value)) {
270 // If the call is a call to a @clspv.resource.var.* function, then try
271 // to merge it, assuming the given number of GEP zero-indices so far.
272 if (call->getCalledFunction()->getName().startswith(
273 "clspv.resource.var.")) {
274 const auto set = uint32_t(
275 dyn_cast<ConstantInt>(call->getOperand(0))->getZExtValue());
276 const auto binding = uint32_t(
277 dyn_cast<ConstantInt>(call->getOperand(1))->getZExtValue());
278 if (!merge_param_info({call->getCalledFunction(), set, binding,
279 num_gep_zeroes, call})) {
280 return false;
281 }
282 } else {
283 // A call but not to a resource access builtin function.
284 return false;
285 }
286 } else {
287 // Not a call.
288 return false;
289 }
290 } else {
291 // There isn't enough commonality. Bail out without changing anything.
292 return false;
293 }
294 }
295 if (ShowDRA) {
296 if (seen_one) {
297 outs() << "DRA: Rewrite " << fn->getName() << " arg " << arg_index << " "
298 << arg.getName() << ": " << common.var_fn->getName() << " ("
299 << common.set << "," << common.binding
David Neto2f450002018-08-01 16:13:03 -0400300 << ") zeroes: " << common.num_gep_zeroes << " sample-call "
301 << *(common.sample_call) << "\n";
David Netoc5fb5242018-07-30 13:28:31 -0400302 }
303 }
304
305 // Now rewrite the argument, using the info in |common|.
306
307 Changed = true;
308 IRBuilder<> Builder(fn->getParent()->getContext());
309 auto *zero = Builder.getInt32(0);
310 Builder.SetInsertPoint(fn->getEntryBlock().getFirstNonPHI());
311
312 // Create the call.
313 SmallVector<Value *, 8> args(common.sample_call->arg_begin(),
314 common.sample_call->arg_end());
315 Value *replacement = Builder.CreateCall(common.var_fn, args);
316 if (ShowDRA) {
317 outs() << "DRA: Replace: call " << *replacement << "\n";
318 }
319 if (common.num_gep_zeroes) {
320 SmallVector<Value *, 3> zeroes;
321 for (unsigned i = 0; i < common.num_gep_zeroes; i++) {
322 zeroes.push_back(zero);
323 }
324 replacement = Builder.CreateGEP(replacement, zeroes);
325 if (ShowDRA) {
326 outs() << "DRA: Replace: gep " << *replacement << "\n";
327 }
328 }
329 arg.replaceAllUsesWith(replacement);
330
331 return Changed;
332}
333
334} // namespace