blob: 9d1ff9ee6157e3f8485cc23dba03d190fdc89b8b [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/IRBuilder.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/IR/Module.h>
18#include <llvm/Pass.h>
19
20using namespace llvm;
21
22#define DEBUG_TYPE "SimplifyPointerBitcast"
23
24namespace {
25struct SimplifyPointerBitcastPass : public ModulePass {
26 static char ID;
27 SimplifyPointerBitcastPass() : ModulePass(ID) {}
28
29 bool runOnModule(Module &M) override;
30
31 bool runOnBitcastFromBitcast(Module &M) const;
32 bool runOnBitcastFromGEP(Module &M) const;
33 bool runOnGEPFromGEP(Module &M) const;
34};
35}
36
37char SimplifyPointerBitcastPass::ID = 0;
38static RegisterPass<SimplifyPointerBitcastPass>
39 X("SimplifyPointerBitcast", "Simplify Pointer Bitcast Pass");
40
41namespace clspv {
42llvm::ModulePass *createSimplifyPointerBitcastPass() {
43 return new SimplifyPointerBitcastPass();
44}
45}
46
47bool SimplifyPointerBitcastPass::runOnModule(Module &M) {
48 bool Changed = false;
49
50 // Loop through our individual simplification passes until they stop changing
51 // things.
52 for (bool localChanged = true; localChanged; Changed |= localChanged) {
53 localChanged = false;
54
55 localChanged |= runOnBitcastFromGEP(M);
56 localChanged |= runOnBitcastFromBitcast(M);
57 localChanged |= runOnGEPFromGEP(M);
58 }
59
60 return Changed;
61}
62
63bool SimplifyPointerBitcastPass::runOnBitcastFromGEP(Module &M) const {
64 SmallVector<BitCastInst *, 16> WorkList;
65 for (Function &F : M) {
66 for (BasicBlock &BB : F) {
67 for (Instruction &I : BB) {
68 // If we have a bitcast instruction...
69 if (auto Bitcast = dyn_cast<BitCastInst>(&I)) {
70 // ... whose source is a GEP instruction...
71 if (auto GEP = dyn_cast<GetElementPtrInst>(Bitcast->getOperand(0))) {
72 // ... where the GEP is retrieving an element of the same type...
73 if (GEP->getSourceElementType() == GEP->getResultElementType()) {
74 auto GEPTy = GEP->getResultElementType();
75 auto BitcastTy = Bitcast->getType()->getPointerElementType();
76 // ... and the types have a known compile time size...
77 if ((0 != GEPTy->getPrimitiveSizeInBits()) &&
78 (0 != BitcastTy->getPrimitiveSizeInBits())) {
79 // ... record the bitcast as something we need to process.
80 WorkList.push_back(Bitcast);
81 }
82 }
83 }
84 }
85 }
86 }
87 }
88
89 const bool Changed = !WorkList.empty();
90
91 for (auto Bitcast : WorkList) {
92 auto BitcastTy = Bitcast->getType();
93 auto BitcastElementTy = BitcastTy->getPointerElementType();
94
95 auto GEP = cast<GetElementPtrInst>(Bitcast->getOperand(0));
96
97 auto SrcTySize = GEP->getResultElementType()->getPrimitiveSizeInBits();
98 auto DstTySize = BitcastElementTy->getPrimitiveSizeInBits();
99
100 SmallVector<Value *, 4> GEPArgs(GEP->idx_begin(), GEP->idx_end());
101
102 // If the source type is smaller than the destination type...
103 if (SrcTySize < DstTySize) {
104 // ... we need to divide the last index of the GEP by the size difference.
105 auto LastIndex = GEPArgs.back();
106 GEPArgs.back() = BinaryOperator::Create(
107 Instruction::SDiv, LastIndex,
108 ConstantInt::get(LastIndex->getType(), DstTySize / SrcTySize), "",
109 Bitcast);
110 } else if (SrcTySize > DstTySize) {
111 // ... we need to multiply the last index of the GEP by the size
112 // difference.
113 auto LastIndex = GEPArgs.back();
114 GEPArgs.back() = BinaryOperator::Create(
115 Instruction::Mul, LastIndex,
116 ConstantInt::get(LastIndex->getType(), SrcTySize / DstTySize), "",
117 Bitcast);
118 } else {
119 // ... the arguments are the same size, nothing to do!
120 }
121
122 // Create a new bitcast from the GEP argument to the bitcast type.
123 auto NewBitcast = CastInst::CreatePointerCast(GEP->getPointerOperand(),
124 BitcastTy, "", Bitcast);
125
126 // Create a new GEP from the (maybe modified) GEPArgs.
127 auto NewGEP = GetElementPtrInst::Create(BitcastElementTy, NewBitcast,
128 GEPArgs, "", Bitcast);
129
130 // And replace the original bitcast with our replacement GEP.
131 Bitcast->replaceAllUsesWith(NewGEP);
132
133 // Remove the bitcast as it has no users now.
134 Bitcast->eraseFromParent();
135
136 // Check if the old GEP had no other users...
137 if (0 == GEP->getNumUses()) {
138 // ... and remove it if we were its only user.
139 GEP->eraseFromParent();
140 }
141 }
142
143 return Changed;
144}
145
146bool SimplifyPointerBitcastPass::runOnBitcastFromBitcast(Module &M) const {
147 SmallVector<BitCastInst *, 16> WorkList;
148 for (Function &F : M) {
149 for (BasicBlock &BB : F) {
150 for (Instruction &I : BB) {
151 // If we have a bitcast instruction...
152 if (auto Bitcast = dyn_cast<BitCastInst>(&I)) {
153 // ... whose source is a bitcast instruction...
154 if (isa<BitCastInst>(Bitcast->getOperand(0))) {
155 // ... record the bitcast as something we need to process.
156 WorkList.push_back(Bitcast);
157 }
158 }
159 }
160 }
161 }
162
163 const bool Changed = !WorkList.empty();
164
165 for (auto Bitcast : WorkList) {
166 auto OtherBitcast = cast<BitCastInst>(Bitcast->getOperand(0));
167
168 // Create a new bitcast from the other bitcasts argument to our type.
169 auto NewBitcast = CastInst::CreatePointerCast(
170 OtherBitcast->getOperand(0), Bitcast->getType(), "", Bitcast);
171
172 // And replace the original bitcast with our replacement bitcast.
173 Bitcast->replaceAllUsesWith(NewBitcast);
174
175 // Remove the bitcast as it has no users now.
176 Bitcast->eraseFromParent();
177
178 // Check if the other bitcast had no other users...
179 if (0 == OtherBitcast->getNumUses()) {
180 // ... and remove it if we were its only user.
181 OtherBitcast->eraseFromParent();
182 }
183 }
184
185 return Changed;
186}
187
188bool SimplifyPointerBitcastPass::runOnGEPFromGEP(Module &M) const {
189 SmallVector<GetElementPtrInst *, 16> WorkList;
190 for (Function &F : M) {
191 for (BasicBlock &BB : F) {
192 for (Instruction &I : BB) {
193 // If we have a GEP instruction...
194 if (auto GEP = dyn_cast<GetElementPtrInst>(&I)) {
195 // ... whose operand is also a GEP instruction...
196 if (isa<GetElementPtrInst>(GEP->getPointerOperand())) {
197 // ... record the GEP as something we need to process.
198 WorkList.push_back(GEP);
199 }
200 }
201 }
202 }
203 }
204
205 const bool Changed = !WorkList.empty();
206
207 for (GetElementPtrInst *GEP : WorkList) {
208 IRBuilder<> Builder(GEP);
209
210 auto OtherGEP = cast<GetElementPtrInst>(GEP->getPointerOperand());
211
212 SmallVector<Value *, 8> Idxs;
213
214 Value *SrcLastIdxOp = OtherGEP->getOperand(OtherGEP->getNumOperands() - 1);
215 Value *GEPIdxOp = GEP->getOperand(1);
216 Value *MergedIdx = Builder.CreateAdd(SrcLastIdxOp, GEPIdxOp);
217
218 Idxs.append(OtherGEP->op_begin() + 1, OtherGEP->op_end() - 1);
219 Idxs.push_back(MergedIdx);
220 Idxs.append(GEP->op_begin() + 2, GEP->op_end());
221
222 Value *NewGEP = nullptr;
223 if (GEP->isInBounds() && OtherGEP->isInBounds()) {
224 NewGEP = Builder.CreateInBoundsGEP(OtherGEP->getPointerOperand(), Idxs);
225 } else {
226 NewGEP = Builder.CreateGEP(OtherGEP->getPointerOperand(), Idxs);
227 }
228
229 // And replace the original GEP with our replacement GEP.
230 GEP->replaceAllUsesWith(NewGEP);
231
232 // Remove the GEP as it has no users now.
233 GEP->eraseFromParent();
234
235 // Check if the other GEP had no other users...
236 if (0 == OtherGEP->getNumUses()) {
237 // ... and remove it if we were its only user.
238 OtherGEP->eraseFromParent();
239 }
240 }
241
242 return Changed;
243}