blob: e13f7a63ad9e4b4cf952eab0d7c317c69d231069 [file] [log] [blame]
alan-baker13568382020-04-02 17:29:27 -04001// Copyright 2020 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <vector>
16
17#include "llvm/ADT/UniqueVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/Module.h"
22#include "llvm/Pass.h"
23
24#include "Passes.h"
25
26#define DEBUG_TYPE "undoinstcombine"
27
28using namespace llvm;
29
30namespace {
31class UndoInstCombinePass : public ModulePass {
32public:
33 static char ID;
34 UndoInstCombinePass() : ModulePass(ID) {}
35
36 bool runOnModule(Module &M) override;
37
38private:
39 bool runOnFunction(Function &F);
40
41 // Undoes wide vector casts that are used in an extract, for example:
42 // %cast = bitcast <4 x i32> %src to <16 x i8>
43 // %extract = extractelement <16 x i8> %cast, i32 4
44 //
45 // With:
46 // %extract = extractelement <4 x i32> %src, i32 1
47 // %trunc = trunc i32 %extract to i8
48 bool UndoWideVectorExtractCast(Instruction *inst);
49
50 // Undoes wide vector casts that are used in a shuffle, for example:
51 // %cast = bitcast <4 x i32> %src to <16 x i8>
52 // %s = shufflevector <16 x i8> %cast, <16 x i8> undef,
53 // <2 x i8> <i32 4, i32 8>
54 //
55 // With:
56 // %extract0 = <4 x i32> %src, i32 1
57 // %trunc0 = trunc i32 %extract0 to i8
58 // %insert0 = insertelement <2 x i8> zeroinitializer, i8 %trunc0, i32 0
59 // %extract1 = <4 x i32> %src, i32 2
60 // %trunc1 = trunc i32 %extract1 to i8
61 // %insert1 = insertelement <2 x i8> %insert0, i8 %trunc1, i32 1
62 bool UndoWideVectorShuffleCast(Instruction *inst);
63
64 UniqueVector<Instruction *> potentially_dead_;
65 std::vector<Instruction *> dead_;
66};
67} // namespace
68
69char UndoInstCombinePass::ID = 0;
70INITIALIZE_PASS(UndoInstCombinePass, "UndoInstCombine",
71 "Undo specific harmful instcombine transformations", false,
72 false)
73
74namespace clspv {
75ModulePass *createUndoInstCombinePass() { return new UndoInstCombinePass(); }
76} // namespace clspv
77
78bool UndoInstCombinePass::runOnModule(Module &M) {
79 bool changed = false;
80
81 for (auto &F : M) {
82 changed |= runOnFunction(F);
83 }
84
85 // Cleanup.
86 for (auto inst : dead_)
87 inst->eraseFromParent();
88
89 for (auto inst : potentially_dead_) {
90 if (inst->user_empty())
91 inst->eraseFromParent();
92 }
93
94 return changed;
95}
96
97bool UndoInstCombinePass::runOnFunction(Function &F) {
98 bool changed = false;
99
100 for (auto &BB : F) {
101 for (auto &I : BB) {
102 changed |= UndoWideVectorExtractCast(&I);
103 changed |= UndoWideVectorShuffleCast(&I);
104 }
105 }
106
107 return changed;
108}
109
110bool UndoInstCombinePass::UndoWideVectorExtractCast(Instruction *inst) {
111 auto extract = dyn_cast<ExtractElementInst>(inst);
112 if (!extract)
113 return false;
114
115 auto vec_ty = extract->getVectorOperandType();
116 if (vec_ty->getElementCount().Min <= 4)
117 return false;
118
119 // Instcombine only transforms TruncInst (which operates on integers).
120 if (!vec_ty->getVectorElementType()->isIntegerTy())
121 return false;
122
123 auto const_idx = dyn_cast<ConstantInt>(extract->getIndexOperand());
124 if (!const_idx)
125 return false;
126
127 auto cast = dyn_cast<BitCastInst>(extract->getVectorOperand());
128 if (!cast)
129 return false;
130
131 auto src = cast->getOperand(0);
132 auto src_vec_ty = dyn_cast<VectorType>(src->getType());
133 if (!src_vec_ty)
134 return false;
135
136 uint64_t src_elements = src_vec_ty->getElementCount().Min;
137 uint64_t dst_elements = vec_ty->getElementCount().Min;
138
139 if (dst_elements < src_elements)
140 return false;
141
142 uint64_t idx = const_idx->getZExtValue();
143 uint64_t ratio = dst_elements / src_elements;
144 uint64_t new_idx = idx / ratio;
145
146 // Instcombine should never have generated an odd index, so don't handle
147 // right now.
148 if (idx & 0x1)
149 return false;
150
151 // Create a truncate of an extract element.
152 IRBuilder<> builder(inst);
153 auto new_src = builder.CreateExtractElement(src, builder.getInt32(new_idx));
154 auto trunc = builder.CreateTrunc(new_src, extract->getType());
155 extract->replaceAllUsesWith(trunc);
156 dead_.push_back(extract);
157 potentially_dead_.insert(cast);
158
159 return true;
160}
161
162bool UndoInstCombinePass::UndoWideVectorShuffleCast(Instruction *inst) {
163 auto shuffle = dyn_cast<ShuffleVectorInst>(inst);
164 if (!shuffle)
165 return false;
166
167 // Instcombine only transforms TruncInst (which operates on integers).
168 auto vec_ty = cast<VectorType>(shuffle->getType());
169 if (!vec_ty->getVectorElementType()->isIntegerTy())
170 return false;
171
172 auto in1 = shuffle->getOperand(0);
173 auto in1_vec_ty = cast<VectorType>(in1->getType());
174 if (in1_vec_ty->getElementCount().Min <= 4)
175 return false;
176
177 auto in1_cast = dyn_cast<BitCastInst>(in1);
178 if (!in1_cast)
179 return false;
180
181 // Instcombine only produces shuffles with an undef second input, so don't
182 // handle other cases for now.
183 if (!isa<UndefValue>(shuffle->getOperand(1)))
184 return false;
185
186 auto src = in1_cast->getOperand(0);
187 auto src_vec_ty = dyn_cast<VectorType>(src->getType());
188 if (!src_vec_ty)
189 return false;
190
191 uint64_t src_elements = src_vec_ty->getElementCount().Min;
192 uint64_t dst_elements = in1_vec_ty->getElementCount().Min;
193
194 if (dst_elements < src_elements)
195 return false;
196
197 uint64_t ratio = dst_elements / src_elements;
198 auto dst_scalar_type = vec_ty->getVectorElementType();
199
200 SmallVector<int, 16> mask;
201 shuffle->getShuffleMask(mask);
202 for (auto i : mask) {
203 // Instcombine should not have generated odd indices, so don't handle them
204 // for now.
205 if ((i != UndefMaskElem) && (i & 0x1))
206 return false;
207 }
208
209 // For each index, create a truncate of an extract element and insert each
210 // into the result vector.
211 IRBuilder<> builder(inst);
212 Value *insert = nullptr;
213 int i = 0;
214 for (auto idx : mask) {
215 if (idx == UndefMaskElem)
216 continue;
217
218 uint64_t new_idx = idx / ratio;
219 auto extract = builder.CreateExtractElement(src, builder.getInt32(new_idx));
220 auto trunc = builder.CreateTrunc(extract, dst_scalar_type);
221 Value *prev = insert ? insert : Constant::getNullValue(vec_ty);
222 insert = builder.CreateInsertElement(prev, trunc, builder.getInt32(i++));
223 }
224 if (!insert) {
225 insert = Constant::getNullValue(vec_ty);
226 }
227 shuffle->replaceAllUsesWith(insert);
228 dead_.push_back(shuffle);
229 potentially_dead_.insert(in1_cast);
230
231 return true;
232}