blob: 158baaed8cb1d2b677507dca3691548c8e3ee434 [file] [log] [blame]
alan-bakere711c762020-05-20 17:56:59 -04001// Copyright 2018 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <utility>
16
17#include "llvm/ADT/UniqueVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/IR/Module.h"
21#include "llvm/Pass.h"
22#include "llvm/Support/raw_ostream.h"
23
24#include "Passes.h"
25
26using namespace llvm;
27
28#define DEBUG_TYPE "UndoTruncateToOddInteger"
29
30namespace {
31struct UndoTruncateToOddIntegerPass : public ModulePass {
32 static char ID;
33 UndoTruncateToOddIntegerPass() : ModulePass(ID) {}
34
35 bool runOnModule(Module &M) override;
36
37private:
38 // Maps a value to its zero-extended value. This is the memoization table for
39 // ZeroExtend.
40 DenseMap<Value *, Value *> extended_value_;
41
42 // Returns a 32-bit zero-extended version of the given argument.
43 // Candidates for erasure are added to |zombies_|, before their feeding
44 // values are created.
45 // TODO(dneto): Handle 64 bit case as well, but separately.
46 Value *ZeroExtend(Value *v, uint32_t desired_bit_width) {
Marco Antognini7e338402021-03-15 12:48:37 +000047 unsigned bit_width = 0;
alan-bakere711c762020-05-20 17:56:59 -040048 if (v->getType()->isIntegerTy())
49 bit_width = v->getType()->getIntegerBitWidth();
50 if (bit_width > 32) {
51 errs() << "Unhandled bit width for " << *v << "\n";
52 llvm_unreachable("Unhandled bit width");
53 }
54
55 auto where = extended_value_.find(v);
56 if (where != extended_value_.end()) {
57 return where->second;
58 }
59
60 // This base case makes for easier recursion.
61 if (bit_width == desired_bit_width && !isa<ZExtInst>(v))
62 return v;
63
64 auto desired_int_ty = IntegerType::get(v->getContext(), desired_bit_width);
65 if (auto *ci = dyn_cast<ConstantInt>(v)) {
66 return ConstantInt::get(desired_int_ty, uint32_t(ci->getZExtValue()));
67 }
68 Value *result = nullptr;
69 if (auto *trunc = dyn_cast<TruncInst>(v)) {
70 Value *tmp = nullptr;
71 auto input = trunc->getOperand(0);
72 uint32_t input_bit_width = input->getType()->getIntegerBitWidth();
73 if (input_bit_width > desired_bit_width) {
74 tmp = new TruncInst(input, desired_int_ty, "", trunc);
75 } else if (input_bit_width == desired_bit_width) {
76 tmp = input;
77 } else if (input_bit_width > bit_width) {
78 tmp = new ZExtInst(input, desired_int_ty, "", trunc);
79 } else {
80 tmp = ZeroExtend(input, desired_bit_width);
81 }
82
83 // Now, and the extended version to keep the range of the output
84 // restricted to the original bit width.
85 result = BinaryOperator::Create(
86 Instruction::And, tmp,
87 ConstantInt::get(
88 desired_int_ty,
89 (uint32_t)APInt::getAllOnesValue(bit_width).getZExtValue()),
90 "", trunc);
91 } else if (auto *zext = dyn_cast<ZExtInst>(v)) {
92 auto tmp = ZeroExtend(zext->getOperand(0), desired_bit_width);
93 uint32_t zext_width = zext->getType()->getIntegerBitWidth();
94 //
95 if (zext_width < desired_bit_width) {
96 result = new TruncInst(tmp, zext->getType(), "", zext);
97 } else if (zext_width > desired_bit_width) {
98 zext->setOperand(0, tmp);
99 result = zext;
100 } else {
101 result = tmp;
102 }
103 } else if (auto *phi = dyn_cast<PHINode>(v)) {
104 const auto num_branches = phi->getNumIncomingValues();
105 PHINode *new_phi = PHINode::Create(desired_int_ty, num_branches, "", phi);
106 for (unsigned i = 0; i < num_branches; i++) {
107 new_phi->addIncoming(
108 ZeroExtend(phi->getIncomingValue(i), desired_bit_width),
109 phi->getIncomingBlock(i));
110 }
111 result = new_phi;
112 } else if (auto *sel = dyn_cast<SelectInst>(v)) {
113 auto *ext_true = ZeroExtend(sel->getTrueValue(), desired_bit_width);
114 auto *ext_false = ZeroExtend(sel->getFalseValue(), desired_bit_width);
115 result =
116 SelectInst::Create(sel->getCondition(), ext_true, ext_false, "", sel);
117 } else if (auto *binop = dyn_cast<BinaryOperator>(v)) {
118 // White-list binary operators that are ok to transform.
119 if (binop->getOpcode() == Instruction::Add ||
120 binop->getOpcode() == Instruction::Sub ||
121 binop->getOpcode() == Instruction::Mul ||
122 binop->getOpcode() == Instruction::And ||
123 binop->getOpcode() == Instruction::Or ||
124 binop->getOpcode() == Instruction::Xor) {
125 auto *op1 = ZeroExtend(binop->getOperand(0), desired_bit_width);
126 auto *op2 = ZeroExtend(binop->getOperand(1), desired_bit_width);
127 result =
128 BinaryOperator::Create(binop->getOpcode(), op1, op2, "", binop);
129 if (binop->getOpcode() == Instruction::Add ||
130 binop->getOpcode() == Instruction::Sub ||
131 binop->getOpcode() == Instruction::Mul) {
132 // Add an extra masking for add and sub in case of integer wrapping.
133 result = BinaryOperator::Create(
134 Instruction::And, result,
135 ConstantInt::get(
136 desired_int_ty,
137 (uint32_t)APInt::getAllOnesValue(bit_width).getZExtValue()),
138 "", binop);
139 }
140 } else {
141 errs() << "Unhandled instruction feeding switch " << *v << "\n";
142 llvm_unreachable("Unhandled instruction feeding switch!");
143 }
144 } else if (auto SI = dyn_cast<SwitchInst>(v)) {
145 auto extended_cond = ZeroExtend(SI->getCondition(), desired_bit_width);
146 if (extended_cond && extended_cond != SI->getCondition()) {
147 SI->setCondition(extended_cond);
148 for (auto Cases : SI->cases()) {
149 // The original value of the case.
150 auto V = Cases.getCaseValue()->getZExtValue();
151
152 // A new value for the case with the correct type.
153 auto CI = dyn_cast<ConstantInt>(ConstantInt::get(desired_int_ty, V));
154
155 // And we replace the old value.
156 Cases.setValue(CI);
157 }
158 }
159 } else if (auto inst = dyn_cast<Instruction>(v)) {
160 for (unsigned i = 0; i < inst->getNumOperands(); ++i) {
161 auto extended_op = ZeroExtend(inst->getOperand(i), desired_bit_width);
162 if (extended_op && extended_op != inst->getOperand(i))
163 inst->setOperand(i, extended_op);
164 }
165 } else {
166 errs() << "Unhandled instruction " << *v << "\n";
167 llvm_unreachable("Unhandled instruction!");
168 }
169
170 // If the instruction was replaced, mark it as a zombie.
171 if (auto *inst = dyn_cast<Instruction>(v)) {
172 if (result && result != inst)
173 zombies_.insert(inst);
174 }
175
176 if (result)
177 extended_value_[v] = result;
178 return result;
179 }
180
181 // The list of things that might be dead.
182 UniqueVector<Instruction *> zombies_;
183};
184} // namespace
185
186char UndoTruncateToOddIntegerPass::ID = 0;
187INITIALIZE_PASS(UndoTruncateToOddIntegerPass, "UndoTruncateToOddInteger",
188 "Undo Truncated Switch Condition Pass", false, false)
189
190namespace clspv {
191ModulePass *createUndoTruncateToOddIntegerPass() {
192 return new UndoTruncateToOddIntegerPass();
193}
194} // namespace clspv
195
196bool UndoTruncateToOddIntegerPass::runOnModule(Module &M) {
197 bool Changed = false;
198
199 SmallVector<std::pair<Instruction *, uint32_t>, 8> WorkList;
200 for (Function &F : M) {
201 for (BasicBlock &BB : F) {
202 for (Instruction &I : BB) {
203 if (auto trunc = dyn_cast<TruncInst>(&I)) {
204 if (trunc->getType()->isVectorTy())
205 continue;
206 auto desired_bit_width =
207 trunc->getOperand(0)->getType()->getIntegerBitWidth();
208 switch (trunc->getType()->getIntegerBitWidth()) {
209 default:
210 WorkList.push_back(std::make_pair(
211 trunc, static_cast<uint32_t>(PowerOf2Ceil(desired_bit_width))));
212 break;
213 case 1: // i1 is a bool.
214 case 8:
215 case 16:
216 case 32:
217 case 64:
218 break;
219 }
220 }
221 }
222 }
223 }
224
225 zombies_.reset();
226
227 while (!WorkList.empty()) {
228 auto inst = WorkList.back().first;
229 auto desired_bit_width = WorkList.back().second;
230 WorkList.pop_back();
231
232 auto extended = ZeroExtend(inst, desired_bit_width);
233 if (extended && extended != inst) {
234 Changed = true;
235
236 for (auto user : inst->users()) {
237 if (auto user_inst = dyn_cast<Instruction>(user)) {
238 WorkList.push_back(std::make_pair(user_inst, desired_bit_width));
239 }
240 }
241 }
242 }
243
244 // Remove the zombies if we can. We expect to. We've ordered zombies in
245 // reverse.
246 for (int i = zombies_.size(); i >= 1; --i) {
247 auto zombie = zombies_[i];
248 if (!zombie->hasNUsesOrMore(1))
249 zombie->eraseFromParent();
250 }
251
252 return Changed;
253}