blob: 136853c7b2a5ee61a2ccbc82f033328b3ce61389 [file] [log] [blame]
David Netoab03f432017-11-03 17:00:44 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/ADT/SmallVector.h>
19#include <llvm/IR/Attributes.h>
20#include <llvm/IR/DerivedTypes.h>
21#include <llvm/IR/Function.h>
22#include <llvm/IR/Instructions.h>
23#include <llvm/IR/Module.h>
24#include <llvm/Pass.h>
25#include <llvm/Support/raw_ostream.h>
26
27using namespace llvm;
28using std::string;
29
30#define DEBUG_TYPE "collapsecompositeinserts"
31
32
33namespace {
34
35const char* kCompositeConstructFunctionPrefix = "clspv.composite_construct.";
36
37class RewriteInsertsPass : public ModulePass {
38 public:
39 static char ID;
40 RewriteInsertsPass() : ModulePass(ID) {}
41
42 bool runOnModule(Module &M) override;
43
44 private:
45 using InsertionVector = SmallVector<InsertValueInst *, 4>;
46
47 // If this is the tail of a chain of InsertValueInst instructions
48 // that covers the entire composite, then return a small vector
49 // containing the insertion instructions, in member order.
50 // Otherwise returns nullptr. Only handle insertions into structs,
51 // not into arrays.
52 InsertionVector *CompleteInsertionChain(InsertValueInst *iv) {
53 if (iv->getNumIndices() == 1) {
54 if (auto *structTy = dyn_cast<StructType>(iv->getType())) {
55 auto numElems = structTy->getNumElements();
56 // Only handle single-index insertions.
57 unsigned index = iv->getIndices()[0];
58 if (index + 1u != numElems) {
59 // Not the last in the chain.
60 return nullptr;
61 }
62 InsertionVector candidates(numElems, nullptr);
63 for (unsigned i = index;
64 iv->getNumIndices() == 1 && i == iv->getIndices()[0]; --i) {
65 // iv inserts the i'th member
66 candidates[i] = iv;
67
68 if (i == 0) {
69 // We're done!
70 return new InsertionVector(candidates);
71 }
72
73 if (InsertValueInst *agg =
74 dyn_cast<InsertValueInst>(iv->getAggregateOperand())) {
75 iv = agg;
76 } else {
77 // The chain is broken.
78 break;
79 }
80 }
81 }
82 }
83 return nullptr;
84 }
85
86 // Return the name for the wrap function for the given type.
87 string &WrapFunctionNameForType(Type *type) {
88 auto where = function_for_type_.find(type);
89 if (where == function_for_type_.end()) {
90 // Insert it.
91 auto &result = function_for_type_[type] =
92 string(kCompositeConstructFunctionPrefix) +
93 std::to_string(function_for_type_.size());
94 return result;
95 } else {
96 return where->second;
97 }
98 }
99
100 // Maps a loaded type to the name of the wrap function for that type.
101 DenseMap<Type *, string> function_for_type_;
102};
103} // namespace
104
105char RewriteInsertsPass::ID = 0;
106static RegisterPass<RewriteInsertsPass>
107 X("RewriteInserts",
108 "Rewrite chains of insertvalue to as composite-construction");
109
110namespace clspv {
111llvm::ModulePass *createRewriteInsertsPass() {
112 return new RewriteInsertsPass();
113}
114} // namespace clspv
115
116bool RewriteInsertsPass::runOnModule(Module &M) {
117 bool Changed = false;
118
119 SmallVector<InsertionVector *, 16> WorkList;
120 for (Function &F : M) {
121 for (BasicBlock &BB : F) {
122 for (Instruction &I : BB) {
123 if (InsertValueInst *iv = dyn_cast<InsertValueInst>(&I)) {
124 if (InsertionVector *insertions = CompleteInsertionChain(iv)) {
125 WorkList.push_back(insertions);
126 }
127 }
128 }
129 }
130 }
131
132 if (WorkList.size() == 0) {
133 return Changed;
134 }
135
136 for (InsertionVector *insertions : WorkList) {
137 Changed = true;
138
139 // Gather the member values and types.
140 SmallVector<Value*, 4> values;
141 SmallVector<Type*, 4> types;
142 for (InsertValueInst* insert : *insertions) {
143 Value* value = insert->getInsertedValueOperand();
144 values.push_back(value);
145 types.push_back(value->getType());
146 }
147
148 Type* resultTy = insertions->back()->getType();
149
150 // Get or create the composite construct function definition.
151 const string& fn_name = WrapFunctionNameForType(resultTy);
152 Function* fn = M.getFunction(fn_name);
153 if (!fn) {
154 // Make the function.
155 FunctionType* fnTy = FunctionType::get(resultTy, types, false);
156 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
157 fn = cast<Function>(fn_constant);
158 fn->addFnAttr(Attribute::ReadOnly);
159 fn->addFnAttr(Attribute::ReadNone);
160 }
161
162 // Replace the chain.
163 auto call = CallInst::Create(fn, values);
164 call->insertAfter(insertions->back());
165 insertions->back()->replaceAllUsesWith(call);
166
167 // Remove the insertions if we can. Go from the tail back to
168 // the head, since the tail uses the previous insertion, etc.
169 for (auto iter = insertions->rbegin(), end = insertions->rend();
170 iter != end; ++iter) {
171 InsertValueInst *insertion = *iter;
172 if (!insertion->hasNUsesOrMore(1)) {
173 insertion->eraseFromParent();
174 }
175 }
176
177 delete insertions;
178 }
179
180 return Changed;
181}