blob: a21fefc9bbd9e6d92e95a600a9e366d03ddf91b0 [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
David Neto8e39bd12017-11-14 21:08:12 -050016#include <utility>
David Netoab03f432017-11-03 17:00:44 -040017
David Neto118188e2018-08-24 11:27:54 -040018#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/UniqueVector.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/IR/Function.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/Module.h"
27#include "llvm/Pass.h"
28#include "llvm/Support/raw_ostream.h"
David Netoab03f432017-11-03 17:00:44 -040029
David Neto482550a2018-03-24 05:21:07 -070030#include "clspv/Option.h"
31
David Netoab03f432017-11-03 17:00:44 -040032using namespace llvm;
33using std::string;
34
David Neto8e39bd12017-11-14 21:08:12 -050035#define DEBUG_TYPE "rewriteinserts"
David Netoab03f432017-11-03 17:00:44 -040036
David Netoab03f432017-11-03 17:00:44 -040037namespace {
38
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040039const char *kCompositeConstructFunctionPrefix = "clspv.composite_construct.";
David Netoab03f432017-11-03 17:00:44 -040040
41class RewriteInsertsPass : public ModulePass {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042public:
David Netoab03f432017-11-03 17:00:44 -040043 static char ID;
44 RewriteInsertsPass() : ModulePass(ID) {}
45
46 bool runOnModule(Module &M) override;
47
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040048private:
49 using InsertionVector = SmallVector<Instruction *, 4>;
David Netoab03f432017-11-03 17:00:44 -040050
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040051 // Replaces chains of insertions that cover the entire value.
52 // Such a change always reduces the number of instructions, so
53 // we always perform these. Returns true if the module was modified.
54 bool ReplaceCompleteInsertionChains(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050055
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040056 // Replaces all InsertValue instructions, even if they aren't part
57 // of a complete insetion chain. Returns true if the module was modified.
58 bool ReplacePartialInsertions(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050059
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040060 // Load |values| and |chain| with the members of the struct value produced
61 // by a chain of InsertValue instructions ending with |iv|, and following
62 // the aggregate operand. Return the start of the chain: the aggregate
63 // value which is not an InsertValue instruction, or an InsertValue
64 // instruction which inserts a component that is replaced later in the
65 // chain. The |values| vector will match the order of struct members and
66 // is initialized to all nullptr members. The |chain| vector will list
67 // the chain of InsertValue instructions, listed in the order we discover
68 // them, e.g. begining with |iv|.
69 Value *LoadValuesEndingWithInsertion(InsertValueInst *iv,
70 std::vector<Value *> *values,
71 InsertionVector *chain) {
72 auto *structTy = dyn_cast<StructType>(iv->getType());
73 assert(structTy);
74 const auto numElems = structTy->getNumElements();
David Neto8e39bd12017-11-14 21:08:12 -050075
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040076 // Walk backward from the tail to an instruction we don't want to
77 // replace.
78 Value *frontier = iv;
79 while (auto *insertion = dyn_cast<InsertValueInst>(frontier)) {
80 chain->push_back(insertion);
81 // Only handle single-index insertions.
82 if (insertion->getNumIndices() == 1) {
83 // Try to replace this one.
David Neto8e39bd12017-11-14 21:08:12 -050084
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040085 unsigned index = insertion->getIndices()[0];
86 assert(index < numElems);
87 if ((*values)[index] != nullptr) {
88 // We already have a value for this slot. Stop now.
89 break;
90 }
91 (*values)[index] = insertion->getInsertedValueOperand();
92 frontier = insertion->getAggregateOperand();
93 } else {
94 break;
95 }
96 }
97 return frontier;
98 }
David Neto8e39bd12017-11-14 21:08:12 -050099
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400100 // Returns the number of elements in the struct or array.
101 unsigned GetNumElements(Type *type) {
102 // CompositeType doesn't implement getNumElements(), but its inheritors
103 // do.
104 if (auto *struct_ty = dyn_cast<StructType>(type)) {
105 return struct_ty->getNumElements();
106 } else if (auto *seq_ty = dyn_cast<SequentialType>(type)) {
107 return seq_ty->getNumElements();
108 }
109 return 0;
110 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400111
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400112 // If this is the tail of a chain of InsertValueInst instructions
113 // that covers the entire composite, then return a small vector
114 // containing the insertion instructions, in member order.
115 // Otherwise returns nullptr.
116 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
117 if (iv->getNumIndices() == 1) {
118 auto numElems = GetNumElements(iv->getType());
119 if (numElems != 0) {
120 // Only handle single-index insertions.
121 unsigned index = iv->getIndices()[0];
122 if (index + 1u != numElems) {
123 // Not the last in the chain.
124 return nullptr;
125 }
126 InsertionVector candidates(numElems, nullptr);
127 for (unsigned i = index;
128 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
129 // iv inserts the i'th member
130 candidates[i] = iv;
Alan Bakerabc935e2018-09-06 11:33:53 -0400131
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400132 if (i == 0) {
133 // We're done!
134 return new InsertionVector(candidates);
135 }
David Netoab03f432017-11-03 17:00:44 -0400136
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400137 if (InsertValueInst *agg =
138 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
139 iv = agg;
140 } else {
141 // The chain is broken.
142 break;
143 }
144 }
145 }
146 }
147 return nullptr;
148 }
David Netoab03f432017-11-03 17:00:44 -0400149
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400150 // If this is the tail of a chain of InsertElementInst instructions
151 // that covers the entire vector, then return a small vector
152 // containing the insertion instructions, in member order.
153 // Otherwise returns nullptr. Only handle insertions into vectors.
154 InsertionVector *CompleteInsertionChain(InsertElementInst *ie) {
155 // Don't handle i8 vectors. Only <4 x i8> is supported and it is
156 // translated as i32. Only handle single-index insertions.
157 if (auto vec_ty = dyn_cast<VectorType>(ie->getType())) {
158 if (vec_ty->getVectorElementType() == Type::getInt8Ty(ie->getContext())) {
159 return nullptr;
160 }
161 }
David Netoab03f432017-11-03 17:00:44 -0400162
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400163 // Only handle single-index insertions.
164 if (ie->getNumOperands() == 3) {
165 auto numElems = GetNumElements(ie->getType());
166 if (numElems != 0) {
167 if (auto *const_value = dyn_cast<ConstantInt>(ie->getOperand(2))) {
168 uint64_t index = const_value->getZExtValue();
169 if (index + 1u != numElems) {
170 // Not the last in the chain.
171 return nullptr;
172 }
173 InsertionVector candidates(numElems, nullptr);
174 Value *value = ie;
175 uint64_t i = index;
176 while (auto *insert = dyn_cast<InsertElementInst>(value)) {
177 if (insert->getNumOperands() != 3)
178 break;
179 if (auto *index_const =
180 dyn_cast<ConstantInt>(insert->getOperand(2))) {
181 if (i != index_const->getZExtValue())
182 break;
Alan Bakerabc935e2018-09-06 11:33:53 -0400183
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400184 candidates[i] = insert;
185 if (i == 0) {
186 // We're done!
187 return new InsertionVector(candidates);
188 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400189
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400190 value = insert->getOperand(0);
191 --i;
192 } else {
193 break;
194 }
195 }
196 } else {
197 return nullptr;
198 }
199 }
200 }
201 return nullptr;
202 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400203
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400204 // Return the name for the wrap function for the given type.
205 string &WrapFunctionNameForType(Type *type) {
206 auto where = function_for_type_.find(type);
207 if (where == function_for_type_.end()) {
208 // Insert it.
209 auto &result = function_for_type_[type] =
210 string(kCompositeConstructFunctionPrefix) +
211 std::to_string(function_for_type_.size());
212 return result;
213 } else {
214 return where->second;
215 }
216 }
David Neto8e39bd12017-11-14 21:08:12 -0500217
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400218 // Get or create the composite construct function definition.
219 Function *GetConstructFunction(Module &M, CompositeType *constructed_type) {
220 // Get or create the composite construct function definition.
221 const string &fn_name = WrapFunctionNameForType(constructed_type);
222 Function *fn = M.getFunction(fn_name);
223 if (!fn) {
224 // Make the function.
225 SmallVector<Type *, 16> elements;
226 unsigned num_elements = GetNumElements(constructed_type);
227 for (unsigned i = 0; i != num_elements; ++i)
228 elements.push_back(constructed_type->getTypeAtIndex(i));
229 FunctionType *fnTy = FunctionType::get(constructed_type, elements, false);
230 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
231 fn = cast<Function>(fn_constant.getCallee());
232 fn->addFnAttr(Attribute::ReadOnly);
233 }
234 return fn;
235 }
David Netoab03f432017-11-03 17:00:44 -0400236
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400237 // Maps a loaded type to the name of the wrap function for that type.
238 DenseMap<Type *, string> function_for_type_;
David Netoab03f432017-11-03 17:00:44 -0400239};
240} // namespace
241
242char RewriteInsertsPass::ID = 0;
243static RegisterPass<RewriteInsertsPass>
244 X("RewriteInserts",
245 "Rewrite chains of insertvalue to as composite-construction");
246
247namespace clspv {
248llvm::ModulePass *createRewriteInsertsPass() {
249 return new RewriteInsertsPass();
250}
251} // namespace clspv
252
253bool RewriteInsertsPass::runOnModule(Module &M) {
David Neto8e39bd12017-11-14 21:08:12 -0500254 bool Changed = ReplaceCompleteInsertionChains(M);
255
David Neto482550a2018-03-24 05:21:07 -0700256 if (clspv::Option::HackInserts()) {
David Neto8e39bd12017-11-14 21:08:12 -0500257 Changed |= ReplacePartialInsertions(M);
258 }
259
260 return Changed;
261}
262
263bool RewriteInsertsPass::ReplaceCompleteInsertionChains(Module &M) {
David Netoab03f432017-11-03 17:00:44 -0400264 bool Changed = false;
265
266 SmallVector<InsertionVector *, 16> WorkList;
267 for (Function &F : M) {
268 for (BasicBlock &BB : F) {
269 for (Instruction &I : BB) {
270 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
271 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
272 WorkList.push_back(insertions);
273 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400274 } else if (InsertElementInst *ie = dyn_cast<InsertElementInst>(&I)) {
275 if (InsertionVector *insertions = CompleteInsertionChain(ie)) {
276 WorkList.push_back(insertions);
277 }
David Netoab03f432017-11-03 17:00:44 -0400278 }
279 }
280 }
281 }
282
283 if (WorkList.size() == 0) {
284 return Changed;
285 }
286
287 for (InsertionVector *insertions : WorkList) {
288 Changed = true;
289
290 // Gather the member values and types.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400291 SmallVector<Value *, 4> values;
292 for (Instruction *inst : *insertions) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400293 if (auto *insert_value = dyn_cast<InsertValueInst>(inst)) {
294 values.push_back(insert_value->getInsertedValueOperand());
295 } else if (auto *insert_element = dyn_cast<InsertElementInst>(inst)) {
296 values.push_back(insert_element->getOperand(1));
297 } else {
298 llvm_unreachable("Unhandled insertion instruction");
299 }
David Netoab03f432017-11-03 17:00:44 -0400300 }
301
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400302 CompositeType *resultTy =
303 cast<CompositeType>(insertions->back()->getType());
Alan Bakerabc935e2018-09-06 11:33:53 -0400304 Function *fn = GetConstructFunction(M, resultTy);
David Netoab03f432017-11-03 17:00:44 -0400305
306 // Replace the chain.
307 auto call = CallInst::Create(fn, values);
308 call->insertAfter(insertions->back());
309 insertions->back()->replaceAllUsesWith(call);
310
311 // Remove the insertions if we can. Go from the tail back to
312 // the head, since the tail uses the previous insertion, etc.
313 for (auto iter = insertions->rbegin(), end = insertions->rend();
314 iter != end; ++iter) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400315 Instruction *insertion = *iter;
David Netoab03f432017-11-03 17:00:44 -0400316 if (!insertion->hasNUsesOrMore(1)) {
317 insertion->eraseFromParent();
318 }
319 }
320
321 delete insertions;
322 }
323
324 return Changed;
325}
David Neto8e39bd12017-11-14 21:08:12 -0500326
327bool RewriteInsertsPass::ReplacePartialInsertions(Module &M) {
328 bool Changed = false;
329
330 // First find candidates. Collect all InsertValue instructions
331 // into struct type, but track their interdependencies. To minimize
332 // the number of new instructions, generate a construction for each
333 // tail of an insertion chain.
334
335 UniqueVector<InsertValueInst *> insertions;
336 for (Function &F : M) {
337 for (BasicBlock &BB : F) {
338 for (Instruction &I : BB) {
339 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
340 if (iv->getType()->isStructTy()) {
341 insertions.insert(iv);
342 }
343 }
344 }
345 }
346 }
347
348 // Now count how many times each InsertValue is used by another InsertValue.
349 // The |num_uses| vector is indexed by the unique id that |insertions|
350 // assigns to it.
351 std::vector<unsigned> num_uses(insertions.size() + 1);
352 // Count from the user's perspective.
353 for (InsertValueInst *insertion : insertions) {
354 if (auto *agg =
355 dyn_cast<InsertValueInst>(insertion->getAggregateOperand())) {
356 ++(num_uses[insertions.idFor(agg)]);
357 }
358 }
359
360 // Proceed in rounds. Each round rewrites any chains ending with an
361 // insertion that is not used by another insertion.
362
363 // Get the first list of insertion tails.
364 InsertionVector WorkList;
365 for (InsertValueInst *insertion : insertions) {
366 if (num_uses[insertions.idFor(insertion)] == 0) {
367 WorkList.push_back(insertion);
368 }
369 }
370
371 // This records insertions in the order they should be removed.
372 // In this list, an insertion preceds any insertions it uses.
373 // (This is post-dominance order.)
374 InsertionVector ordered_candidates_for_removal;
375
376 // Proceed in rounds.
377 while (WorkList.size()) {
378 Changed = true;
379
380 // Record the list of tails for the next round.
381 InsertionVector NextRoundWorkList;
382
Alan Bakerabc935e2018-09-06 11:33:53 -0400383 for (Instruction *inst : WorkList) {
384 InsertValueInst *insertion = cast<InsertValueInst>(inst);
David Neto8e39bd12017-11-14 21:08:12 -0500385 // Rewrite |insertion|.
386
387 StructType *resultTy = cast<StructType>(insertion->getType());
388
389 const unsigned num_members = resultTy->getNumElements();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400390 std::vector<Value *> members(num_members, nullptr);
David Neto8e39bd12017-11-14 21:08:12 -0500391 InsertionVector chain;
392 // Gather the member values. Walk backward from the insertion.
393 Value *base = LoadValuesEndingWithInsertion(insertion, &members, &chain);
394
395 // Populate remaining entries in |values| by extracting elements
396 // from |base|. Only make a new extractvalue instruction if we can't
397 // make a constant or undef. New instructions are inserted before
398 // the insertion we plan to remove.
399 for (unsigned i = 0; i < num_members; ++i) {
400 if (!members[i]) {
401 Type *memberTy = resultTy->getTypeAtIndex(i);
402 if (isa<UndefValue>(base)) {
403 members[i] = UndefValue::get(memberTy);
404 } else if (const auto *caz = dyn_cast<ConstantAggregateZero>(base)) {
405 members[i] = caz->getElementValue(i);
406 } else if (const auto *ca = dyn_cast<ConstantAggregate>(base)) {
407 members[i] = ca->getOperand(i);
408 } else {
409 members[i] = ExtractValueInst::Create(base, {i}, "", insertion);
410 }
411 }
412 }
413
414 // Create the call. It's dominated by any extractions we've just
415 // created.
Alan Bakerabc935e2018-09-06 11:33:53 -0400416 Function *construct_fn = GetConstructFunction(M, resultTy);
David Neto8e39bd12017-11-14 21:08:12 -0500417 auto *call = CallInst::Create(construct_fn, members, "", insertion);
418
419 // Disconnect this insertion. We'll remove it later.
420 insertion->replaceAllUsesWith(call);
421
422 // Trace backward through the chain, removing uses and deleting where
423 // we can. Stop at the first element that has a remaining use.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400424 for (auto *chainElem : chain) {
David Neto8e39bd12017-11-14 21:08:12 -0500425 if (chainElem->hasNUsesOrMore(1)) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400426 unsigned &use_count =
427 num_uses[insertions.idFor(cast<InsertValueInst>(chainElem))];
David Neto8e39bd12017-11-14 21:08:12 -0500428 assert(use_count > 0);
429 --use_count;
430 if (use_count == 0) {
431 NextRoundWorkList.push_back(chainElem);
432 }
433 break;
434 } else {
435 chainElem->eraseFromParent();
436 }
437 }
438 }
439 WorkList = std::move(NextRoundWorkList);
440 }
441
442 return Changed;
443}