blob: 09df8cbecd225555665ffa3ebcd91371babdd931 [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
David Neto8e39bd12017-11-14 21:08:12 -050016#include <utility>
David Netoab03f432017-11-03 17:00:44 -040017
David Neto118188e2018-08-24 11:27:54 -040018#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/UniqueVector.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/IR/Function.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/Module.h"
27#include "llvm/Pass.h"
28#include "llvm/Support/raw_ostream.h"
David Netoab03f432017-11-03 17:00:44 -040029
David Neto482550a2018-03-24 05:21:07 -070030#include "clspv/Option.h"
31
SJW61531372020-06-09 07:31:08 -050032#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040033#include "Passes.h"
34
David Netoab03f432017-11-03 17:00:44 -040035using namespace llvm;
36using std::string;
37
David Neto8e39bd12017-11-14 21:08:12 -050038#define DEBUG_TYPE "rewriteinserts"
David Netoab03f432017-11-03 17:00:44 -040039
David Netoab03f432017-11-03 17:00:44 -040040namespace {
41
David Netoab03f432017-11-03 17:00:44 -040042class RewriteInsertsPass : public ModulePass {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040043public:
David Netoab03f432017-11-03 17:00:44 -040044 static char ID;
45 RewriteInsertsPass() : ModulePass(ID) {}
46
47 bool runOnModule(Module &M) override;
48
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040049private:
50 using InsertionVector = SmallVector<Instruction *, 4>;
David Netoab03f432017-11-03 17:00:44 -040051
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040052 // Replaces chains of insertions that cover the entire value.
53 // Such a change always reduces the number of instructions, so
54 // we always perform these. Returns true if the module was modified.
55 bool ReplaceCompleteInsertionChains(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050056
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040057 // Replaces all InsertValue instructions, even if they aren't part
58 // of a complete insetion chain. Returns true if the module was modified.
59 bool ReplacePartialInsertions(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050060
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040061 // Load |values| and |chain| with the members of the struct value produced
62 // by a chain of InsertValue instructions ending with |iv|, and following
63 // the aggregate operand. Return the start of the chain: the aggregate
64 // value which is not an InsertValue instruction, or an InsertValue
65 // instruction which inserts a component that is replaced later in the
66 // chain. The |values| vector will match the order of struct members and
67 // is initialized to all nullptr members. The |chain| vector will list
68 // the chain of InsertValue instructions, listed in the order we discover
69 // them, e.g. begining with |iv|.
70 Value *LoadValuesEndingWithInsertion(InsertValueInst *iv,
71 std::vector<Value *> *values,
72 InsertionVector *chain) {
73 auto *structTy = dyn_cast<StructType>(iv->getType());
74 assert(structTy);
alan-baker4a757f62020-04-22 08:17:49 -040075 if (!structTy)
76 return nullptr;
David Neto8e39bd12017-11-14 21:08:12 -050077
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040078 // Walk backward from the tail to an instruction we don't want to
79 // replace.
80 Value *frontier = iv;
81 while (auto *insertion = dyn_cast<InsertValueInst>(frontier)) {
82 chain->push_back(insertion);
83 // Only handle single-index insertions.
84 if (insertion->getNumIndices() == 1) {
85 // Try to replace this one.
David Neto8e39bd12017-11-14 21:08:12 -050086
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040087 unsigned index = insertion->getIndices()[0];
alan-baker4a757f62020-04-22 08:17:49 -040088 assert(index < structTy->getNumElements());
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040089 if ((*values)[index] != nullptr) {
90 // We already have a value for this slot. Stop now.
91 break;
92 }
93 (*values)[index] = insertion->getInsertedValueOperand();
94 frontier = insertion->getAggregateOperand();
95 } else {
96 break;
97 }
98 }
99 return frontier;
100 }
David Neto8e39bd12017-11-14 21:08:12 -0500101
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400102 // Returns the number of elements in the struct or array.
103 unsigned GetNumElements(Type *type) {
104 // CompositeType doesn't implement getNumElements(), but its inheritors
105 // do.
106 if (auto *struct_ty = dyn_cast<StructType>(type)) {
107 return struct_ty->getNumElements();
alan-baker8eb435a2020-04-08 00:42:06 -0400108 } else if (auto *array_ty = dyn_cast<ArrayType>(type)) {
109 return array_ty->getNumElements();
110 } else if (auto *vec_ty = dyn_cast<VectorType>(type)) {
alan-baker5a8c3be2020-09-09 13:44:26 -0400111 return vec_ty->getElementCount().getKnownMinValue();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400112 }
113 return 0;
114 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400115
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400116 // If this is the tail of a chain of InsertValueInst instructions
117 // that covers the entire composite, then return a small vector
118 // containing the insertion instructions, in member order.
119 // Otherwise returns nullptr.
120 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
121 if (iv->getNumIndices() == 1) {
122 auto numElems = GetNumElements(iv->getType());
123 if (numElems != 0) {
124 // Only handle single-index insertions.
125 unsigned index = iv->getIndices()[0];
126 if (index + 1u != numElems) {
127 // Not the last in the chain.
128 return nullptr;
129 }
130 InsertionVector candidates(numElems, nullptr);
131 for (unsigned i = index;
132 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
133 // iv inserts the i'th member
134 candidates[i] = iv;
Alan Bakerabc935e2018-09-06 11:33:53 -0400135
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400136 if (i == 0) {
137 // We're done!
138 return new InsertionVector(candidates);
139 }
David Netoab03f432017-11-03 17:00:44 -0400140
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400141 if (InsertValueInst *agg =
142 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
143 iv = agg;
144 } else {
145 // The chain is broken.
146 break;
147 }
148 }
149 }
150 }
151 return nullptr;
152 }
David Netoab03f432017-11-03 17:00:44 -0400153
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400154 // If this is the tail of a chain of InsertElementInst instructions
155 // that covers the entire vector, then return a small vector
156 // containing the insertion instructions, in member order.
157 // Otherwise returns nullptr. Only handle insertions into vectors.
158 InsertionVector *CompleteInsertionChain(InsertElementInst *ie) {
159 // Don't handle i8 vectors. Only <4 x i8> is supported and it is
160 // translated as i32. Only handle single-index insertions.
161 if (auto vec_ty = dyn_cast<VectorType>(ie->getType())) {
James Pricecf53df42020-04-20 14:41:24 -0400162 if (vec_ty->getElementType() == Type::getInt8Ty(ie->getContext())) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400163 return nullptr;
164 }
165 }
David Netoab03f432017-11-03 17:00:44 -0400166
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400167 // Only handle single-index insertions.
168 if (ie->getNumOperands() == 3) {
169 auto numElems = GetNumElements(ie->getType());
170 if (numElems != 0) {
171 if (auto *const_value = dyn_cast<ConstantInt>(ie->getOperand(2))) {
172 uint64_t index = const_value->getZExtValue();
173 if (index + 1u != numElems) {
174 // Not the last in the chain.
175 return nullptr;
176 }
177 InsertionVector candidates(numElems, nullptr);
178 Value *value = ie;
179 uint64_t i = index;
180 while (auto *insert = dyn_cast<InsertElementInst>(value)) {
181 if (insert->getNumOperands() != 3)
182 break;
183 if (auto *index_const =
184 dyn_cast<ConstantInt>(insert->getOperand(2))) {
185 if (i != index_const->getZExtValue())
186 break;
Alan Bakerabc935e2018-09-06 11:33:53 -0400187
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400188 candidates[i] = insert;
189 if (i == 0) {
190 // We're done!
191 return new InsertionVector(candidates);
192 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400193
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400194 value = insert->getOperand(0);
195 --i;
196 } else {
197 break;
198 }
199 }
200 } else {
201 return nullptr;
202 }
203 }
204 }
205 return nullptr;
206 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400207
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400208 // Return the name for the wrap function for the given type.
209 string &WrapFunctionNameForType(Type *type) {
210 auto where = function_for_type_.find(type);
211 if (where == function_for_type_.end()) {
212 // Insert it.
213 auto &result = function_for_type_[type] =
SJW61531372020-06-09 07:31:08 -0500214 clspv::CompositeConstructFunction() + "." +
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400215 std::to_string(function_for_type_.size());
216 return result;
217 } else {
218 return where->second;
219 }
220 }
David Neto8e39bd12017-11-14 21:08:12 -0500221
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400222 // Get or create the composite construct function definition.
alan-baker077517b2020-03-19 13:52:12 -0400223 Function *GetConstructFunction(Module &M, Type *constructed_type) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400224 // Get or create the composite construct function definition.
225 const string &fn_name = WrapFunctionNameForType(constructed_type);
226 Function *fn = M.getFunction(fn_name);
227 if (!fn) {
228 // Make the function.
229 SmallVector<Type *, 16> elements;
230 unsigned num_elements = GetNumElements(constructed_type);
alan-baker077517b2020-03-19 13:52:12 -0400231 if (auto struct_ty = dyn_cast<StructType>(constructed_type)) {
232 for (unsigned i = 0; i != num_elements; ++i)
233 elements.push_back(struct_ty->getTypeAtIndex(i));
alan-baker8eb435a2020-04-08 00:42:06 -0400234 } else if (isa<ArrayType>(constructed_type)) {
235 elements.resize(num_elements, constructed_type->getArrayElementType());
236 } else if (isa<VectorType>(constructed_type)) {
James Pricecf53df42020-04-20 14:41:24 -0400237 elements.resize(num_elements,
238 cast<VectorType>(constructed_type)->getElementType());
alan-baker077517b2020-03-19 13:52:12 -0400239 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400240 FunctionType *fnTy = FunctionType::get(constructed_type, elements, false);
241 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
242 fn = cast<Function>(fn_constant.getCallee());
243 fn->addFnAttr(Attribute::ReadOnly);
244 }
245 return fn;
246 }
David Netoab03f432017-11-03 17:00:44 -0400247
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400248 // Maps a loaded type to the name of the wrap function for that type.
249 DenseMap<Type *, string> function_for_type_;
David Netoab03f432017-11-03 17:00:44 -0400250};
251} // namespace
252
253char RewriteInsertsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400254INITIALIZE_PASS(RewriteInsertsPass, "RewriteInserts",
255 "Rewrite chains of insertvalue to as composite-construction",
256 false, false)
David Netoab03f432017-11-03 17:00:44 -0400257
258namespace clspv {
259llvm::ModulePass *createRewriteInsertsPass() {
260 return new RewriteInsertsPass();
261}
262} // namespace clspv
263
264bool RewriteInsertsPass::runOnModule(Module &M) {
David Neto8e39bd12017-11-14 21:08:12 -0500265 bool Changed = ReplaceCompleteInsertionChains(M);
266
David Neto482550a2018-03-24 05:21:07 -0700267 if (clspv::Option::HackInserts()) {
David Neto8e39bd12017-11-14 21:08:12 -0500268 Changed |= ReplacePartialInsertions(M);
269 }
270
271 return Changed;
272}
273
274bool RewriteInsertsPass::ReplaceCompleteInsertionChains(Module &M) {
David Netoab03f432017-11-03 17:00:44 -0400275 bool Changed = false;
276
277 SmallVector<InsertionVector *, 16> WorkList;
278 for (Function &F : M) {
279 for (BasicBlock &BB : F) {
280 for (Instruction &I : BB) {
281 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
282 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
283 WorkList.push_back(insertions);
284 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400285 } else if (InsertElementInst *ie = dyn_cast<InsertElementInst>(&I)) {
286 if (InsertionVector *insertions = CompleteInsertionChain(ie)) {
287 WorkList.push_back(insertions);
288 }
David Netoab03f432017-11-03 17:00:44 -0400289 }
290 }
291 }
292 }
293
294 if (WorkList.size() == 0) {
295 return Changed;
296 }
297
298 for (InsertionVector *insertions : WorkList) {
299 Changed = true;
300
301 // Gather the member values and types.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400302 SmallVector<Value *, 4> values;
303 for (Instruction *inst : *insertions) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400304 if (auto *insert_value = dyn_cast<InsertValueInst>(inst)) {
305 values.push_back(insert_value->getInsertedValueOperand());
306 } else if (auto *insert_element = dyn_cast<InsertElementInst>(inst)) {
307 values.push_back(insert_element->getOperand(1));
308 } else {
309 llvm_unreachable("Unhandled insertion instruction");
310 }
David Netoab03f432017-11-03 17:00:44 -0400311 }
312
alan-baker077517b2020-03-19 13:52:12 -0400313 auto *resultTy = insertions->back()->getType();
Alan Bakerabc935e2018-09-06 11:33:53 -0400314 Function *fn = GetConstructFunction(M, resultTy);
David Netoab03f432017-11-03 17:00:44 -0400315
316 // Replace the chain.
317 auto call = CallInst::Create(fn, values);
318 call->insertAfter(insertions->back());
319 insertions->back()->replaceAllUsesWith(call);
320
321 // Remove the insertions if we can. Go from the tail back to
322 // the head, since the tail uses the previous insertion, etc.
323 for (auto iter = insertions->rbegin(), end = insertions->rend();
324 iter != end; ++iter) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400325 Instruction *insertion = *iter;
David Netoab03f432017-11-03 17:00:44 -0400326 if (!insertion->hasNUsesOrMore(1)) {
327 insertion->eraseFromParent();
328 }
329 }
330
331 delete insertions;
332 }
333
334 return Changed;
335}
David Neto8e39bd12017-11-14 21:08:12 -0500336
337bool RewriteInsertsPass::ReplacePartialInsertions(Module &M) {
338 bool Changed = false;
339
340 // First find candidates. Collect all InsertValue instructions
341 // into struct type, but track their interdependencies. To minimize
342 // the number of new instructions, generate a construction for each
343 // tail of an insertion chain.
344
345 UniqueVector<InsertValueInst *> insertions;
346 for (Function &F : M) {
347 for (BasicBlock &BB : F) {
348 for (Instruction &I : BB) {
349 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
350 if (iv->getType()->isStructTy()) {
351 insertions.insert(iv);
352 }
353 }
354 }
355 }
356 }
357
358 // Now count how many times each InsertValue is used by another InsertValue.
359 // The |num_uses| vector is indexed by the unique id that |insertions|
360 // assigns to it.
361 std::vector<unsigned> num_uses(insertions.size() + 1);
362 // Count from the user's perspective.
363 for (InsertValueInst *insertion : insertions) {
364 if (auto *agg =
365 dyn_cast<InsertValueInst>(insertion->getAggregateOperand())) {
366 ++(num_uses[insertions.idFor(agg)]);
367 }
368 }
369
370 // Proceed in rounds. Each round rewrites any chains ending with an
371 // insertion that is not used by another insertion.
372
373 // Get the first list of insertion tails.
374 InsertionVector WorkList;
375 for (InsertValueInst *insertion : insertions) {
376 if (num_uses[insertions.idFor(insertion)] == 0) {
377 WorkList.push_back(insertion);
378 }
379 }
380
381 // This records insertions in the order they should be removed.
382 // In this list, an insertion preceds any insertions it uses.
383 // (This is post-dominance order.)
384 InsertionVector ordered_candidates_for_removal;
385
386 // Proceed in rounds.
387 while (WorkList.size()) {
388 Changed = true;
389
390 // Record the list of tails for the next round.
391 InsertionVector NextRoundWorkList;
392
Alan Bakerabc935e2018-09-06 11:33:53 -0400393 for (Instruction *inst : WorkList) {
394 InsertValueInst *insertion = cast<InsertValueInst>(inst);
David Neto8e39bd12017-11-14 21:08:12 -0500395 // Rewrite |insertion|.
396
397 StructType *resultTy = cast<StructType>(insertion->getType());
398
399 const unsigned num_members = resultTy->getNumElements();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400400 std::vector<Value *> members(num_members, nullptr);
David Neto8e39bd12017-11-14 21:08:12 -0500401 InsertionVector chain;
402 // Gather the member values. Walk backward from the insertion.
403 Value *base = LoadValuesEndingWithInsertion(insertion, &members, &chain);
404
405 // Populate remaining entries in |values| by extracting elements
406 // from |base|. Only make a new extractvalue instruction if we can't
407 // make a constant or undef. New instructions are inserted before
408 // the insertion we plan to remove.
409 for (unsigned i = 0; i < num_members; ++i) {
410 if (!members[i]) {
411 Type *memberTy = resultTy->getTypeAtIndex(i);
412 if (isa<UndefValue>(base)) {
413 members[i] = UndefValue::get(memberTy);
414 } else if (const auto *caz = dyn_cast<ConstantAggregateZero>(base)) {
415 members[i] = caz->getElementValue(i);
416 } else if (const auto *ca = dyn_cast<ConstantAggregate>(base)) {
417 members[i] = ca->getOperand(i);
418 } else {
419 members[i] = ExtractValueInst::Create(base, {i}, "", insertion);
420 }
421 }
422 }
423
424 // Create the call. It's dominated by any extractions we've just
425 // created.
Alan Bakerabc935e2018-09-06 11:33:53 -0400426 Function *construct_fn = GetConstructFunction(M, resultTy);
David Neto8e39bd12017-11-14 21:08:12 -0500427 auto *call = CallInst::Create(construct_fn, members, "", insertion);
428
429 // Disconnect this insertion. We'll remove it later.
430 insertion->replaceAllUsesWith(call);
431
432 // Trace backward through the chain, removing uses and deleting where
433 // we can. Stop at the first element that has a remaining use.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400434 for (auto *chainElem : chain) {
David Neto8e39bd12017-11-14 21:08:12 -0500435 if (chainElem->hasNUsesOrMore(1)) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400436 unsigned &use_count =
437 num_uses[insertions.idFor(cast<InsertValueInst>(chainElem))];
David Neto8e39bd12017-11-14 21:08:12 -0500438 assert(use_count > 0);
439 --use_count;
440 if (use_count == 0) {
441 NextRoundWorkList.push_back(chainElem);
442 }
443 break;
444 } else {
445 chainElem->eraseFromParent();
446 }
447 }
448 }
449 WorkList = std::move(NextRoundWorkList);
450 }
451
452 return Changed;
453}