blob: 97c000a6c9d961c38187fa02bedd9fbbd7063ae7 [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
David Neto8e39bd12017-11-14 21:08:12 -050016#include <utility>
David Netoab03f432017-11-03 17:00:44 -040017
18#include <llvm/ADT/DenseMap.h>
19#include <llvm/ADT/SmallVector.h>
David Neto8e39bd12017-11-14 21:08:12 -050020#include <llvm/ADT/UniqueVector.h>
David Netoab03f432017-11-03 17:00:44 -040021#include <llvm/IR/Attributes.h>
David Neto8e39bd12017-11-14 21:08:12 -050022#include <llvm/IR/Constants.h>
23#include <llvm/IR/DerivedTypes.h>
David Netoab03f432017-11-03 17:00:44 -040024#include <llvm/IR/DerivedTypes.h>
25#include <llvm/IR/Function.h>
26#include <llvm/IR/Instructions.h>
27#include <llvm/IR/Module.h>
28#include <llvm/Pass.h>
29#include <llvm/Support/raw_ostream.h>
30
David Neto482550a2018-03-24 05:21:07 -070031#include "clspv/Option.h"
32
David Netoab03f432017-11-03 17:00:44 -040033using namespace llvm;
34using std::string;
35
David Neto8e39bd12017-11-14 21:08:12 -050036#define DEBUG_TYPE "rewriteinserts"
David Netoab03f432017-11-03 17:00:44 -040037
David Netoab03f432017-11-03 17:00:44 -040038namespace {
39
40const char* kCompositeConstructFunctionPrefix = "clspv.composite_construct.";
41
42class RewriteInsertsPass : public ModulePass {
43 public:
44 static char ID;
45 RewriteInsertsPass() : ModulePass(ID) {}
46
47 bool runOnModule(Module &M) override;
48
49 private:
50 using InsertionVector = SmallVector<InsertValueInst *, 4>;
51
David Neto8e39bd12017-11-14 21:08:12 -050052 // Replaces chains of insertions that cover the entire value.
53 // Such a change always reduces the number of instructions, so
54 // we always perform these. Returns true if the module was modified.
55 bool ReplaceCompleteInsertionChains(Module &M);
56
57 // Replaces all InsertValue instructions, even if they aren't part
58 // of a complete insetion chain. Returns true if the module was modified.
59 bool ReplacePartialInsertions(Module &M);
60
61 // Load |values| and |chain| with the members of the struct value produced
62 // by a chain of InsertValue instructions ending with |iv|, and following
63 // the aggregate operand. Return the start of the chain: the aggregate
64 // value which is not an InsertValue instruction, or an InsertValue
65 // instruction which inserts a component that is replaced later in the
66 // chain. The |values| vector will match the order of struct members and
67 // is initialized to all nullptr members. The |chain| vector will list
68 // the chain of InsertValue instructions, listed in the order we discover
69 // them, e.g. begining with |iv|.
70 Value *LoadValuesEndingWithInsertion(InsertValueInst *iv,
71 std::vector<Value *> *values,
72 InsertionVector *chain) {
73 auto *structTy = dyn_cast<StructType>(iv->getType());
74 assert(structTy);
75 const auto numElems = structTy->getNumElements();
76
77 // Walk backward from the tail to an instruction we don't want to
78 // replace.
79 Value *frontier = iv;
80 while (auto *insertion = dyn_cast<InsertValueInst>(frontier)) {
81 chain->push_back(insertion);
82 // Only handle single-index insertions.
83 if (insertion->getNumIndices() == 1) {
84 // Try to replace this one.
85
86 unsigned index = insertion->getIndices()[0];
87 assert(index < numElems);
88 if ((*values)[index] != nullptr) {
89 // We already have a value for this slot. Stop now.
90 break;
91 }
92 (*values)[index] = insertion->getInsertedValueOperand();
93 frontier = insertion->getAggregateOperand();
94 } else {
95 break;
96 }
97 }
98 return frontier;
99 }
100
David Netoab03f432017-11-03 17:00:44 -0400101 // If this is the tail of a chain of InsertValueInst instructions
102 // that covers the entire composite, then return a small vector
103 // containing the insertion instructions, in member order.
104 // Otherwise returns nullptr. Only handle insertions into structs,
105 // not into arrays.
106 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
107 if (iv->getNumIndices() == 1) {
108 if (auto *structTy = dyn_cast<StructType>(iv->getType())) {
109 auto numElems = structTy->getNumElements();
110 // Only handle single-index insertions.
111 unsigned index = iv->getIndices()[0];
112 if (index + 1u != numElems) {
113 // Not the last in the chain.
114 return nullptr;
115 }
116 InsertionVector candidates(numElems, nullptr);
117 for (unsigned i = index;
118 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
119 // iv inserts the i'th member
120 candidates[i] = iv;
121
122 if (i == 0) {
123 // We're done!
124 return new InsertionVector(candidates);
125 }
126
127 if (InsertValueInst *agg =
128 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
129 iv = agg;
130 } else {
131 // The chain is broken.
132 break;
133 }
134 }
135 }
136 }
137 return nullptr;
138 }
139
David Neto8e39bd12017-11-14 21:08:12 -0500140
David Netoab03f432017-11-03 17:00:44 -0400141 // Return the name for the wrap function for the given type.
142 string &WrapFunctionNameForType(Type *type) {
143 auto where = function_for_type_.find(type);
144 if (where == function_for_type_.end()) {
145 // Insert it.
146 auto &result = function_for_type_[type] =
147 string(kCompositeConstructFunctionPrefix) +
148 std::to_string(function_for_type_.size());
149 return result;
150 } else {
151 return where->second;
152 }
153 }
154
David Neto8e39bd12017-11-14 21:08:12 -0500155 // Get or create the composite construct function definition.
156 Function *GetConstructFunction(Module &M, StructType *constructed_type) {
157 // Get or create the composite construct function definition.
158 const string &fn_name = WrapFunctionNameForType(constructed_type);
159 Function *fn = M.getFunction(fn_name);
160 if (!fn) {
161 // Make the function.
162 FunctionType *fnTy = FunctionType::get(
163 constructed_type, constructed_type->elements(), false);
164 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
165 fn = cast<Function>(fn_constant);
166 fn->addFnAttr(Attribute::ReadOnly);
167 fn->addFnAttr(Attribute::ReadNone);
168 }
169 return fn;
170 }
171
David Netoab03f432017-11-03 17:00:44 -0400172 // Maps a loaded type to the name of the wrap function for that type.
173 DenseMap<Type *, string> function_for_type_;
174};
175} // namespace
176
177char RewriteInsertsPass::ID = 0;
178static RegisterPass<RewriteInsertsPass>
179 X("RewriteInserts",
180 "Rewrite chains of insertvalue to as composite-construction");
181
182namespace clspv {
183llvm::ModulePass *createRewriteInsertsPass() {
184 return new RewriteInsertsPass();
185}
186} // namespace clspv
187
188bool RewriteInsertsPass::runOnModule(Module &M) {
David Neto8e39bd12017-11-14 21:08:12 -0500189 bool Changed = ReplaceCompleteInsertionChains(M);
190
David Neto482550a2018-03-24 05:21:07 -0700191 if (clspv::Option::HackInserts()) {
David Neto8e39bd12017-11-14 21:08:12 -0500192 Changed |= ReplacePartialInsertions(M);
193 }
194
195 return Changed;
196}
197
198bool RewriteInsertsPass::ReplaceCompleteInsertionChains(Module &M) {
David Netoab03f432017-11-03 17:00:44 -0400199 bool Changed = false;
200
201 SmallVector<InsertionVector *, 16> WorkList;
202 for (Function &F : M) {
203 for (BasicBlock &BB : F) {
204 for (Instruction &I : BB) {
205 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
206 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
207 WorkList.push_back(insertions);
208 }
209 }
210 }
211 }
212 }
213
214 if (WorkList.size() == 0) {
215 return Changed;
216 }
217
218 for (InsertionVector *insertions : WorkList) {
219 Changed = true;
220
221 // Gather the member values and types.
222 SmallVector<Value*, 4> values;
223 SmallVector<Type*, 4> types;
224 for (InsertValueInst* insert : *insertions) {
225 Value* value = insert->getInsertedValueOperand();
226 values.push_back(value);
227 types.push_back(value->getType());
228 }
229
David Neto8e39bd12017-11-14 21:08:12 -0500230 StructType *resultTy = cast<StructType>(insertions->back()->getType());
231 Function *fn = GetConstructFunction(M, resultTy);
David Netoab03f432017-11-03 17:00:44 -0400232
233 // Replace the chain.
234 auto call = CallInst::Create(fn, values);
235 call->insertAfter(insertions->back());
236 insertions->back()->replaceAllUsesWith(call);
237
238 // Remove the insertions if we can. Go from the tail back to
239 // the head, since the tail uses the previous insertion, etc.
240 for (auto iter = insertions->rbegin(), end = insertions->rend();
241 iter != end; ++iter) {
242 InsertValueInst *insertion = *iter;
243 if (!insertion->hasNUsesOrMore(1)) {
244 insertion->eraseFromParent();
245 }
246 }
247
248 delete insertions;
249 }
250
251 return Changed;
252}
David Neto8e39bd12017-11-14 21:08:12 -0500253
254bool RewriteInsertsPass::ReplacePartialInsertions(Module &M) {
255 bool Changed = false;
256
257 // First find candidates. Collect all InsertValue instructions
258 // into struct type, but track their interdependencies. To minimize
259 // the number of new instructions, generate a construction for each
260 // tail of an insertion chain.
261
262 UniqueVector<InsertValueInst *> insertions;
263 for (Function &F : M) {
264 for (BasicBlock &BB : F) {
265 for (Instruction &I : BB) {
266 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
267 if (iv->getType()->isStructTy()) {
268 insertions.insert(iv);
269 }
270 }
271 }
272 }
273 }
274
275 // Now count how many times each InsertValue is used by another InsertValue.
276 // The |num_uses| vector is indexed by the unique id that |insertions|
277 // assigns to it.
278 std::vector<unsigned> num_uses(insertions.size() + 1);
279 // Count from the user's perspective.
280 for (InsertValueInst *insertion : insertions) {
281 if (auto *agg =
282 dyn_cast<InsertValueInst>(insertion->getAggregateOperand())) {
283 ++(num_uses[insertions.idFor(agg)]);
284 }
285 }
286
287 // Proceed in rounds. Each round rewrites any chains ending with an
288 // insertion that is not used by another insertion.
289
290 // Get the first list of insertion tails.
291 InsertionVector WorkList;
292 for (InsertValueInst *insertion : insertions) {
293 if (num_uses[insertions.idFor(insertion)] == 0) {
294 WorkList.push_back(insertion);
295 }
296 }
297
298 // This records insertions in the order they should be removed.
299 // In this list, an insertion preceds any insertions it uses.
300 // (This is post-dominance order.)
301 InsertionVector ordered_candidates_for_removal;
302
303 // Proceed in rounds.
304 while (WorkList.size()) {
305 Changed = true;
306
307 // Record the list of tails for the next round.
308 InsertionVector NextRoundWorkList;
309
310 for (InsertValueInst *insertion : WorkList) {
311 // Rewrite |insertion|.
312
313 StructType *resultTy = cast<StructType>(insertion->getType());
314
315 const unsigned num_members = resultTy->getNumElements();
316 std::vector<Value*> members(num_members, nullptr);
317 InsertionVector chain;
318 // Gather the member values. Walk backward from the insertion.
319 Value *base = LoadValuesEndingWithInsertion(insertion, &members, &chain);
320
321 // Populate remaining entries in |values| by extracting elements
322 // from |base|. Only make a new extractvalue instruction if we can't
323 // make a constant or undef. New instructions are inserted before
324 // the insertion we plan to remove.
325 for (unsigned i = 0; i < num_members; ++i) {
326 if (!members[i]) {
327 Type *memberTy = resultTy->getTypeAtIndex(i);
328 if (isa<UndefValue>(base)) {
329 members[i] = UndefValue::get(memberTy);
330 } else if (const auto *caz = dyn_cast<ConstantAggregateZero>(base)) {
331 members[i] = caz->getElementValue(i);
332 } else if (const auto *ca = dyn_cast<ConstantAggregate>(base)) {
333 members[i] = ca->getOperand(i);
334 } else {
335 members[i] = ExtractValueInst::Create(base, {i}, "", insertion);
336 }
337 }
338 }
339
340 // Create the call. It's dominated by any extractions we've just
341 // created.
342 Function *construct_fn = GetConstructFunction(M, resultTy);
343 auto *call = CallInst::Create(construct_fn, members, "", insertion);
344
345 // Disconnect this insertion. We'll remove it later.
346 insertion->replaceAllUsesWith(call);
347
348 // Trace backward through the chain, removing uses and deleting where
349 // we can. Stop at the first element that has a remaining use.
350 for (auto* chainElem : chain) {
351 if (chainElem->hasNUsesOrMore(1)) {
352 unsigned &use_count = num_uses[insertions.idFor(chainElem)];
353 assert(use_count > 0);
354 --use_count;
355 if (use_count == 0) {
356 NextRoundWorkList.push_back(chainElem);
357 }
358 break;
359 } else {
360 chainElem->eraseFromParent();
361 }
362 }
363 }
364 WorkList = std::move(NextRoundWorkList);
365 }
366
367 return Changed;
368}