blob: 052aff0d401f435811f10467ea0955cdaff162a3 [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
David Neto8e39bd12017-11-14 21:08:12 -050016#include <utility>
David Netoab03f432017-11-03 17:00:44 -040017
18#include <llvm/ADT/DenseMap.h>
19#include <llvm/ADT/SmallVector.h>
David Neto8e39bd12017-11-14 21:08:12 -050020#include <llvm/ADT/UniqueVector.h>
David Netoab03f432017-11-03 17:00:44 -040021#include <llvm/IR/Attributes.h>
David Neto8e39bd12017-11-14 21:08:12 -050022#include <llvm/IR/Constants.h>
23#include <llvm/IR/DerivedTypes.h>
David Netoab03f432017-11-03 17:00:44 -040024#include <llvm/IR/DerivedTypes.h>
25#include <llvm/IR/Function.h>
26#include <llvm/IR/Instructions.h>
27#include <llvm/IR/Module.h>
28#include <llvm/Pass.h>
29#include <llvm/Support/raw_ostream.h>
30
31using namespace llvm;
32using std::string;
33
David Neto8e39bd12017-11-14 21:08:12 -050034#define DEBUG_TYPE "rewriteinserts"
David Netoab03f432017-11-03 17:00:44 -040035
David Neto8e39bd12017-11-14 21:08:12 -050036static llvm::cl::opt<bool> hack_inserts(
37 "hack-inserts", llvm::cl::init(false),
38 llvm::cl::desc(
39 "Avoid all single-index OpCompositInsert instructions "
40 "into struct types by using complete composite construction and "
41 "extractions"));
David Netoab03f432017-11-03 17:00:44 -040042
43namespace {
44
45const char* kCompositeConstructFunctionPrefix = "clspv.composite_construct.";
46
47class RewriteInsertsPass : public ModulePass {
48 public:
49 static char ID;
50 RewriteInsertsPass() : ModulePass(ID) {}
51
52 bool runOnModule(Module &M) override;
53
54 private:
55 using InsertionVector = SmallVector<InsertValueInst *, 4>;
56
David Neto8e39bd12017-11-14 21:08:12 -050057 // Replaces chains of insertions that cover the entire value.
58 // Such a change always reduces the number of instructions, so
59 // we always perform these. Returns true if the module was modified.
60 bool ReplaceCompleteInsertionChains(Module &M);
61
62 // Replaces all InsertValue instructions, even if they aren't part
63 // of a complete insetion chain. Returns true if the module was modified.
64 bool ReplacePartialInsertions(Module &M);
65
66 // Load |values| and |chain| with the members of the struct value produced
67 // by a chain of InsertValue instructions ending with |iv|, and following
68 // the aggregate operand. Return the start of the chain: the aggregate
69 // value which is not an InsertValue instruction, or an InsertValue
70 // instruction which inserts a component that is replaced later in the
71 // chain. The |values| vector will match the order of struct members and
72 // is initialized to all nullptr members. The |chain| vector will list
73 // the chain of InsertValue instructions, listed in the order we discover
74 // them, e.g. begining with |iv|.
75 Value *LoadValuesEndingWithInsertion(InsertValueInst *iv,
76 std::vector<Value *> *values,
77 InsertionVector *chain) {
78 auto *structTy = dyn_cast<StructType>(iv->getType());
79 assert(structTy);
80 const auto numElems = structTy->getNumElements();
81
82 // Walk backward from the tail to an instruction we don't want to
83 // replace.
84 Value *frontier = iv;
85 while (auto *insertion = dyn_cast<InsertValueInst>(frontier)) {
86 chain->push_back(insertion);
87 // Only handle single-index insertions.
88 if (insertion->getNumIndices() == 1) {
89 // Try to replace this one.
90
91 unsigned index = insertion->getIndices()[0];
92 assert(index < numElems);
93 if ((*values)[index] != nullptr) {
94 // We already have a value for this slot. Stop now.
95 break;
96 }
97 (*values)[index] = insertion->getInsertedValueOperand();
98 frontier = insertion->getAggregateOperand();
99 } else {
100 break;
101 }
102 }
103 return frontier;
104 }
105
David Netoab03f432017-11-03 17:00:44 -0400106 // If this is the tail of a chain of InsertValueInst instructions
107 // that covers the entire composite, then return a small vector
108 // containing the insertion instructions, in member order.
109 // Otherwise returns nullptr. Only handle insertions into structs,
110 // not into arrays.
111 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
112 if (iv->getNumIndices() == 1) {
113 if (auto *structTy = dyn_cast<StructType>(iv->getType())) {
114 auto numElems = structTy->getNumElements();
115 // Only handle single-index insertions.
116 unsigned index = iv->getIndices()[0];
117 if (index + 1u != numElems) {
118 // Not the last in the chain.
119 return nullptr;
120 }
121 InsertionVector candidates(numElems, nullptr);
122 for (unsigned i = index;
123 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
124 // iv inserts the i'th member
125 candidates[i] = iv;
126
127 if (i == 0) {
128 // We're done!
129 return new InsertionVector(candidates);
130 }
131
132 if (InsertValueInst *agg =
133 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
134 iv = agg;
135 } else {
136 // The chain is broken.
137 break;
138 }
139 }
140 }
141 }
142 return nullptr;
143 }
144
David Neto8e39bd12017-11-14 21:08:12 -0500145
David Netoab03f432017-11-03 17:00:44 -0400146 // Return the name for the wrap function for the given type.
147 string &WrapFunctionNameForType(Type *type) {
148 auto where = function_for_type_.find(type);
149 if (where == function_for_type_.end()) {
150 // Insert it.
151 auto &result = function_for_type_[type] =
152 string(kCompositeConstructFunctionPrefix) +
153 std::to_string(function_for_type_.size());
154 return result;
155 } else {
156 return where->second;
157 }
158 }
159
David Neto8e39bd12017-11-14 21:08:12 -0500160 // Get or create the composite construct function definition.
161 Function *GetConstructFunction(Module &M, StructType *constructed_type) {
162 // Get or create the composite construct function definition.
163 const string &fn_name = WrapFunctionNameForType(constructed_type);
164 Function *fn = M.getFunction(fn_name);
165 if (!fn) {
166 // Make the function.
167 FunctionType *fnTy = FunctionType::get(
168 constructed_type, constructed_type->elements(), false);
169 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
170 fn = cast<Function>(fn_constant);
171 fn->addFnAttr(Attribute::ReadOnly);
172 fn->addFnAttr(Attribute::ReadNone);
173 }
174 return fn;
175 }
176
David Netoab03f432017-11-03 17:00:44 -0400177 // Maps a loaded type to the name of the wrap function for that type.
178 DenseMap<Type *, string> function_for_type_;
179};
180} // namespace
181
182char RewriteInsertsPass::ID = 0;
183static RegisterPass<RewriteInsertsPass>
184 X("RewriteInserts",
185 "Rewrite chains of insertvalue to as composite-construction");
186
187namespace clspv {
188llvm::ModulePass *createRewriteInsertsPass() {
189 return new RewriteInsertsPass();
190}
191} // namespace clspv
192
193bool RewriteInsertsPass::runOnModule(Module &M) {
David Neto8e39bd12017-11-14 21:08:12 -0500194 bool Changed = ReplaceCompleteInsertionChains(M);
195
196 if (hack_inserts) {
197 Changed |= ReplacePartialInsertions(M);
198 }
199
200 return Changed;
201}
202
203bool RewriteInsertsPass::ReplaceCompleteInsertionChains(Module &M) {
David Netoab03f432017-11-03 17:00:44 -0400204 bool Changed = false;
205
206 SmallVector<InsertionVector *, 16> WorkList;
207 for (Function &F : M) {
208 for (BasicBlock &BB : F) {
209 for (Instruction &I : BB) {
210 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
211 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
212 WorkList.push_back(insertions);
213 }
214 }
215 }
216 }
217 }
218
219 if (WorkList.size() == 0) {
220 return Changed;
221 }
222
223 for (InsertionVector *insertions : WorkList) {
224 Changed = true;
225
226 // Gather the member values and types.
227 SmallVector<Value*, 4> values;
228 SmallVector<Type*, 4> types;
229 for (InsertValueInst* insert : *insertions) {
230 Value* value = insert->getInsertedValueOperand();
231 values.push_back(value);
232 types.push_back(value->getType());
233 }
234
David Neto8e39bd12017-11-14 21:08:12 -0500235 StructType *resultTy = cast<StructType>(insertions->back()->getType());
236 Function *fn = GetConstructFunction(M, resultTy);
David Netoab03f432017-11-03 17:00:44 -0400237
238 // Replace the chain.
239 auto call = CallInst::Create(fn, values);
240 call->insertAfter(insertions->back());
241 insertions->back()->replaceAllUsesWith(call);
242
243 // Remove the insertions if we can. Go from the tail back to
244 // the head, since the tail uses the previous insertion, etc.
245 for (auto iter = insertions->rbegin(), end = insertions->rend();
246 iter != end; ++iter) {
247 InsertValueInst *insertion = *iter;
248 if (!insertion->hasNUsesOrMore(1)) {
249 insertion->eraseFromParent();
250 }
251 }
252
253 delete insertions;
254 }
255
256 return Changed;
257}
David Neto8e39bd12017-11-14 21:08:12 -0500258
259bool RewriteInsertsPass::ReplacePartialInsertions(Module &M) {
260 bool Changed = false;
261
262 // First find candidates. Collect all InsertValue instructions
263 // into struct type, but track their interdependencies. To minimize
264 // the number of new instructions, generate a construction for each
265 // tail of an insertion chain.
266
267 UniqueVector<InsertValueInst *> insertions;
268 for (Function &F : M) {
269 for (BasicBlock &BB : F) {
270 for (Instruction &I : BB) {
271 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
272 if (iv->getType()->isStructTy()) {
273 insertions.insert(iv);
274 }
275 }
276 }
277 }
278 }
279
280 // Now count how many times each InsertValue is used by another InsertValue.
281 // The |num_uses| vector is indexed by the unique id that |insertions|
282 // assigns to it.
283 std::vector<unsigned> num_uses(insertions.size() + 1);
284 // Count from the user's perspective.
285 for (InsertValueInst *insertion : insertions) {
286 if (auto *agg =
287 dyn_cast<InsertValueInst>(insertion->getAggregateOperand())) {
288 ++(num_uses[insertions.idFor(agg)]);
289 }
290 }
291
292 // Proceed in rounds. Each round rewrites any chains ending with an
293 // insertion that is not used by another insertion.
294
295 // Get the first list of insertion tails.
296 InsertionVector WorkList;
297 for (InsertValueInst *insertion : insertions) {
298 if (num_uses[insertions.idFor(insertion)] == 0) {
299 WorkList.push_back(insertion);
300 }
301 }
302
303 // This records insertions in the order they should be removed.
304 // In this list, an insertion preceds any insertions it uses.
305 // (This is post-dominance order.)
306 InsertionVector ordered_candidates_for_removal;
307
308 // Proceed in rounds.
309 while (WorkList.size()) {
310 Changed = true;
311
312 // Record the list of tails for the next round.
313 InsertionVector NextRoundWorkList;
314
315 for (InsertValueInst *insertion : WorkList) {
316 // Rewrite |insertion|.
317
318 StructType *resultTy = cast<StructType>(insertion->getType());
319
320 const unsigned num_members = resultTy->getNumElements();
321 std::vector<Value*> members(num_members, nullptr);
322 InsertionVector chain;
323 // Gather the member values. Walk backward from the insertion.
324 Value *base = LoadValuesEndingWithInsertion(insertion, &members, &chain);
325
326 // Populate remaining entries in |values| by extracting elements
327 // from |base|. Only make a new extractvalue instruction if we can't
328 // make a constant or undef. New instructions are inserted before
329 // the insertion we plan to remove.
330 for (unsigned i = 0; i < num_members; ++i) {
331 if (!members[i]) {
332 Type *memberTy = resultTy->getTypeAtIndex(i);
333 if (isa<UndefValue>(base)) {
334 members[i] = UndefValue::get(memberTy);
335 } else if (const auto *caz = dyn_cast<ConstantAggregateZero>(base)) {
336 members[i] = caz->getElementValue(i);
337 } else if (const auto *ca = dyn_cast<ConstantAggregate>(base)) {
338 members[i] = ca->getOperand(i);
339 } else {
340 members[i] = ExtractValueInst::Create(base, {i}, "", insertion);
341 }
342 }
343 }
344
345 // Create the call. It's dominated by any extractions we've just
346 // created.
347 Function *construct_fn = GetConstructFunction(M, resultTy);
348 auto *call = CallInst::Create(construct_fn, members, "", insertion);
349
350 // Disconnect this insertion. We'll remove it later.
351 insertion->replaceAllUsesWith(call);
352
353 // Trace backward through the chain, removing uses and deleting where
354 // we can. Stop at the first element that has a remaining use.
355 for (auto* chainElem : chain) {
356 if (chainElem->hasNUsesOrMore(1)) {
357 unsigned &use_count = num_uses[insertions.idFor(chainElem)];
358 assert(use_count > 0);
359 --use_count;
360 if (use_count == 0) {
361 NextRoundWorkList.push_back(chainElem);
362 }
363 break;
364 } else {
365 chainElem->eraseFromParent();
366 }
367 }
368 }
369 WorkList = std::move(NextRoundWorkList);
370 }
371
372 return Changed;
373}