blob: f01cff2d1912f0b851581bc23e0b47c42d7ef74c [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/DataLayout.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/raw_ostream.h"
David Neto22f144c2017-06-12 14:26:21 -040021
22using namespace llvm;
23
24#define DEBUG_TYPE "replacepointerbitcast"
25
26namespace {
27struct ReplacePointerBitcastPass : public ModulePass {
28 static char ID;
29 ReplacePointerBitcastPass() : ModulePass(ID) {}
30
David Neto30ae05e2017-09-06 19:58:36 -040031 // Returns the number of chunks of source data required to exactly
32 // cover the destination data, if the source and destination types are
33 // different sizes. Otherwise returns 0.
David Neto22f144c2017-06-12 14:26:21 -040034 unsigned CalculateNumIter(unsigned SrcTyBitWidth, unsigned DstTyBitWidth);
35 Value *CalculateNewGEPIdx(unsigned SrcTyBitWidth, unsigned DstTyBitWidth,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040036 GetElementPtrInst *GEP);
David Neto22f144c2017-06-12 14:26:21 -040037
38 bool runOnModule(Module &M) override;
39};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040040} // namespace
David Neto22f144c2017-06-12 14:26:21 -040041
42char ReplacePointerBitcastPass::ID = 0;
43static RegisterPass<ReplacePointerBitcastPass>
44 X("ReplacePointerBitcast", "Replace Pointer Bitcast Pass");
45
46namespace clspv {
47ModulePass *createReplacePointerBitcastPass() {
48 return new ReplacePointerBitcastPass();
49}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040050} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040051
52unsigned ReplacePointerBitcastPass::CalculateNumIter(unsigned SrcTyBitWidth,
53 unsigned DstTyBitWidth) {
54 unsigned NumIter = 0;
55 if (SrcTyBitWidth > DstTyBitWidth) {
56 if (SrcTyBitWidth % DstTyBitWidth) {
57 llvm_unreachable(
58 "Src type bitwidth should be multiple of Dest type bitwidth");
59 }
60 NumIter = 1;
61 } else if (SrcTyBitWidth < DstTyBitWidth) {
62 if (DstTyBitWidth % SrcTyBitWidth) {
63 llvm_unreachable(
64 "Dest type bitwidth should be multiple of Src type bitwidth");
65 }
66 NumIter = DstTyBitWidth / SrcTyBitWidth;
67 } else {
68 NumIter = 0;
69 }
70
71 return NumIter;
72}
73
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040074Value *ReplacePointerBitcastPass::CalculateNewGEPIdx(unsigned SrcTyBitWidth,
75 unsigned DstTyBitWidth,
76 GetElementPtrInst *GEP) {
David Neto22f144c2017-06-12 14:26:21 -040077 Value *NewGEPIdx = GEP->getOperand(1);
78 IRBuilder<> Builder(GEP);
79
80 if (SrcTyBitWidth > DstTyBitWidth) {
81 if (GEP->getNumOperands() > 2) {
82 GEP->print(errs());
83 llvm_unreachable("Support above GEP on PointerBitcastPass");
84 }
85
86 NewGEPIdx = Builder.CreateLShr(
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040087 NewGEPIdx, Builder.getInt32(std::log2(SrcTyBitWidth / DstTyBitWidth)));
David Neto22f144c2017-06-12 14:26:21 -040088 } else if (DstTyBitWidth > SrcTyBitWidth) {
89 if (GEP->getNumOperands() > 2) {
90 GEP->print(errs());
91 llvm_unreachable("Support above GEP on PointerBitcastPass");
92 }
93
94 NewGEPIdx = Builder.CreateShl(
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040095 NewGEPIdx, Builder.getInt32(std::log2(DstTyBitWidth / SrcTyBitWidth)));
David Neto22f144c2017-06-12 14:26:21 -040096 }
97
98 return NewGEPIdx;
99}
100
101bool ReplacePointerBitcastPass::runOnModule(Module &M) {
102 bool Changed = false;
103
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400104 const DataLayout &DL = M.getDataLayout();
David Neto8e138142018-05-29 10:19:21 -0400105
David Neto22f144c2017-06-12 14:26:21 -0400106 SmallVector<Instruction *, 16> VectorWorkList;
107 SmallVector<Instruction *, 16> ScalarWorkList;
108 for (Function &F : M) {
109 for (BasicBlock &BB : F) {
110 for (Instruction &I : BB) {
111 // Find pointer bitcast instruction.
112 if (isa<BitCastInst>(&I) && isa<PointerType>(I.getType())) {
113 Value *Src = I.getOperand(0);
114 if (isa<PointerType>(Src->getType())) {
115 Type *SrcEleTy =
116 I.getOperand(0)->getType()->getPointerElementType();
117 Type *DstEleTy = I.getType()->getPointerElementType();
118 if (SrcEleTy->isVectorTy() || DstEleTy->isVectorTy()) {
119 // Handle case either operand is vector type like char4* -> int4*.
120 VectorWorkList.push_back(&I);
121 } else {
122 // Handle case all operands are scalar type like char* -> int*.
123 ScalarWorkList.push_back(&I);
124 }
125
126 Changed = true;
127 } else {
128 llvm_unreachable("Unsupported bitcast");
129 }
130 }
131 }
132 }
133 }
134
135 SmallVector<Instruction *, 16> ToBeDeleted;
136 for (Instruction *Inst : VectorWorkList) {
137 Value *Src = Inst->getOperand(0);
138 Type *SrcTy = Src->getType()->getPointerElementType();
139 Type *DstTy = Inst->getType()->getPointerElementType();
140 Type *SrcEleTy =
141 SrcTy->isVectorTy() ? SrcTy->getSequentialElementType() : SrcTy;
142 Type *DstEleTy =
143 DstTy->isVectorTy() ? DstTy->getSequentialElementType() : DstTy;
David Neto30ae05e2017-09-06 19:58:36 -0400144 // These are bit widths of the source and destination types, even
145 // if they are vector types. E.g. bit width of float4 is 64.
David Neto8e138142018-05-29 10:19:21 -0400146 unsigned SrcTyBitWidth = DL.getTypeStoreSizeInBits(SrcTy);
147 unsigned DstTyBitWidth = DL.getTypeStoreSizeInBits(DstTy);
148 unsigned SrcEleTyBitWidth = DL.getTypeStoreSizeInBits(SrcEleTy);
149 unsigned DstEleTyBitWidth = DL.getTypeStoreSizeInBits(DstEleTy);
David Neto22f144c2017-06-12 14:26:21 -0400150 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
151
152 // Investigate pointer bitcast's users.
153 for (User *BitCastUser : Inst->users()) {
154 Value *BitCastSrc = Inst->getOperand(0);
155 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
156
157 // It consist of User* and bool whether user is gep or not.
158 SmallVector<std::pair<User *, bool>, 32> Users;
159
160 GetElementPtrInst *GEP = nullptr;
161 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400162 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400163 OrgGEPIdx = GEP->getOperand(1);
164
165 // Build new src/dst address index.
166 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
167
168 // Record gep's users.
169 for (User *GEPUser : GEP->users()) {
170 Users.push_back(std::make_pair(GEPUser, true));
171 }
172 } else {
173 // Record bitcast's users.
174 Users.push_back(std::make_pair(BitCastUser, false));
175 }
176
177 // Handle users.
178 bool IsGEPUser = false;
179 for (auto UserIter : Users) {
180 User *U = UserIter.first;
181 IsGEPUser = UserIter.second;
182
183 IRBuilder<> Builder(cast<Instruction>(U));
184
185 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
186 if (SrcTyBitWidth < DstTyBitWidth) {
187 //
188 // Consider below case.
189 //
190 // Original IR (float2* --> float4*)
191 // 1. val = load (float4*) src_addr
192 // 2. dst_addr = bitcast float2*, float4*
193 // 3. dst_addr = gep (float4*) dst_addr, idx
194 // 4. store (float4*) dst_addr
195 //
196 // Transformed IR
197 // 1. val(float4) = load (float4*) src_addr
198 // 2. val1(float2) = shufflevector (float4)val, (float4)undef,
199 // (float2)<0, 1>
200 // 3. val2(float2) = shufflevector (float4)val, (float4)undef,
201 // (float2)<2, 3>
202 // 4. dst_addr1(float2*) = gep (float2*)dst_addr, idx * 2
203 // 5. dst_addr2(float2*) = gep (float2*)dst_addr, idx * 2 + 1
204 // 6. store (float2)val1, (float2*)dst_addr1
205 // 7. store (float2)val2, (float2*)dst_addr2
206 //
207
208 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
209 unsigned NumVector = 1;
210 // Vulkan SPIR-V does not support over 4 components for
211 // TypeVector.
212 if (NumElement > 4) {
213 NumVector = NumElement >> 2;
214 NumElement = 4;
215 }
216
217 // Create store values.
218 Type *TmpValTy = SrcTy;
219 if (DstTy->isVectorTy()) {
220 if (SrcEleTyBitWidth == DstEleTyBitWidth) {
221 TmpValTy =
222 VectorType::get(SrcEleTy, DstTy->getVectorNumElements());
223 } else {
224 TmpValTy = VectorType::get(SrcEleTy, NumElement);
225 }
226 }
227
228 Value *STVal = ST->getValueOperand();
229 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
230 Value *TmpSTVal = nullptr;
231 if (NumVector == 1) {
232 TmpSTVal = Builder.CreateBitCast(STVal, TmpValTy);
233 } else {
234 unsigned DstVecTyNumElement =
235 DstTy->getVectorNumElements() / NumVector;
236 SmallVector<uint32_t, 4> Idxs;
237 for (unsigned i = 0; i < DstVecTyNumElement; i++) {
238 Idxs.push_back(i + (DstVecTyNumElement * VIdx));
239 }
240 Value *UndefVal = UndefValue::get(DstTy);
241 TmpSTVal = Builder.CreateShuffleVector(STVal, UndefVal, Idxs);
242 TmpSTVal = Builder.CreateBitCast(TmpSTVal, TmpValTy);
243 }
244
245 SmallVector<Value *, 8> STValues;
246 if (!SrcTy->isVectorTy()) {
247 // Handle scalar type.
248 for (unsigned i = 0; i < NumElement; i++) {
249 Value *TmpVal = Builder.CreateExtractElement(
250 TmpSTVal, Builder.getInt32(i));
251 STValues.push_back(TmpVal);
252 }
253 } else {
254 // Handle vector type.
255 unsigned SrcNumElement = SrcTy->getVectorNumElements();
256 unsigned DstNumElement = DstTy->getVectorNumElements();
257 for (unsigned i = 0; i < NumElement; i++) {
258 SmallVector<uint32_t, 4> Idxs;
259 for (unsigned j = 0; j < SrcNumElement; j++) {
260 Idxs.push_back(i * SrcNumElement + j);
261 }
262
263 VectorType *TmpVecTy =
264 VectorType::get(SrcEleTy, DstNumElement);
265 Value *UndefVal = UndefValue::get(TmpVecTy);
266 Value *TmpVal =
267 Builder.CreateShuffleVector(TmpSTVal, UndefVal, Idxs);
268 STValues.push_back(TmpVal);
269 }
270 }
271
272 // Generate stores.
273 Value *SrcAddrIdx = NewAddrIdx;
274 Value *BaseAddr = BitCastSrc;
275 for (unsigned i = 0; i < NumElement; i++) {
276 // Calculate store address.
277 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
278 Builder.CreateStore(STValues[i], DstAddr);
279
280 if (i + 1 < NumElement) {
281 // Calculate next store address
282 SrcAddrIdx =
283 Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
284 }
285 }
286 }
287 } else if (SrcTyBitWidth > DstTyBitWidth) {
288 //
289 // Consider below case.
290 //
291 // Original IR (float4* --> float2*)
292 // 1. val = load (float2*) src_addr
293 // 2. dst_addr = bitcast float4*, float2*
294 // 3. dst_addr = gep (float2*) dst_addr, idx
295 // 4. store (float2) val, (float2*) dst_addr
296 //
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400297 // Transformed IR: Decompose the source vector into elements, then
298 // write them one at a time.
David Neto22f144c2017-06-12 14:26:21 -0400299 // 1. val = load (float2*) src_addr
300 // 2. val1 = (float)extract_element val, 0
301 // 3. val2 = (float)extract_element val, 1
David Neto30ae05e2017-09-06 19:58:36 -0400302 // // Source component k maps to destination component k * idxscale
303 // 3a. idxscale = sizeof(float4)/sizeof(float2)
304 // 3b. idxbase = idx / idxscale
305 // 3c. newarrayidx = idxbase * idxscale
306 // 4. dst_addr1 = gep (float4*) dst, newarrayidx
307 // 5. dst_addr2 = gep (float4*) dst, newarrayidx + 1
David Neto22f144c2017-06-12 14:26:21 -0400308 // 6. store (float)val1, (float*) dst_addr1
309 // 7. store (float)val2, (float*) dst_addr2
310 //
311
312 if (SrcTyBitWidth <= DstEleTyBitWidth) {
313 SrcTy->print(errs());
314 DstTy->print(errs());
315 llvm_unreachable("Handle above src/dst type.");
316 }
317
318 // Create store values.
319 Value *STVal = ST->getValueOperand();
320
321 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
322 VectorType *TmpVecTy =
323 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
324 STVal = Builder.CreateBitCast(STVal, TmpVecTy);
325 }
326
327 SmallVector<Value *, 8> STValues;
David Neto30ae05e2017-09-06 19:58:36 -0400328 // How many destination writes are required?
David Neto22f144c2017-06-12 14:26:21 -0400329 unsigned DstNumElement = 1;
330 if (!DstTy->isVectorTy() || SrcEleTyBitWidth == DstTyBitWidth) {
331 // Handle scalar type.
332 STValues.push_back(STVal);
333 } else {
334 // Handle vector type.
335 DstNumElement = DstTy->getVectorNumElements();
336 for (unsigned i = 0; i < DstNumElement; i++) {
337 Value *Idx = Builder.getInt32(i);
338 Value *TmpVal = Builder.CreateExtractElement(STVal, Idx);
339 STValues.push_back(TmpVal);
340 }
341 }
342
343 // Generate stores.
344 Value *BaseAddr = BitCastSrc;
345 Value *SubEleIdx = Builder.getInt32(0);
346 if (IsGEPUser) {
David Neto30ae05e2017-09-06 19:58:36 -0400347 // Compute SubNumElement = idxscale
David Neto22f144c2017-06-12 14:26:21 -0400348 unsigned SubNumElement = SrcTy->getVectorNumElements();
349 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
David Neto30ae05e2017-09-06 19:58:36 -0400350 // Same condition under which DstNumElements > 1
David Neto22f144c2017-06-12 14:26:21 -0400351 SubNumElement = SrcTy->getVectorNumElements() /
352 DstTy->getVectorNumElements();
353 }
354
David Neto30ae05e2017-09-06 19:58:36 -0400355 // Compute SubEleIdx = idxbase * idxscale
David Neto22f144c2017-06-12 14:26:21 -0400356 SubEleIdx = Builder.CreateAnd(
357 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
David Neto30ae05e2017-09-06 19:58:36 -0400358 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
359 SubEleIdx = Builder.CreateShl(
360 SubEleIdx, Builder.getInt32(std::log2(SubNumElement)));
361 }
David Neto22f144c2017-06-12 14:26:21 -0400362 }
363
364 for (unsigned i = 0; i < DstNumElement; i++) {
365 // Calculate address.
366 if (i > 0) {
367 SubEleIdx = Builder.CreateAdd(SubEleIdx, Builder.getInt32(i));
368 }
369
370 Value *Idxs[] = {NewAddrIdx, SubEleIdx};
371 Value *DstAddr = Builder.CreateGEP(BaseAddr, Idxs);
372 Type *TmpSrcTy = SrcEleTy;
373 if (TmpSrcTy->isVectorTy()) {
374 TmpSrcTy = TmpSrcTy->getVectorElementType();
375 }
376 Value *TmpVal = Builder.CreateBitCast(STValues[i], TmpSrcTy);
377
378 Builder.CreateStore(TmpVal, DstAddr);
379 }
380 } else {
381 // if SrcTyBitWidth == DstTyBitWidth
382 Type *TmpSrcTy = SrcTy;
383 Value *DstAddr = Src;
384
385 if (IsGEPUser) {
386 SmallVector<Value *, 4> Idxs;
387 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
388 Idxs.push_back(GEP->getOperand(i));
389 }
390 DstAddr = Builder.CreateGEP(BitCastSrc, Idxs);
391
392 if (GEP->getNumOperands() > 2) {
393 TmpSrcTy = SrcEleTy;
394 }
395 }
396
397 Value *TmpVal =
398 Builder.CreateBitCast(ST->getValueOperand(), TmpSrcTy);
399 Builder.CreateStore(TmpVal, DstAddr);
400 }
401 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
402 Value *SrcAddrIdx = Builder.getInt32(0);
403 if (IsGEPUser) {
404 SrcAddrIdx = NewAddrIdx;
405 }
406
407 // Load value from src.
408 SmallVector<Value *, 8> LDValues;
409
410 for (unsigned i = 1; i <= NumIter; i++) {
411 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
412 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
413 LDValues.push_back(SrcVal);
414
415 if (i + 1 <= NumIter) {
416 // Calculate next SrcAddrIdx.
417 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
418 }
419 }
420
421 Value *DstVal = nullptr;
422 if (SrcTyBitWidth > DstTyBitWidth) {
423 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
424
425 if (SrcEleTyBitWidth == DstTyBitWidth) {
426 //
427 // Consider below case.
428 //
429 // Original IR (int4* --> char4*)
430 // 1. src_addr = bitcast int4*, char4*
431 // 2. element_addr = gep (char4*) src_addr, idx
432 // 3. load (char4*) element_addr
433 //
434 // Transformed IR
435 // 1. src_addr = gep (int4*) src, idx / 4
436 // 2. src_val(int4) = load (int4*) src_addr
437 // 3. tmp_val(int4) = extractelement src_val, idx % 4
438 // 4. dst_val(char4) = bitcast tmp_val, (char4)
439 //
440 Value *EleIdx = Builder.getInt32(0);
441 if (IsGEPUser) {
442 EleIdx = Builder.CreateAnd(OrgGEPIdx,
443 Builder.getInt32(NumElement - 1));
444 }
445 Value *TmpVal =
446 Builder.CreateExtractElement(LDValues[0], EleIdx, "tmp_val");
447 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
448 } else if (SrcEleTyBitWidth < DstTyBitWidth) {
449 if (IsGEPUser) {
450 //
451 // Consider below case.
452 //
453 // Original IR (float4* --> float2*)
454 // 1. src_addr = bitcast float4*, float2*
455 // 2. element_addr = gep (float2*) src_addr, idx
456 // 3. load (float2*) element_addr
457 //
458 // Transformed IR
459 // 1. src_addr = gep (float4*) src, idx / 2
460 // 2. src_val(float4) = load (float4*) src_addr
461 // 3. tmp_val1(float) = extractelement (idx % 2) * 2
462 // 4. tmp_val2(float) = extractelement (idx % 2) * 2 + 1
463 // 5. dst_val(float2) = insertelement undef(float2), tmp_val1, 0
464 // 6. dst_val(float2) = insertelement undef(float2), tmp_val2, 1
465 // 7. dst_val(float2) = bitcast dst_val, (float2)
466 // ==> if types are same between src and dst, it will be
467 // igonored
468 //
469 VectorType *TmpVecTy =
470 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
471 DstVal = UndefValue::get(TmpVecTy);
472 Value *EleIdx = Builder.CreateAnd(
473 OrgGEPIdx, Builder.getInt32(NumElement - 1));
474 EleIdx = Builder.CreateShl(
475 EleIdx, Builder.getInt32(
476 std::log2(DstTyBitWidth / SrcEleTyBitWidth)));
477 Value *TmpOrgGEPIdx = EleIdx;
478 for (unsigned i = 0; i < NumElement; i++) {
479 Value *TmpVal = Builder.CreateExtractElement(
480 LDValues[0], TmpOrgGEPIdx, "tmp_val");
481 DstVal = Builder.CreateInsertElement(DstVal, TmpVal,
482 Builder.getInt32(i));
483
484 if (i + 1 < NumElement) {
485 TmpOrgGEPIdx =
486 Builder.CreateAdd(TmpOrgGEPIdx, Builder.getInt32(1));
487 }
488 }
489 } else {
490 //
491 // Consider below case.
492 //
493 // Original IR (float4* --> int2*)
494 // 1. src_addr = bitcast float4*, int2*
495 // 2. load (int2*) src_addr
496 //
497 // Transformed IR
498 // 1. src_val(float4) = load (float4*) src_addr
499 // 2. tmp_val(float2) = shufflevector (float4)src_val,
500 // (float4)undef,
501 // (float2)<0, 1>
502 // 3. dst_val(int2) = bitcast (float2)tmp_val, (int2)
503 //
504 unsigned NumElement = DstTyBitWidth / SrcEleTyBitWidth;
505 Value *Undef = UndefValue::get(SrcTy);
506
507 SmallVector<uint32_t, 4> Idxs;
508 for (unsigned i = 0; i < NumElement; i++) {
509 Idxs.push_back(i);
510 }
511 DstVal = Builder.CreateShuffleVector(LDValues[0], Undef, Idxs);
512
513 DstVal = Builder.CreateBitCast(DstVal, DstTy);
514 }
515
516 DstVal = Builder.CreateBitCast(DstVal, DstTy);
517 } else {
518 if (IsGEPUser) {
519 //
520 // Consider below case.
521 //
522 // Original IR (int4* --> char2*)
523 // 1. src_addr = bitcast int4*, char2*
524 // 2. element_addr = gep (char2*) src_addr, idx
525 // 3. load (char2*) element_addr
526 //
527 // Transformed IR
528 // 1. src_addr = gep (int4*) src, idx / 8
529 // 2. src_val(int4) = load (int4*) src_addr
530 // 3. tmp_val(int) = extractelement idx / 2
531 // 4. tmp_val(<i16 x 2>) = bitcast tmp_val(int), (<i16 x 2>)
532 // 5. tmp_val(i16) = extractelement idx % 2
533 // 6. dst_val(char2) = bitcast tmp_val, (char2)
534 // ==> if types are same between src and dst, it will be
535 // igonored
536 //
537 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
538 unsigned SubNumElement = SrcEleTyBitWidth / DstTyBitWidth;
539 if (SubNumElement != 2 && SubNumElement != 4) {
540 llvm_unreachable("Unsupported SubNumElement");
541 }
542
543 Value *TmpOrgGEPIdx = Builder.CreateLShr(
544 OrgGEPIdx, Builder.getInt32(std::log2(SubNumElement)));
545 Value *TmpVal = Builder.CreateExtractElement(
546 LDValues[0], TmpOrgGEPIdx, "tmp_val");
547 TmpVal = Builder.CreateBitCast(
548 TmpVal,
549 VectorType::get(
550 IntegerType::get(DstTy->getContext(), DstTyBitWidth),
551 SubNumElement));
552 TmpOrgGEPIdx = Builder.CreateAnd(
553 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
554 TmpVal = Builder.CreateExtractElement(TmpVal, TmpOrgGEPIdx,
555 "tmp_val");
556 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
557 } else {
558 Inst->print(errs());
559 llvm_unreachable("Handle this bitcast");
560 }
561 }
562 } else if (SrcTyBitWidth < DstTyBitWidth) {
563 //
564 // Consider below case.
565 //
566 // Original IR (float2* --> float4*)
567 // 1. src_addr = bitcast float2*, float4*
568 // 2. element_addr = gep (float4*) src_addr, idx
569 // 3. load (float4*) element_addr
570 //
571 // Transformed IR
572 // 1. src_addr = gep (float2*) src, idx * 2
573 // 2. src_val1(float2) = load (float2*) src_addr
574 // 3. src_addr2 = gep (float2*) src_addr, 1
575 // 4. src_val2(float2) = load (float2*) src_addr2
576 // 5. dst_val(float4) = shufflevector src_val1, src_val2, <0, 1>
577 // 6. dst_val(float4) = bitcast dst_val, (float4)
578 // ==> if types are same between src and dst, it will be igonored
579 //
580 unsigned NumElement = 1;
581 if (SrcTy->isVectorTy()) {
582 NumElement = SrcTy->getVectorNumElements() * 2;
583 }
584
585 // Handle scalar type.
586 if (NumElement == 1) {
587 if (SrcTyBitWidth * 4 <= DstTyBitWidth) {
588 unsigned NumVecElement = DstTyBitWidth / SrcTyBitWidth;
589 unsigned NumVector = 1;
590 if (NumVecElement > 4) {
591 NumVector = NumVecElement >> 2;
592 NumVecElement = 4;
593 }
594
595 SmallVector<Value *, 4> Values;
596 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
597 // In this case, generate only insert element. It generates
598 // less instructions than using shuffle vector.
599 VectorType *TmpVecTy = VectorType::get(SrcTy, NumVecElement);
600 Value *TmpVal = UndefValue::get(TmpVecTy);
601 for (unsigned i = 0; i < NumVecElement; i++) {
602 TmpVal = Builder.CreateInsertElement(
603 TmpVal, LDValues[i + (VIdx * 4)], Builder.getInt32(i));
604 }
605 Values.push_back(TmpVal);
606 }
607
608 if (Values.size() > 2) {
609 Inst->print(errs());
610 llvm_unreachable("Support above bitcast");
611 }
612
613 if (Values.size() > 1) {
614 Type *TmpEleTy =
615 Type::getIntNTy(M.getContext(), SrcEleTyBitWidth * 2);
616 VectorType *TmpVecTy = VectorType::get(TmpEleTy, NumVector);
617 for (unsigned i = 0; i < Values.size(); i++) {
618 Values[i] = Builder.CreateBitCast(Values[i], TmpVecTy);
619 }
620 SmallVector<uint32_t, 4> Idxs;
621 for (unsigned i = 0; i < (NumVector * 2); i++) {
622 Idxs.push_back(i);
623 }
624 for (unsigned i = 0; i < Values.size(); i = i + 2) {
625 Values[i] = Builder.CreateShuffleVector(
626 Values[i], Values[i + 1], Idxs);
627 }
628 }
629
630 LDValues.clear();
631 LDValues.push_back(Values[0]);
632 } else {
633 SmallVector<Value *, 4> TmpLDValues;
634 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
635 VectorType *TmpVecTy = VectorType::get(SrcTy, 2);
636 Value *TmpVal = UndefValue::get(TmpVecTy);
637 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i],
638 Builder.getInt32(0));
639 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i + 1],
640 Builder.getInt32(1));
641 TmpLDValues.push_back(TmpVal);
642 }
643 LDValues.clear();
644 LDValues = std::move(TmpLDValues);
645 NumElement = 4;
646 }
647 }
648
649 // Handle vector type.
650 while (LDValues.size() != 1) {
651 SmallVector<Value *, 4> TmpLDValues;
652 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
653 SmallVector<uint32_t, 4> Idxs;
654 for (unsigned j = 0; j < NumElement; j++) {
655 Idxs.push_back(j);
656 }
657 Value *TmpVal = Builder.CreateShuffleVector(
658 LDValues[i], LDValues[i + 1], Idxs);
659 TmpLDValues.push_back(TmpVal);
660 }
661 LDValues.clear();
662 LDValues = std::move(TmpLDValues);
663 NumElement *= 2;
664 }
665
666 DstVal = Builder.CreateBitCast(LDValues[0], DstTy);
667 } else {
668 //
669 // Consider below case.
670 //
671 // Original IR (float4* --> int4*)
672 // 1. src_addr = bitcast float4*, int4*
673 // 2. element_addr = gep (int4*) src_addr, idx, 0
674 // 3. load (int) element_addr
675 //
676 // Transformed IR
677 // 1. element_addr = gep (float4*) src_addr, idx, 0
678 // 2. src_val = load (float*) element_addr
679 // 3. val = bitcast (float) src_val to (int)
680 //
681 Value *SrcAddr = Src;
682 if (IsGEPUser) {
683 SmallVector<Value *, 4> Idxs;
684 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
685 Idxs.push_back(GEP->getOperand(i));
686 }
687 SrcAddr = Builder.CreateGEP(Src, Idxs);
688 }
689 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
690
691 Type *TmpDstTy = DstTy;
692 if (IsGEPUser) {
693 if (GEP->getNumOperands() > 2) {
694 TmpDstTy = DstEleTy;
695 }
696 }
697 DstVal = Builder.CreateBitCast(SrcVal, TmpDstTy);
698 }
699
700 // Update LD's users with DstVal.
701 LD->replaceAllUsesWith(DstVal);
702 } else {
703 U->print(errs());
704 llvm_unreachable(
705 "Handle above user of gep on ReplacePointerBitcastPass");
706 }
707
708 ToBeDeleted.push_back(cast<Instruction>(U));
709 }
710
711 if (IsGEPUser) {
712 ToBeDeleted.push_back(GEP);
713 }
714 }
715
716 ToBeDeleted.push_back(Inst);
717 }
718
719 for (Instruction *Inst : ScalarWorkList) {
David Neto8e138142018-05-29 10:19:21 -0400720 // Some tests have a stray bitcast from pointer-to-array to
721 // pointer to i8*, but the bitcast has no uses. Exit early
722 // but be sure to delete it later.
723 //
724 // Example:
725 // %1 = bitcast [25 x float]* %dst to i8*
726
727 // errs () << " Scalar bitcast is " << *Inst << "\n";
728
729 if (!Inst->hasNUsesOrMore(1)) {
730 ToBeDeleted.push_back(Inst);
731 continue;
732 }
733
David Neto22f144c2017-06-12 14:26:21 -0400734 Value *Src = Inst->getOperand(0);
David Neto8e138142018-05-29 10:19:21 -0400735 Type *SrcTy; // Original type
736 Type *DstTy; // Type that SrcTy is cast to.
737 unsigned SrcTyBitWidth;
738 unsigned DstTyBitWidth;
739
740 SrcTy = Src->getType()->getPointerElementType();
741 DstTy = Inst->getType()->getPointerElementType();
742 int iter_count = 0;
743 while (++iter_count) {
744 SrcTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(SrcTy));
745 DstTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(DstTy));
746#if 0
747 errs() << " Try Src " << *Src << "\n";
748 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
749 << "\n";
750 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
751 << "\n";
752#endif
753
754 // The normal case that we can handle is source type is smaller than
755 // the dest type.
756 if (SrcTyBitWidth <= DstTyBitWidth)
757 break;
758
759 // The Source type is bigger than the destination type.
760 // Walk into the source type to break it down.
761 if (SrcTy->isArrayTy()) {
762 // If it's an array, consider only the first element.
763 Value *Zero = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400764 Instruction *NewSrc =
765 GetElementPtrInst::CreateInBounds(Src, {Zero, Zero});
David Neto8e138142018-05-29 10:19:21 -0400766 // errs() << "NewSrc is " << *NewSrc << "\n";
767 if (auto *SrcInst = dyn_cast<Instruction>(Src)) {
768 // errs() << " instruction case\n";
769 NewSrc->insertAfter(SrcInst);
770 } else {
771 // Could be a parameter.
772 auto where = Inst->getParent()
773 ->getParent()
774 ->getEntryBlock()
775 .getFirstInsertionPt();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400776 Instruction &whereInst = *where;
David Neto8e138142018-05-29 10:19:21 -0400777 // errs() << "insert " << *NewSrc << " before " << whereInst << "\n";
778 NewSrc->insertBefore(&whereInst);
779 }
780 Src = NewSrc;
781 SrcTy = Src->getType()->getPointerElementType();
782 } else {
783 errs() << "Replace pointer bitcasts: unhandled case: non-array "
784 "non-vector source type "
785 << *SrcTy << " is wider than dest type " << *DstTy << "\n";
786 llvm_unreachable("ReplacePointerBitcastPass: non-array non-vector "
787 "source type is wider than dest type");
788 }
789 if (iter_count > 1000) {
790 llvm_unreachable("ReplacePointerBitcastPass: Too many iterations!");
791 }
792 };
793#if 0
794 errs() << " Src is " << *Src << "\n";
795 errs() << " Dst is " << *Inst << "\n";
796 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
797 << "\n";
798 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
799 << "\n";
800#endif
David Neto22f144c2017-06-12 14:26:21 -0400801
802 for (User *BitCastUser : Inst->users()) {
803 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
804 // It consist of User* and bool whether user is gep or not.
805 SmallVector<std::pair<User *, bool>, 32> Users;
806
807 GetElementPtrInst *GEP = nullptr;
808 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400809 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400810 IRBuilder<> Builder(GEP);
811
812 // Build new src/dst address.
813 OrgGEPIdx = GEP->getOperand(1);
814 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
815
816 // If bitcast's user is gep, investigate gep's users too.
817 for (User *GEPUser : GEP->users()) {
818 Users.push_back(std::make_pair(GEPUser, true));
819 }
820 } else {
821 Users.push_back(std::make_pair(BitCastUser, false));
822 }
823
824 // Handle users.
825 bool IsGEPUser = false;
826 for (auto UserIter : Users) {
827 User *U = UserIter.first;
828 IsGEPUser = UserIter.second;
829
830 IRBuilder<> Builder(cast<Instruction>(U));
831
832 // Handle store instruction with gep.
833 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400834 // errs() << " store is " << *ST << "\n";
David Neto22f144c2017-06-12 14:26:21 -0400835 if (SrcTyBitWidth == DstTyBitWidth) {
836 auto STVal = Builder.CreateBitCast(ST->getValueOperand(), SrcTy);
837 Value *DstAddr = Builder.CreateGEP(Src, NewAddrIdx);
838 Builder.CreateStore(STVal, DstAddr);
839 } else if (SrcTyBitWidth < DstTyBitWidth) {
840 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
841
842 // Create Mask.
843 Constant *Mask = nullptr;
844 if (NumElement == 1) {
845 Mask = Builder.getInt32(0xFF);
846 } else if (NumElement == 2) {
847 Mask = Builder.getInt32(0xFFFF);
848 } else {
849 llvm_unreachable("strange type on bitcast");
850 }
851
852 // Create store values.
853 Value *STVal = ST->getValueOperand();
854 SmallVector<Value *, 8> STValues;
855 for (unsigned i = 0; i < NumElement; i++) {
856 Type *TmpTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
857 Value *TmpVal = Builder.CreateBitCast(STVal, TmpTy);
858 TmpVal = Builder.CreateLShr(TmpVal, Builder.getInt32(i));
859 TmpVal = Builder.CreateAnd(TmpVal, Mask);
860 TmpVal = Builder.CreateTrunc(TmpVal, SrcTy);
861 STValues.push_back(TmpVal);
862 }
863
864 // Generate stores.
865 Value *SrcAddrIdx = NewAddrIdx;
866 Value *BaseAddr = Src;
867 for (unsigned i = 0; i < NumElement; i++) {
868 // Calculate store address.
869 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
870 Builder.CreateStore(STValues[i], DstAddr);
871
872 if (i + 1 < NumElement) {
873 // Calculate next store address
874 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
875 }
876 }
877
878 } else {
879 Inst->print(errs());
880 llvm_unreachable("Handle different size store with scalar "
881 "bitcast on ReplacePointerBitcastPass");
882 }
883 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
884 if (SrcTyBitWidth == DstTyBitWidth) {
885 Value *SrcAddr = Builder.CreateGEP(Src, NewAddrIdx);
886 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
887 LD->replaceAllUsesWith(Builder.CreateBitCast(SrcVal, DstTy));
888 } else if (SrcTyBitWidth < DstTyBitWidth) {
889 Value *SrcAddrIdx = NewAddrIdx;
890
891 // Load value from src.
892 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
893 SmallVector<Value *, 8> LDValues;
894 for (unsigned i = 1; i <= NumIter; i++) {
895 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
896 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
897 LDValues.push_back(SrcVal);
898
899 if (i + 1 <= NumIter) {
900 // Calculate next SrcAddrIdx.
901 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
902 }
903 }
904
905 // Merge Load.
906 Type *TmpSrcTy = Type::getIntNTy(M.getContext(), SrcTyBitWidth);
907 Value *DstVal = Builder.CreateBitCast(LDValues[0], TmpSrcTy);
908 Type *TmpDstTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
909 DstVal = Builder.CreateZExt(DstVal, TmpDstTy);
910 for (unsigned i = 1; i < LDValues.size(); i++) {
911 Value *TmpVal = Builder.CreateBitCast(LDValues[i], TmpSrcTy);
912 TmpVal = Builder.CreateZExt(TmpVal, TmpDstTy);
913 TmpVal = Builder.CreateShl(TmpVal,
914 Builder.getInt32(i * SrcTyBitWidth));
915 DstVal = Builder.CreateOr(DstVal, TmpVal);
916 }
917
918 DstVal = Builder.CreateBitCast(DstVal, DstTy);
919 LD->replaceAllUsesWith(DstVal);
920
921 } else {
922 Inst->print(errs());
923 llvm_unreachable("Handle different size load with scalar "
924 "bitcast on ReplacePointerBitcastPass");
925 }
926 } else {
David Neto22f144c2017-06-12 14:26:21 -0400927 Inst->print(errs());
928 llvm_unreachable("Handle above user of scalar bitcast with gep on "
929 "ReplacePointerBitcastPass");
930 }
931
932 ToBeDeleted.push_back(cast<Instruction>(U));
933 }
934
935 if (IsGEPUser) {
936 ToBeDeleted.push_back(GEP);
937 }
938 }
939
940 ToBeDeleted.push_back(Inst);
941 }
942
943 for (Instruction *Inst : ToBeDeleted) {
944 Inst->eraseFromParent();
945 }
946
947 return Changed;
948}