blob: 330d739514341ab58267f8cc980616140b6dc75c [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
David Neto8e39bd12017-11-14 21:08:12 -050016#include <utility>
David Netoab03f432017-11-03 17:00:44 -040017
David Neto118188e2018-08-24 11:27:54 -040018#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/UniqueVector.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
David Neto118188e2018-08-24 11:27:54 -040024#include "llvm/IR/Function.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/Module.h"
27#include "llvm/Pass.h"
28#include "llvm/Support/raw_ostream.h"
David Netoab03f432017-11-03 17:00:44 -040029
David Neto482550a2018-03-24 05:21:07 -070030#include "clspv/Option.h"
31
Diego Novilloa4c44fa2019-04-11 10:56:15 -040032#include "Passes.h"
33
David Netoab03f432017-11-03 17:00:44 -040034using namespace llvm;
35using std::string;
36
David Neto8e39bd12017-11-14 21:08:12 -050037#define DEBUG_TYPE "rewriteinserts"
David Netoab03f432017-11-03 17:00:44 -040038
David Netoab03f432017-11-03 17:00:44 -040039namespace {
40
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040041const char *kCompositeConstructFunctionPrefix = "clspv.composite_construct.";
David Netoab03f432017-11-03 17:00:44 -040042
43class RewriteInsertsPass : public ModulePass {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044public:
David Netoab03f432017-11-03 17:00:44 -040045 static char ID;
46 RewriteInsertsPass() : ModulePass(ID) {}
47
48 bool runOnModule(Module &M) override;
49
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040050private:
51 using InsertionVector = SmallVector<Instruction *, 4>;
David Netoab03f432017-11-03 17:00:44 -040052
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040053 // Replaces chains of insertions that cover the entire value.
54 // Such a change always reduces the number of instructions, so
55 // we always perform these. Returns true if the module was modified.
56 bool ReplaceCompleteInsertionChains(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050057
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040058 // Replaces all InsertValue instructions, even if they aren't part
59 // of a complete insetion chain. Returns true if the module was modified.
60 bool ReplacePartialInsertions(Module &M);
David Neto8e39bd12017-11-14 21:08:12 -050061
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040062 // Load |values| and |chain| with the members of the struct value produced
63 // by a chain of InsertValue instructions ending with |iv|, and following
64 // the aggregate operand. Return the start of the chain: the aggregate
65 // value which is not an InsertValue instruction, or an InsertValue
66 // instruction which inserts a component that is replaced later in the
67 // chain. The |values| vector will match the order of struct members and
68 // is initialized to all nullptr members. The |chain| vector will list
69 // the chain of InsertValue instructions, listed in the order we discover
70 // them, e.g. begining with |iv|.
71 Value *LoadValuesEndingWithInsertion(InsertValueInst *iv,
72 std::vector<Value *> *values,
73 InsertionVector *chain) {
74 auto *structTy = dyn_cast<StructType>(iv->getType());
75 assert(structTy);
76 const auto numElems = structTy->getNumElements();
David Neto8e39bd12017-11-14 21:08:12 -050077
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040078 // Walk backward from the tail to an instruction we don't want to
79 // replace.
80 Value *frontier = iv;
81 while (auto *insertion = dyn_cast<InsertValueInst>(frontier)) {
82 chain->push_back(insertion);
83 // Only handle single-index insertions.
84 if (insertion->getNumIndices() == 1) {
85 // Try to replace this one.
David Neto8e39bd12017-11-14 21:08:12 -050086
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040087 unsigned index = insertion->getIndices()[0];
88 assert(index < numElems);
89 if ((*values)[index] != nullptr) {
90 // We already have a value for this slot. Stop now.
91 break;
92 }
93 (*values)[index] = insertion->getInsertedValueOperand();
94 frontier = insertion->getAggregateOperand();
95 } else {
96 break;
97 }
98 }
99 return frontier;
100 }
David Neto8e39bd12017-11-14 21:08:12 -0500101
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400102 // Returns the number of elements in the struct or array.
103 unsigned GetNumElements(Type *type) {
104 // CompositeType doesn't implement getNumElements(), but its inheritors
105 // do.
106 if (auto *struct_ty = dyn_cast<StructType>(type)) {
107 return struct_ty->getNumElements();
108 } else if (auto *seq_ty = dyn_cast<SequentialType>(type)) {
109 return seq_ty->getNumElements();
110 }
111 return 0;
112 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400113
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400114 // If this is the tail of a chain of InsertValueInst instructions
115 // that covers the entire composite, then return a small vector
116 // containing the insertion instructions, in member order.
117 // Otherwise returns nullptr.
118 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
119 if (iv->getNumIndices() == 1) {
120 auto numElems = GetNumElements(iv->getType());
121 if (numElems != 0) {
122 // Only handle single-index insertions.
123 unsigned index = iv->getIndices()[0];
124 if (index + 1u != numElems) {
125 // Not the last in the chain.
126 return nullptr;
127 }
128 InsertionVector candidates(numElems, nullptr);
129 for (unsigned i = index;
130 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
131 // iv inserts the i'th member
132 candidates[i] = iv;
Alan Bakerabc935e2018-09-06 11:33:53 -0400133
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400134 if (i == 0) {
135 // We're done!
136 return new InsertionVector(candidates);
137 }
David Netoab03f432017-11-03 17:00:44 -0400138
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400139 if (InsertValueInst *agg =
140 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
141 iv = agg;
142 } else {
143 // The chain is broken.
144 break;
145 }
146 }
147 }
148 }
149 return nullptr;
150 }
David Netoab03f432017-11-03 17:00:44 -0400151
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400152 // If this is the tail of a chain of InsertElementInst instructions
153 // that covers the entire vector, then return a small vector
154 // containing the insertion instructions, in member order.
155 // Otherwise returns nullptr. Only handle insertions into vectors.
156 InsertionVector *CompleteInsertionChain(InsertElementInst *ie) {
157 // Don't handle i8 vectors. Only <4 x i8> is supported and it is
158 // translated as i32. Only handle single-index insertions.
159 if (auto vec_ty = dyn_cast<VectorType>(ie->getType())) {
160 if (vec_ty->getVectorElementType() == Type::getInt8Ty(ie->getContext())) {
161 return nullptr;
162 }
163 }
David Netoab03f432017-11-03 17:00:44 -0400164
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400165 // Only handle single-index insertions.
166 if (ie->getNumOperands() == 3) {
167 auto numElems = GetNumElements(ie->getType());
168 if (numElems != 0) {
169 if (auto *const_value = dyn_cast<ConstantInt>(ie->getOperand(2))) {
170 uint64_t index = const_value->getZExtValue();
171 if (index + 1u != numElems) {
172 // Not the last in the chain.
173 return nullptr;
174 }
175 InsertionVector candidates(numElems, nullptr);
176 Value *value = ie;
177 uint64_t i = index;
178 while (auto *insert = dyn_cast<InsertElementInst>(value)) {
179 if (insert->getNumOperands() != 3)
180 break;
181 if (auto *index_const =
182 dyn_cast<ConstantInt>(insert->getOperand(2))) {
183 if (i != index_const->getZExtValue())
184 break;
Alan Bakerabc935e2018-09-06 11:33:53 -0400185
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400186 candidates[i] = insert;
187 if (i == 0) {
188 // We're done!
189 return new InsertionVector(candidates);
190 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400191
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400192 value = insert->getOperand(0);
193 --i;
194 } else {
195 break;
196 }
197 }
198 } else {
199 return nullptr;
200 }
201 }
202 }
203 return nullptr;
204 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400205
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400206 // Return the name for the wrap function for the given type.
207 string &WrapFunctionNameForType(Type *type) {
208 auto where = function_for_type_.find(type);
209 if (where == function_for_type_.end()) {
210 // Insert it.
211 auto &result = function_for_type_[type] =
212 string(kCompositeConstructFunctionPrefix) +
213 std::to_string(function_for_type_.size());
214 return result;
215 } else {
216 return where->second;
217 }
218 }
David Neto8e39bd12017-11-14 21:08:12 -0500219
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400220 // Get or create the composite construct function definition.
alan-baker077517b2020-03-19 13:52:12 -0400221 Function *GetConstructFunction(Module &M, Type *constructed_type) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400222 // Get or create the composite construct function definition.
223 const string &fn_name = WrapFunctionNameForType(constructed_type);
224 Function *fn = M.getFunction(fn_name);
225 if (!fn) {
226 // Make the function.
227 SmallVector<Type *, 16> elements;
228 unsigned num_elements = GetNumElements(constructed_type);
alan-baker077517b2020-03-19 13:52:12 -0400229 if (auto struct_ty = dyn_cast<StructType>(constructed_type)) {
230 for (unsigned i = 0; i != num_elements; ++i)
231 elements.push_back(struct_ty->getTypeAtIndex(i));
232 } else {
233 elements.resize(
234 num_elements,
235 cast<SequentialType>(constructed_type)->getElementType());
236 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400237 FunctionType *fnTy = FunctionType::get(constructed_type, elements, false);
238 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
239 fn = cast<Function>(fn_constant.getCallee());
240 fn->addFnAttr(Attribute::ReadOnly);
241 }
242 return fn;
243 }
David Netoab03f432017-11-03 17:00:44 -0400244
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400245 // Maps a loaded type to the name of the wrap function for that type.
246 DenseMap<Type *, string> function_for_type_;
David Netoab03f432017-11-03 17:00:44 -0400247};
248} // namespace
249
250char RewriteInsertsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -0400251INITIALIZE_PASS(RewriteInsertsPass, "RewriteInserts",
252 "Rewrite chains of insertvalue to as composite-construction",
253 false, false)
David Netoab03f432017-11-03 17:00:44 -0400254
255namespace clspv {
256llvm::ModulePass *createRewriteInsertsPass() {
257 return new RewriteInsertsPass();
258}
259} // namespace clspv
260
261bool RewriteInsertsPass::runOnModule(Module &M) {
David Neto8e39bd12017-11-14 21:08:12 -0500262 bool Changed = ReplaceCompleteInsertionChains(M);
263
David Neto482550a2018-03-24 05:21:07 -0700264 if (clspv::Option::HackInserts()) {
David Neto8e39bd12017-11-14 21:08:12 -0500265 Changed |= ReplacePartialInsertions(M);
266 }
267
268 return Changed;
269}
270
271bool RewriteInsertsPass::ReplaceCompleteInsertionChains(Module &M) {
David Netoab03f432017-11-03 17:00:44 -0400272 bool Changed = false;
273
274 SmallVector<InsertionVector *, 16> WorkList;
275 for (Function &F : M) {
276 for (BasicBlock &BB : F) {
277 for (Instruction &I : BB) {
278 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
279 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
280 WorkList.push_back(insertions);
281 }
Alan Bakerabc935e2018-09-06 11:33:53 -0400282 } else if (InsertElementInst *ie = dyn_cast<InsertElementInst>(&I)) {
283 if (InsertionVector *insertions = CompleteInsertionChain(ie)) {
284 WorkList.push_back(insertions);
285 }
David Netoab03f432017-11-03 17:00:44 -0400286 }
287 }
288 }
289 }
290
291 if (WorkList.size() == 0) {
292 return Changed;
293 }
294
295 for (InsertionVector *insertions : WorkList) {
296 Changed = true;
297
298 // Gather the member values and types.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400299 SmallVector<Value *, 4> values;
300 for (Instruction *inst : *insertions) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400301 if (auto *insert_value = dyn_cast<InsertValueInst>(inst)) {
302 values.push_back(insert_value->getInsertedValueOperand());
303 } else if (auto *insert_element = dyn_cast<InsertElementInst>(inst)) {
304 values.push_back(insert_element->getOperand(1));
305 } else {
306 llvm_unreachable("Unhandled insertion instruction");
307 }
David Netoab03f432017-11-03 17:00:44 -0400308 }
309
alan-baker077517b2020-03-19 13:52:12 -0400310 auto *resultTy = insertions->back()->getType();
Alan Bakerabc935e2018-09-06 11:33:53 -0400311 Function *fn = GetConstructFunction(M, resultTy);
David Netoab03f432017-11-03 17:00:44 -0400312
313 // Replace the chain.
314 auto call = CallInst::Create(fn, values);
315 call->insertAfter(insertions->back());
316 insertions->back()->replaceAllUsesWith(call);
317
318 // Remove the insertions if we can. Go from the tail back to
319 // the head, since the tail uses the previous insertion, etc.
320 for (auto iter = insertions->rbegin(), end = insertions->rend();
321 iter != end; ++iter) {
Alan Bakerabc935e2018-09-06 11:33:53 -0400322 Instruction *insertion = *iter;
David Netoab03f432017-11-03 17:00:44 -0400323 if (!insertion->hasNUsesOrMore(1)) {
324 insertion->eraseFromParent();
325 }
326 }
327
328 delete insertions;
329 }
330
331 return Changed;
332}
David Neto8e39bd12017-11-14 21:08:12 -0500333
334bool RewriteInsertsPass::ReplacePartialInsertions(Module &M) {
335 bool Changed = false;
336
337 // First find candidates. Collect all InsertValue instructions
338 // into struct type, but track their interdependencies. To minimize
339 // the number of new instructions, generate a construction for each
340 // tail of an insertion chain.
341
342 UniqueVector<InsertValueInst *> insertions;
343 for (Function &F : M) {
344 for (BasicBlock &BB : F) {
345 for (Instruction &I : BB) {
346 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
347 if (iv->getType()->isStructTy()) {
348 insertions.insert(iv);
349 }
350 }
351 }
352 }
353 }
354
355 // Now count how many times each InsertValue is used by another InsertValue.
356 // The |num_uses| vector is indexed by the unique id that |insertions|
357 // assigns to it.
358 std::vector<unsigned> num_uses(insertions.size() + 1);
359 // Count from the user's perspective.
360 for (InsertValueInst *insertion : insertions) {
361 if (auto *agg =
362 dyn_cast<InsertValueInst>(insertion->getAggregateOperand())) {
363 ++(num_uses[insertions.idFor(agg)]);
364 }
365 }
366
367 // Proceed in rounds. Each round rewrites any chains ending with an
368 // insertion that is not used by another insertion.
369
370 // Get the first list of insertion tails.
371 InsertionVector WorkList;
372 for (InsertValueInst *insertion : insertions) {
373 if (num_uses[insertions.idFor(insertion)] == 0) {
374 WorkList.push_back(insertion);
375 }
376 }
377
378 // This records insertions in the order they should be removed.
379 // In this list, an insertion preceds any insertions it uses.
380 // (This is post-dominance order.)
381 InsertionVector ordered_candidates_for_removal;
382
383 // Proceed in rounds.
384 while (WorkList.size()) {
385 Changed = true;
386
387 // Record the list of tails for the next round.
388 InsertionVector NextRoundWorkList;
389
Alan Bakerabc935e2018-09-06 11:33:53 -0400390 for (Instruction *inst : WorkList) {
391 InsertValueInst *insertion = cast<InsertValueInst>(inst);
David Neto8e39bd12017-11-14 21:08:12 -0500392 // Rewrite |insertion|.
393
394 StructType *resultTy = cast<StructType>(insertion->getType());
395
396 const unsigned num_members = resultTy->getNumElements();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400397 std::vector<Value *> members(num_members, nullptr);
David Neto8e39bd12017-11-14 21:08:12 -0500398 InsertionVector chain;
399 // Gather the member values. Walk backward from the insertion.
400 Value *base = LoadValuesEndingWithInsertion(insertion, &members, &chain);
401
402 // Populate remaining entries in |values| by extracting elements
403 // from |base|. Only make a new extractvalue instruction if we can't
404 // make a constant or undef. New instructions are inserted before
405 // the insertion we plan to remove.
406 for (unsigned i = 0; i < num_members; ++i) {
407 if (!members[i]) {
408 Type *memberTy = resultTy->getTypeAtIndex(i);
409 if (isa<UndefValue>(base)) {
410 members[i] = UndefValue::get(memberTy);
411 } else if (const auto *caz = dyn_cast<ConstantAggregateZero>(base)) {
412 members[i] = caz->getElementValue(i);
413 } else if (const auto *ca = dyn_cast<ConstantAggregate>(base)) {
414 members[i] = ca->getOperand(i);
415 } else {
416 members[i] = ExtractValueInst::Create(base, {i}, "", insertion);
417 }
418 }
419 }
420
421 // Create the call. It's dominated by any extractions we've just
422 // created.
Alan Bakerabc935e2018-09-06 11:33:53 -0400423 Function *construct_fn = GetConstructFunction(M, resultTy);
David Neto8e39bd12017-11-14 21:08:12 -0500424 auto *call = CallInst::Create(construct_fn, members, "", insertion);
425
426 // Disconnect this insertion. We'll remove it later.
427 insertion->replaceAllUsesWith(call);
428
429 // Trace backward through the chain, removing uses and deleting where
430 // we can. Stop at the first element that has a remaining use.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400431 for (auto *chainElem : chain) {
David Neto8e39bd12017-11-14 21:08:12 -0500432 if (chainElem->hasNUsesOrMore(1)) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400433 unsigned &use_count =
434 num_uses[insertions.idFor(cast<InsertValueInst>(chainElem))];
David Neto8e39bd12017-11-14 21:08:12 -0500435 assert(use_count > 0);
436 --use_count;
437 if (use_count == 0) {
438 NextRoundWorkList.push_back(chainElem);
439 }
440 break;
441 } else {
442 chainElem->eraseFromParent();
443 }
444 }
445 }
446 WorkList = std::move(NextRoundWorkList);
447 }
448
449 return Changed;
450}