blob: 844637c47a272220940284f8e504ae218aa3a3a6 [file] [log] [blame]
alan-baker13568382020-04-02 17:29:27 -04001// Copyright 2020 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <vector>
16
17#include "llvm/ADT/UniqueVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/Module.h"
alan-baker90749232020-04-07 20:36:20 -040022#include "llvm/IR/Operator.h"
alan-baker13568382020-04-02 17:29:27 -040023#include "llvm/Pass.h"
24
25#include "Passes.h"
26
27#define DEBUG_TYPE "undoinstcombine"
28
29using namespace llvm;
30
31namespace {
32class UndoInstCombinePass : public ModulePass {
33public:
34 static char ID;
35 UndoInstCombinePass() : ModulePass(ID) {}
36
37 bool runOnModule(Module &M) override;
38
39private:
40 bool runOnFunction(Function &F);
41
42 // Undoes wide vector casts that are used in an extract, for example:
43 // %cast = bitcast <4 x i32> %src to <16 x i8>
44 // %extract = extractelement <16 x i8> %cast, i32 4
45 //
46 // With:
47 // %extract = extractelement <4 x i32> %src, i32 1
48 // %trunc = trunc i32 %extract to i8
alan-baker90749232020-04-07 20:36:20 -040049 //
50 // Also handles casts that get loaded, for example:
51 // %cast = bitcast <3 x i32>* %src to <6 x i16>*
52 // %load = load <6 x i16>, <6 x i16>* %cast
53 // %extract = extractelement <6 x i16> %load, i32 0
54 //
55 // With:
56 // %load = load <3 x i32>, <3 x i32>* %src
57 // %extract = extractelement <3 x i32> %load, i32 0
58 // %trunc = trunc i32 %extract to i16
alan-baker13568382020-04-02 17:29:27 -040059 bool UndoWideVectorExtractCast(Instruction *inst);
60
61 // Undoes wide vector casts that are used in a shuffle, for example:
62 // %cast = bitcast <4 x i32> %src to <16 x i8>
63 // %s = shufflevector <16 x i8> %cast, <16 x i8> undef,
64 // <2 x i8> <i32 4, i32 8>
65 //
66 // With:
67 // %extract0 = <4 x i32> %src, i32 1
68 // %trunc0 = trunc i32 %extract0 to i8
69 // %insert0 = insertelement <2 x i8> zeroinitializer, i8 %trunc0, i32 0
70 // %extract1 = <4 x i32> %src, i32 2
71 // %trunc1 = trunc i32 %extract1 to i8
72 // %insert1 = insertelement <2 x i8> %insert0, i8 %trunc1, i32 1
alan-baker90749232020-04-07 20:36:20 -040073 //
74 // Also handles shuffles casted through a load, for example:
75 // %cast = bitcast <3 x i32>* %src to <6 x i16>
76 // %load = load <6 x i16>* %cast
77 // %shuffle = shufflevector <6 x i16> %load, <6 x i16> undef,
78 // <2 x i32> <i32 2, i32 4>
79 //
80 // With:
81 // %load = load <3 x i32>, <3 x i32>* %src
82 // %ex0 = extractelement <3 x i32> %load, i32 1
83 // %trunc0 = trunc i32 %ex0 to i16
84 // %in0 = insertelement <2 x i16> zeroinitializer, i16 %trunc0, i32 0
85 // %ex1 = extractelement <3 x i32> %load, i32 2
86 // %trunc1 = trunc i32 %ex1 to i16
87 // %in1 = insertelement <2 x i16> %in0, i16 %trunc1, i32 1
alan-baker13568382020-04-02 17:29:27 -040088 bool UndoWideVectorShuffleCast(Instruction *inst);
89
alan-baker90749232020-04-07 20:36:20 -040090 UniqueVector<Value *> potentially_dead_;
alan-baker13568382020-04-02 17:29:27 -040091 std::vector<Instruction *> dead_;
92};
93} // namespace
94
95char UndoInstCombinePass::ID = 0;
96INITIALIZE_PASS(UndoInstCombinePass, "UndoInstCombine",
97 "Undo specific harmful instcombine transformations", false,
98 false)
99
100namespace clspv {
101ModulePass *createUndoInstCombinePass() { return new UndoInstCombinePass(); }
102} // namespace clspv
103
104bool UndoInstCombinePass::runOnModule(Module &M) {
105 bool changed = false;
106
107 for (auto &F : M) {
108 changed |= runOnFunction(F);
109 }
110
111 // Cleanup.
112 for (auto inst : dead_)
113 inst->eraseFromParent();
114
alan-baker90749232020-04-07 20:36:20 -0400115 for (auto val : potentially_dead_) {
116 if (auto inst = dyn_cast<Instruction>(val)) {
117 if (inst->user_empty())
118 inst->eraseFromParent();
119 } else if (auto cast = dyn_cast<BitCastOperator>(val)) {
120 if (auto constant = dyn_cast<Constant>(cast->getOperand(0)))
121 constant->removeDeadConstantUsers();
122 }
alan-baker13568382020-04-02 17:29:27 -0400123 }
124
125 return changed;
126}
127
128bool UndoInstCombinePass::runOnFunction(Function &F) {
129 bool changed = false;
130
131 for (auto &BB : F) {
132 for (auto &I : BB) {
133 changed |= UndoWideVectorExtractCast(&I);
134 changed |= UndoWideVectorShuffleCast(&I);
135 }
136 }
137
138 return changed;
139}
140
141bool UndoInstCombinePass::UndoWideVectorExtractCast(Instruction *inst) {
142 auto extract = dyn_cast<ExtractElementInst>(inst);
143 if (!extract)
144 return false;
145
146 auto vec_ty = extract->getVectorOperandType();
147 if (vec_ty->getElementCount().Min <= 4)
148 return false;
149
150 // Instcombine only transforms TruncInst (which operates on integers).
James Pricecf53df42020-04-20 14:41:24 -0400151 if (!vec_ty->getElementType()->isIntegerTy())
alan-baker13568382020-04-02 17:29:27 -0400152 return false;
153
154 auto const_idx = dyn_cast<ConstantInt>(extract->getIndexOperand());
155 if (!const_idx)
156 return false;
157
alan-baker90749232020-04-07 20:36:20 -0400158 auto load = dyn_cast<LoadInst>(extract->getVectorOperand());
159 auto cast = dyn_cast<BitCastOperator>(extract->getVectorOperand());
160 if (load) {
161 // If this is a laod, check for a cast on the pointer operand
162 cast = dyn_cast<BitCastOperator>(load->getPointerOperand());
163 }
164
alan-baker13568382020-04-02 17:29:27 -0400165 if (!cast)
166 return false;
167
168 auto src = cast->getOperand(0);
alan-baker90749232020-04-07 20:36:20 -0400169 VectorType *src_vec_ty = nullptr;
170 if (isa<PointerType>(src->getType()))
171 // In the load cast, go through the pointer first.
172 src_vec_ty = dyn_cast<VectorType>(src->getType()->getPointerElementType());
173 else
174 src_vec_ty = dyn_cast<VectorType>(src->getType());
175
alan-baker13568382020-04-02 17:29:27 -0400176 if (!src_vec_ty)
177 return false;
178
179 uint64_t src_elements = src_vec_ty->getElementCount().Min;
180 uint64_t dst_elements = vec_ty->getElementCount().Min;
181
182 if (dst_elements < src_elements)
183 return false;
184
185 uint64_t idx = const_idx->getZExtValue();
186 uint64_t ratio = dst_elements / src_elements;
187 uint64_t new_idx = idx / ratio;
188
189 // Instcombine should never have generated an odd index, so don't handle
190 // right now.
191 if (idx & 0x1)
192 return false;
193
194 // Create a truncate of an extract element.
195 IRBuilder<> builder(inst);
alan-baker90749232020-04-07 20:36:20 -0400196 Value *new_src = nullptr;
197 if (load) {
198 potentially_dead_.insert(load);
199 new_src = builder.CreateLoad(src);
200 src = new_src;
201 }
202 new_src = builder.CreateExtractElement(src, builder.getInt32(new_idx));
alan-baker13568382020-04-02 17:29:27 -0400203 auto trunc = builder.CreateTrunc(new_src, extract->getType());
204 extract->replaceAllUsesWith(trunc);
205 dead_.push_back(extract);
206 potentially_dead_.insert(cast);
207
208 return true;
209}
210
211bool UndoInstCombinePass::UndoWideVectorShuffleCast(Instruction *inst) {
212 auto shuffle = dyn_cast<ShuffleVectorInst>(inst);
213 if (!shuffle)
214 return false;
215
216 // Instcombine only transforms TruncInst (which operates on integers).
217 auto vec_ty = cast<VectorType>(shuffle->getType());
James Pricecf53df42020-04-20 14:41:24 -0400218 if (!vec_ty->getElementType()->isIntegerTy())
alan-baker13568382020-04-02 17:29:27 -0400219 return false;
220
221 auto in1 = shuffle->getOperand(0);
222 auto in1_vec_ty = cast<VectorType>(in1->getType());
223 if (in1_vec_ty->getElementCount().Min <= 4)
224 return false;
225
alan-baker90749232020-04-07 20:36:20 -0400226 auto in1_load = dyn_cast<LoadInst>(in1);
227 auto in1_cast = dyn_cast<BitCastOperator>(in1);
228 if (in1_load) {
229 // If this is a laod, check for a cast on the pointer operand
230 in1_cast = dyn_cast<BitCastOperator>(in1_load->getPointerOperand());
231 }
232
alan-baker13568382020-04-02 17:29:27 -0400233 if (!in1_cast)
234 return false;
235
236 // Instcombine only produces shuffles with an undef second input, so don't
237 // handle other cases for now.
238 if (!isa<UndefValue>(shuffle->getOperand(1)))
239 return false;
240
241 auto src = in1_cast->getOperand(0);
alan-baker90749232020-04-07 20:36:20 -0400242 VectorType *src_vec_ty = nullptr;
243 if (isa<PointerType>(src->getType()))
244 // In the load cast, go through the pointer first.
245 src_vec_ty = dyn_cast<VectorType>(src->getType()->getPointerElementType());
246 else
247 src_vec_ty = dyn_cast<VectorType>(src->getType());
248
alan-baker13568382020-04-02 17:29:27 -0400249 if (!src_vec_ty)
250 return false;
251
252 uint64_t src_elements = src_vec_ty->getElementCount().Min;
253 uint64_t dst_elements = in1_vec_ty->getElementCount().Min;
254
255 if (dst_elements < src_elements)
256 return false;
257
258 uint64_t ratio = dst_elements / src_elements;
James Pricecf53df42020-04-20 14:41:24 -0400259 auto dst_scalar_type = vec_ty->getElementType();
alan-baker13568382020-04-02 17:29:27 -0400260
261 SmallVector<int, 16> mask;
262 shuffle->getShuffleMask(mask);
263 for (auto i : mask) {
264 // Instcombine should not have generated odd indices, so don't handle them
265 // for now.
266 if ((i != UndefMaskElem) && (i & 0x1))
267 return false;
268 }
269
270 // For each index, create a truncate of an extract element and insert each
271 // into the result vector.
272 IRBuilder<> builder(inst);
273 Value *insert = nullptr;
alan-baker90749232020-04-07 20:36:20 -0400274 if (in1_load) {
275 potentially_dead_.insert(in1_load);
276 src = builder.CreateLoad(src);
277 }
278
alan-baker13568382020-04-02 17:29:27 -0400279 int i = 0;
280 for (auto idx : mask) {
281 if (idx == UndefMaskElem)
282 continue;
283
284 uint64_t new_idx = idx / ratio;
285 auto extract = builder.CreateExtractElement(src, builder.getInt32(new_idx));
286 auto trunc = builder.CreateTrunc(extract, dst_scalar_type);
287 Value *prev = insert ? insert : Constant::getNullValue(vec_ty);
288 insert = builder.CreateInsertElement(prev, trunc, builder.getInt32(i++));
289 }
290 if (!insert) {
291 insert = Constant::getNullValue(vec_ty);
292 }
293 shuffle->replaceAllUsesWith(insert);
294 dead_.push_back(shuffle);
295 potentially_dead_.insert(in1_cast);
296
297 return true;
298}