blob: 9b14a250c6cdee8497e850ac9ee4ef5531826280 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/IRBuilder.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/IR/Module.h>
18#include <llvm/Pass.h>
19#include <llvm/Support/raw_ostream.h>
20
21using namespace llvm;
22
23#define DEBUG_TYPE "replacepointerbitcast"
24
25namespace {
26struct ReplacePointerBitcastPass : public ModulePass {
27 static char ID;
28 ReplacePointerBitcastPass() : ModulePass(ID) {}
29
David Neto30ae05e2017-09-06 19:58:36 -040030 // Returns the number of chunks of source data required to exactly
31 // cover the destination data, if the source and destination types are
32 // different sizes. Otherwise returns 0.
David Neto22f144c2017-06-12 14:26:21 -040033 unsigned CalculateNumIter(unsigned SrcTyBitWidth, unsigned DstTyBitWidth);
34 Value *CalculateNewGEPIdx(unsigned SrcTyBitWidth, unsigned DstTyBitWidth,
35 GetElementPtrInst *GEP);
36
37 bool runOnModule(Module &M) override;
38};
39}
40
41char ReplacePointerBitcastPass::ID = 0;
42static RegisterPass<ReplacePointerBitcastPass>
43 X("ReplacePointerBitcast", "Replace Pointer Bitcast Pass");
44
45namespace clspv {
46ModulePass *createReplacePointerBitcastPass() {
47 return new ReplacePointerBitcastPass();
48}
49}
50
51unsigned ReplacePointerBitcastPass::CalculateNumIter(unsigned SrcTyBitWidth,
52 unsigned DstTyBitWidth) {
53 unsigned NumIter = 0;
54 if (SrcTyBitWidth > DstTyBitWidth) {
55 if (SrcTyBitWidth % DstTyBitWidth) {
56 llvm_unreachable(
57 "Src type bitwidth should be multiple of Dest type bitwidth");
58 }
59 NumIter = 1;
60 } else if (SrcTyBitWidth < DstTyBitWidth) {
61 if (DstTyBitWidth % SrcTyBitWidth) {
62 llvm_unreachable(
63 "Dest type bitwidth should be multiple of Src type bitwidth");
64 }
65 NumIter = DstTyBitWidth / SrcTyBitWidth;
66 } else {
67 NumIter = 0;
68 }
69
70 return NumIter;
71}
72
73Value *ReplacePointerBitcastPass::CalculateNewGEPIdx(
74 unsigned SrcTyBitWidth, unsigned DstTyBitWidth, GetElementPtrInst *GEP) {
75 Value *NewGEPIdx = GEP->getOperand(1);
76 IRBuilder<> Builder(GEP);
77
78 if (SrcTyBitWidth > DstTyBitWidth) {
79 if (GEP->getNumOperands() > 2) {
80 GEP->print(errs());
81 llvm_unreachable("Support above GEP on PointerBitcastPass");
82 }
83
84 NewGEPIdx = Builder.CreateLShr(
85 NewGEPIdx,
86 Builder.getInt32(std::log2(SrcTyBitWidth / DstTyBitWidth)));
87 } else if (DstTyBitWidth > SrcTyBitWidth) {
88 if (GEP->getNumOperands() > 2) {
89 GEP->print(errs());
90 llvm_unreachable("Support above GEP on PointerBitcastPass");
91 }
92
93 NewGEPIdx = Builder.CreateShl(
94 NewGEPIdx,
95 Builder.getInt32(std::log2(DstTyBitWidth / SrcTyBitWidth)));
96 }
97
98 return NewGEPIdx;
99}
100
101bool ReplacePointerBitcastPass::runOnModule(Module &M) {
102 bool Changed = false;
103
104 SmallVector<Instruction *, 16> VectorWorkList;
105 SmallVector<Instruction *, 16> ScalarWorkList;
106 for (Function &F : M) {
107 for (BasicBlock &BB : F) {
108 for (Instruction &I : BB) {
109 // Find pointer bitcast instruction.
110 if (isa<BitCastInst>(&I) && isa<PointerType>(I.getType())) {
111 Value *Src = I.getOperand(0);
112 if (isa<PointerType>(Src->getType())) {
113 Type *SrcEleTy =
114 I.getOperand(0)->getType()->getPointerElementType();
115 Type *DstEleTy = I.getType()->getPointerElementType();
116 if (SrcEleTy->isVectorTy() || DstEleTy->isVectorTy()) {
117 // Handle case either operand is vector type like char4* -> int4*.
118 VectorWorkList.push_back(&I);
119 } else {
120 // Handle case all operands are scalar type like char* -> int*.
121 ScalarWorkList.push_back(&I);
122 }
123
124 Changed = true;
125 } else {
126 llvm_unreachable("Unsupported bitcast");
127 }
128 }
129 }
130 }
131 }
132
133 SmallVector<Instruction *, 16> ToBeDeleted;
134 for (Instruction *Inst : VectorWorkList) {
135 Value *Src = Inst->getOperand(0);
136 Type *SrcTy = Src->getType()->getPointerElementType();
137 Type *DstTy = Inst->getType()->getPointerElementType();
138 Type *SrcEleTy =
139 SrcTy->isVectorTy() ? SrcTy->getSequentialElementType() : SrcTy;
140 Type *DstEleTy =
141 DstTy->isVectorTy() ? DstTy->getSequentialElementType() : DstTy;
David Neto30ae05e2017-09-06 19:58:36 -0400142 // These are bit widths of the source and destination types, even
143 // if they are vector types. E.g. bit width of float4 is 64.
David Neto22f144c2017-06-12 14:26:21 -0400144 unsigned SrcTyBitWidth = SrcTy->getPrimitiveSizeInBits();
145 unsigned DstTyBitWidth = DstTy->getPrimitiveSizeInBits();
146 unsigned SrcEleTyBitWidth = SrcEleTy->getPrimitiveSizeInBits();
147 unsigned DstEleTyBitWidth = DstEleTy->getPrimitiveSizeInBits();
148 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
149
150 // Investigate pointer bitcast's users.
151 for (User *BitCastUser : Inst->users()) {
152 Value *BitCastSrc = Inst->getOperand(0);
153 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
154
155 // It consist of User* and bool whether user is gep or not.
156 SmallVector<std::pair<User *, bool>, 32> Users;
157
158 GetElementPtrInst *GEP = nullptr;
159 Value *OrgGEPIdx = nullptr;
160 if (GEP = dyn_cast<GetElementPtrInst>(BitCastUser)) {
161 OrgGEPIdx = GEP->getOperand(1);
162
163 // Build new src/dst address index.
164 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
165
166 // Record gep's users.
167 for (User *GEPUser : GEP->users()) {
168 Users.push_back(std::make_pair(GEPUser, true));
169 }
170 } else {
171 // Record bitcast's users.
172 Users.push_back(std::make_pair(BitCastUser, false));
173 }
174
175 // Handle users.
176 bool IsGEPUser = false;
177 for (auto UserIter : Users) {
178 User *U = UserIter.first;
179 IsGEPUser = UserIter.second;
180
181 IRBuilder<> Builder(cast<Instruction>(U));
182
183 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
184 if (SrcTyBitWidth < DstTyBitWidth) {
185 //
186 // Consider below case.
187 //
188 // Original IR (float2* --> float4*)
189 // 1. val = load (float4*) src_addr
190 // 2. dst_addr = bitcast float2*, float4*
191 // 3. dst_addr = gep (float4*) dst_addr, idx
192 // 4. store (float4*) dst_addr
193 //
194 // Transformed IR
195 // 1. val(float4) = load (float4*) src_addr
196 // 2. val1(float2) = shufflevector (float4)val, (float4)undef,
197 // (float2)<0, 1>
198 // 3. val2(float2) = shufflevector (float4)val, (float4)undef,
199 // (float2)<2, 3>
200 // 4. dst_addr1(float2*) = gep (float2*)dst_addr, idx * 2
201 // 5. dst_addr2(float2*) = gep (float2*)dst_addr, idx * 2 + 1
202 // 6. store (float2)val1, (float2*)dst_addr1
203 // 7. store (float2)val2, (float2*)dst_addr2
204 //
205
206 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
207 unsigned NumVector = 1;
208 // Vulkan SPIR-V does not support over 4 components for
209 // TypeVector.
210 if (NumElement > 4) {
211 NumVector = NumElement >> 2;
212 NumElement = 4;
213 }
214
215 // Create store values.
216 Type *TmpValTy = SrcTy;
217 if (DstTy->isVectorTy()) {
218 if (SrcEleTyBitWidth == DstEleTyBitWidth) {
219 TmpValTy =
220 VectorType::get(SrcEleTy, DstTy->getVectorNumElements());
221 } else {
222 TmpValTy = VectorType::get(SrcEleTy, NumElement);
223 }
224 }
225
226 Value *STVal = ST->getValueOperand();
227 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
228 Value *TmpSTVal = nullptr;
229 if (NumVector == 1) {
230 TmpSTVal = Builder.CreateBitCast(STVal, TmpValTy);
231 } else {
232 unsigned DstVecTyNumElement =
233 DstTy->getVectorNumElements() / NumVector;
234 SmallVector<uint32_t, 4> Idxs;
235 for (unsigned i = 0; i < DstVecTyNumElement; i++) {
236 Idxs.push_back(i + (DstVecTyNumElement * VIdx));
237 }
238 Value *UndefVal = UndefValue::get(DstTy);
239 TmpSTVal = Builder.CreateShuffleVector(STVal, UndefVal, Idxs);
240 TmpSTVal = Builder.CreateBitCast(TmpSTVal, TmpValTy);
241 }
242
243 SmallVector<Value *, 8> STValues;
244 if (!SrcTy->isVectorTy()) {
245 // Handle scalar type.
246 for (unsigned i = 0; i < NumElement; i++) {
247 Value *TmpVal = Builder.CreateExtractElement(
248 TmpSTVal, Builder.getInt32(i));
249 STValues.push_back(TmpVal);
250 }
251 } else {
252 // Handle vector type.
253 unsigned SrcNumElement = SrcTy->getVectorNumElements();
254 unsigned DstNumElement = DstTy->getVectorNumElements();
255 for (unsigned i = 0; i < NumElement; i++) {
256 SmallVector<uint32_t, 4> Idxs;
257 for (unsigned j = 0; j < SrcNumElement; j++) {
258 Idxs.push_back(i * SrcNumElement + j);
259 }
260
261 VectorType *TmpVecTy =
262 VectorType::get(SrcEleTy, DstNumElement);
263 Value *UndefVal = UndefValue::get(TmpVecTy);
264 Value *TmpVal =
265 Builder.CreateShuffleVector(TmpSTVal, UndefVal, Idxs);
266 STValues.push_back(TmpVal);
267 }
268 }
269
270 // Generate stores.
271 Value *SrcAddrIdx = NewAddrIdx;
272 Value *BaseAddr = BitCastSrc;
273 for (unsigned i = 0; i < NumElement; i++) {
274 // Calculate store address.
275 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
276 Builder.CreateStore(STValues[i], DstAddr);
277
278 if (i + 1 < NumElement) {
279 // Calculate next store address
280 SrcAddrIdx =
281 Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
282 }
283 }
284 }
285 } else if (SrcTyBitWidth > DstTyBitWidth) {
286 //
287 // Consider below case.
288 //
289 // Original IR (float4* --> float2*)
290 // 1. val = load (float2*) src_addr
291 // 2. dst_addr = bitcast float4*, float2*
292 // 3. dst_addr = gep (float2*) dst_addr, idx
293 // 4. store (float2) val, (float2*) dst_addr
294 //
David Neto30ae05e2017-09-06 19:58:36 -0400295 // Transformed IR: Decompose the source vector into elements, then write
296 // them one at a time.
David Neto22f144c2017-06-12 14:26:21 -0400297 // 1. val = load (float2*) src_addr
298 // 2. val1 = (float)extract_element val, 0
299 // 3. val2 = (float)extract_element val, 1
David Neto30ae05e2017-09-06 19:58:36 -0400300 // // Source component k maps to destination component k * idxscale
301 // 3a. idxscale = sizeof(float4)/sizeof(float2)
302 // 3b. idxbase = idx / idxscale
303 // 3c. newarrayidx = idxbase * idxscale
304 // 4. dst_addr1 = gep (float4*) dst, newarrayidx
305 // 5. dst_addr2 = gep (float4*) dst, newarrayidx + 1
David Neto22f144c2017-06-12 14:26:21 -0400306 // 6. store (float)val1, (float*) dst_addr1
307 // 7. store (float)val2, (float*) dst_addr2
308 //
309
310 if (SrcTyBitWidth <= DstEleTyBitWidth) {
311 SrcTy->print(errs());
312 DstTy->print(errs());
313 llvm_unreachable("Handle above src/dst type.");
314 }
315
316 // Create store values.
317 Value *STVal = ST->getValueOperand();
318
319 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
320 VectorType *TmpVecTy =
321 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
322 STVal = Builder.CreateBitCast(STVal, TmpVecTy);
323 }
324
325 SmallVector<Value *, 8> STValues;
David Neto30ae05e2017-09-06 19:58:36 -0400326 // How many destination writes are required?
David Neto22f144c2017-06-12 14:26:21 -0400327 unsigned DstNumElement = 1;
328 if (!DstTy->isVectorTy() || SrcEleTyBitWidth == DstTyBitWidth) {
329 // Handle scalar type.
330 STValues.push_back(STVal);
331 } else {
332 // Handle vector type.
333 DstNumElement = DstTy->getVectorNumElements();
334 for (unsigned i = 0; i < DstNumElement; i++) {
335 Value *Idx = Builder.getInt32(i);
336 Value *TmpVal = Builder.CreateExtractElement(STVal, Idx);
337 STValues.push_back(TmpVal);
338 }
339 }
340
341 // Generate stores.
342 Value *BaseAddr = BitCastSrc;
343 Value *SubEleIdx = Builder.getInt32(0);
344 if (IsGEPUser) {
David Neto30ae05e2017-09-06 19:58:36 -0400345 // Compute SubNumElement = idxscale
David Neto22f144c2017-06-12 14:26:21 -0400346 unsigned SubNumElement = SrcTy->getVectorNumElements();
347 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
David Neto30ae05e2017-09-06 19:58:36 -0400348 // Same condition under which DstNumElements > 1
David Neto22f144c2017-06-12 14:26:21 -0400349 SubNumElement = SrcTy->getVectorNumElements() /
350 DstTy->getVectorNumElements();
351 }
352
David Neto30ae05e2017-09-06 19:58:36 -0400353 // Compute SubEleIdx = idxbase * idxscale
David Neto22f144c2017-06-12 14:26:21 -0400354 SubEleIdx = Builder.CreateAnd(
355 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
David Neto30ae05e2017-09-06 19:58:36 -0400356 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
357 SubEleIdx = Builder.CreateShl(
358 SubEleIdx, Builder.getInt32(std::log2(SubNumElement)));
359 }
David Neto22f144c2017-06-12 14:26:21 -0400360 }
361
362 for (unsigned i = 0; i < DstNumElement; i++) {
363 // Calculate address.
364 if (i > 0) {
365 SubEleIdx = Builder.CreateAdd(SubEleIdx, Builder.getInt32(i));
366 }
367
368 Value *Idxs[] = {NewAddrIdx, SubEleIdx};
369 Value *DstAddr = Builder.CreateGEP(BaseAddr, Idxs);
370 Type *TmpSrcTy = SrcEleTy;
371 if (TmpSrcTy->isVectorTy()) {
372 TmpSrcTy = TmpSrcTy->getVectorElementType();
373 }
374 Value *TmpVal = Builder.CreateBitCast(STValues[i], TmpSrcTy);
375
376 Builder.CreateStore(TmpVal, DstAddr);
377 }
378 } else {
379 // if SrcTyBitWidth == DstTyBitWidth
380 Type *TmpSrcTy = SrcTy;
381 Value *DstAddr = Src;
382
383 if (IsGEPUser) {
384 SmallVector<Value *, 4> Idxs;
385 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
386 Idxs.push_back(GEP->getOperand(i));
387 }
388 DstAddr = Builder.CreateGEP(BitCastSrc, Idxs);
389
390 if (GEP->getNumOperands() > 2) {
391 TmpSrcTy = SrcEleTy;
392 }
393 }
394
395 Value *TmpVal =
396 Builder.CreateBitCast(ST->getValueOperand(), TmpSrcTy);
397 Builder.CreateStore(TmpVal, DstAddr);
398 }
399 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
400 Value *SrcAddrIdx = Builder.getInt32(0);
401 if (IsGEPUser) {
402 SrcAddrIdx = NewAddrIdx;
403 }
404
405 // Load value from src.
406 SmallVector<Value *, 8> LDValues;
407
408 for (unsigned i = 1; i <= NumIter; i++) {
409 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
410 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
411 LDValues.push_back(SrcVal);
412
413 if (i + 1 <= NumIter) {
414 // Calculate next SrcAddrIdx.
415 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
416 }
417 }
418
419 Value *DstVal = nullptr;
420 if (SrcTyBitWidth > DstTyBitWidth) {
421 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
422
423 if (SrcEleTyBitWidth == DstTyBitWidth) {
424 //
425 // Consider below case.
426 //
427 // Original IR (int4* --> char4*)
428 // 1. src_addr = bitcast int4*, char4*
429 // 2. element_addr = gep (char4*) src_addr, idx
430 // 3. load (char4*) element_addr
431 //
432 // Transformed IR
433 // 1. src_addr = gep (int4*) src, idx / 4
434 // 2. src_val(int4) = load (int4*) src_addr
435 // 3. tmp_val(int4) = extractelement src_val, idx % 4
436 // 4. dst_val(char4) = bitcast tmp_val, (char4)
437 //
438 Value *EleIdx = Builder.getInt32(0);
439 if (IsGEPUser) {
440 EleIdx = Builder.CreateAnd(OrgGEPIdx,
441 Builder.getInt32(NumElement - 1));
442 }
443 Value *TmpVal =
444 Builder.CreateExtractElement(LDValues[0], EleIdx, "tmp_val");
445 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
446 } else if (SrcEleTyBitWidth < DstTyBitWidth) {
447 if (IsGEPUser) {
448 //
449 // Consider below case.
450 //
451 // Original IR (float4* --> float2*)
452 // 1. src_addr = bitcast float4*, float2*
453 // 2. element_addr = gep (float2*) src_addr, idx
454 // 3. load (float2*) element_addr
455 //
456 // Transformed IR
457 // 1. src_addr = gep (float4*) src, idx / 2
458 // 2. src_val(float4) = load (float4*) src_addr
459 // 3. tmp_val1(float) = extractelement (idx % 2) * 2
460 // 4. tmp_val2(float) = extractelement (idx % 2) * 2 + 1
461 // 5. dst_val(float2) = insertelement undef(float2), tmp_val1, 0
462 // 6. dst_val(float2) = insertelement undef(float2), tmp_val2, 1
463 // 7. dst_val(float2) = bitcast dst_val, (float2)
464 // ==> if types are same between src and dst, it will be
465 // igonored
466 //
467 VectorType *TmpVecTy =
468 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
469 DstVal = UndefValue::get(TmpVecTy);
470 Value *EleIdx = Builder.CreateAnd(
471 OrgGEPIdx, Builder.getInt32(NumElement - 1));
472 EleIdx = Builder.CreateShl(
473 EleIdx, Builder.getInt32(
474 std::log2(DstTyBitWidth / SrcEleTyBitWidth)));
475 Value *TmpOrgGEPIdx = EleIdx;
476 for (unsigned i = 0; i < NumElement; i++) {
477 Value *TmpVal = Builder.CreateExtractElement(
478 LDValues[0], TmpOrgGEPIdx, "tmp_val");
479 DstVal = Builder.CreateInsertElement(DstVal, TmpVal,
480 Builder.getInt32(i));
481
482 if (i + 1 < NumElement) {
483 TmpOrgGEPIdx =
484 Builder.CreateAdd(TmpOrgGEPIdx, Builder.getInt32(1));
485 }
486 }
487 } else {
488 //
489 // Consider below case.
490 //
491 // Original IR (float4* --> int2*)
492 // 1. src_addr = bitcast float4*, int2*
493 // 2. load (int2*) src_addr
494 //
495 // Transformed IR
496 // 1. src_val(float4) = load (float4*) src_addr
497 // 2. tmp_val(float2) = shufflevector (float4)src_val,
498 // (float4)undef,
499 // (float2)<0, 1>
500 // 3. dst_val(int2) = bitcast (float2)tmp_val, (int2)
501 //
502 unsigned NumElement = DstTyBitWidth / SrcEleTyBitWidth;
503 Value *Undef = UndefValue::get(SrcTy);
504
505 SmallVector<uint32_t, 4> Idxs;
506 for (unsigned i = 0; i < NumElement; i++) {
507 Idxs.push_back(i);
508 }
509 DstVal = Builder.CreateShuffleVector(LDValues[0], Undef, Idxs);
510
511 DstVal = Builder.CreateBitCast(DstVal, DstTy);
512 }
513
514 DstVal = Builder.CreateBitCast(DstVal, DstTy);
515 } else {
516 if (IsGEPUser) {
517 //
518 // Consider below case.
519 //
520 // Original IR (int4* --> char2*)
521 // 1. src_addr = bitcast int4*, char2*
522 // 2. element_addr = gep (char2*) src_addr, idx
523 // 3. load (char2*) element_addr
524 //
525 // Transformed IR
526 // 1. src_addr = gep (int4*) src, idx / 8
527 // 2. src_val(int4) = load (int4*) src_addr
528 // 3. tmp_val(int) = extractelement idx / 2
529 // 4. tmp_val(<i16 x 2>) = bitcast tmp_val(int), (<i16 x 2>)
530 // 5. tmp_val(i16) = extractelement idx % 2
531 // 6. dst_val(char2) = bitcast tmp_val, (char2)
532 // ==> if types are same between src and dst, it will be
533 // igonored
534 //
535 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
536 unsigned SubNumElement = SrcEleTyBitWidth / DstTyBitWidth;
537 if (SubNumElement != 2 && SubNumElement != 4) {
538 llvm_unreachable("Unsupported SubNumElement");
539 }
540
541 Value *TmpOrgGEPIdx = Builder.CreateLShr(
542 OrgGEPIdx, Builder.getInt32(std::log2(SubNumElement)));
543 Value *TmpVal = Builder.CreateExtractElement(
544 LDValues[0], TmpOrgGEPIdx, "tmp_val");
545 TmpVal = Builder.CreateBitCast(
546 TmpVal,
547 VectorType::get(
548 IntegerType::get(DstTy->getContext(), DstTyBitWidth),
549 SubNumElement));
550 TmpOrgGEPIdx = Builder.CreateAnd(
551 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
552 TmpVal = Builder.CreateExtractElement(TmpVal, TmpOrgGEPIdx,
553 "tmp_val");
554 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
555 } else {
556 Inst->print(errs());
557 llvm_unreachable("Handle this bitcast");
558 }
559 }
560 } else if (SrcTyBitWidth < DstTyBitWidth) {
561 //
562 // Consider below case.
563 //
564 // Original IR (float2* --> float4*)
565 // 1. src_addr = bitcast float2*, float4*
566 // 2. element_addr = gep (float4*) src_addr, idx
567 // 3. load (float4*) element_addr
568 //
569 // Transformed IR
570 // 1. src_addr = gep (float2*) src, idx * 2
571 // 2. src_val1(float2) = load (float2*) src_addr
572 // 3. src_addr2 = gep (float2*) src_addr, 1
573 // 4. src_val2(float2) = load (float2*) src_addr2
574 // 5. dst_val(float4) = shufflevector src_val1, src_val2, <0, 1>
575 // 6. dst_val(float4) = bitcast dst_val, (float4)
576 // ==> if types are same between src and dst, it will be igonored
577 //
578 unsigned NumElement = 1;
579 if (SrcTy->isVectorTy()) {
580 NumElement = SrcTy->getVectorNumElements() * 2;
581 }
582
583 // Handle scalar type.
584 if (NumElement == 1) {
585 if (SrcTyBitWidth * 4 <= DstTyBitWidth) {
586 unsigned NumVecElement = DstTyBitWidth / SrcTyBitWidth;
587 unsigned NumVector = 1;
588 if (NumVecElement > 4) {
589 NumVector = NumVecElement >> 2;
590 NumVecElement = 4;
591 }
592
593 SmallVector<Value *, 4> Values;
594 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
595 // In this case, generate only insert element. It generates
596 // less instructions than using shuffle vector.
597 VectorType *TmpVecTy = VectorType::get(SrcTy, NumVecElement);
598 Value *TmpVal = UndefValue::get(TmpVecTy);
599 for (unsigned i = 0; i < NumVecElement; i++) {
600 TmpVal = Builder.CreateInsertElement(
601 TmpVal, LDValues[i + (VIdx * 4)], Builder.getInt32(i));
602 }
603 Values.push_back(TmpVal);
604 }
605
606 if (Values.size() > 2) {
607 Inst->print(errs());
608 llvm_unreachable("Support above bitcast");
609 }
610
611 if (Values.size() > 1) {
612 Type *TmpEleTy =
613 Type::getIntNTy(M.getContext(), SrcEleTyBitWidth * 2);
614 VectorType *TmpVecTy = VectorType::get(TmpEleTy, NumVector);
615 for (unsigned i = 0; i < Values.size(); i++) {
616 Values[i] = Builder.CreateBitCast(Values[i], TmpVecTy);
617 }
618 SmallVector<uint32_t, 4> Idxs;
619 for (unsigned i = 0; i < (NumVector * 2); i++) {
620 Idxs.push_back(i);
621 }
622 for (unsigned i = 0; i < Values.size(); i = i + 2) {
623 Values[i] = Builder.CreateShuffleVector(
624 Values[i], Values[i + 1], Idxs);
625 }
626 }
627
628 LDValues.clear();
629 LDValues.push_back(Values[0]);
630 } else {
631 SmallVector<Value *, 4> TmpLDValues;
632 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
633 VectorType *TmpVecTy = VectorType::get(SrcTy, 2);
634 Value *TmpVal = UndefValue::get(TmpVecTy);
635 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i],
636 Builder.getInt32(0));
637 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i + 1],
638 Builder.getInt32(1));
639 TmpLDValues.push_back(TmpVal);
640 }
641 LDValues.clear();
642 LDValues = std::move(TmpLDValues);
643 NumElement = 4;
644 }
645 }
646
647 // Handle vector type.
648 while (LDValues.size() != 1) {
649 SmallVector<Value *, 4> TmpLDValues;
650 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
651 SmallVector<uint32_t, 4> Idxs;
652 for (unsigned j = 0; j < NumElement; j++) {
653 Idxs.push_back(j);
654 }
655 Value *TmpVal = Builder.CreateShuffleVector(
656 LDValues[i], LDValues[i + 1], Idxs);
657 TmpLDValues.push_back(TmpVal);
658 }
659 LDValues.clear();
660 LDValues = std::move(TmpLDValues);
661 NumElement *= 2;
662 }
663
664 DstVal = Builder.CreateBitCast(LDValues[0], DstTy);
665 } else {
666 //
667 // Consider below case.
668 //
669 // Original IR (float4* --> int4*)
670 // 1. src_addr = bitcast float4*, int4*
671 // 2. element_addr = gep (int4*) src_addr, idx, 0
672 // 3. load (int) element_addr
673 //
674 // Transformed IR
675 // 1. element_addr = gep (float4*) src_addr, idx, 0
676 // 2. src_val = load (float*) element_addr
677 // 3. val = bitcast (float) src_val to (int)
678 //
679 Value *SrcAddr = Src;
680 if (IsGEPUser) {
681 SmallVector<Value *, 4> Idxs;
682 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
683 Idxs.push_back(GEP->getOperand(i));
684 }
685 SrcAddr = Builder.CreateGEP(Src, Idxs);
686 }
687 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
688
689 Type *TmpDstTy = DstTy;
690 if (IsGEPUser) {
691 if (GEP->getNumOperands() > 2) {
692 TmpDstTy = DstEleTy;
693 }
694 }
695 DstVal = Builder.CreateBitCast(SrcVal, TmpDstTy);
696 }
697
698 // Update LD's users with DstVal.
699 LD->replaceAllUsesWith(DstVal);
700 } else {
701 U->print(errs());
702 llvm_unreachable(
703 "Handle above user of gep on ReplacePointerBitcastPass");
704 }
705
706 ToBeDeleted.push_back(cast<Instruction>(U));
707 }
708
709 if (IsGEPUser) {
710 ToBeDeleted.push_back(GEP);
711 }
712 }
713
714 ToBeDeleted.push_back(Inst);
715 }
716
717 for (Instruction *Inst : ScalarWorkList) {
718 Value *Src = Inst->getOperand(0);
719 Type *SrcTy = Src->getType()->getPointerElementType();
720 Type *DstTy = Inst->getType()->getPointerElementType();
721 unsigned SrcTyBitWidth = SrcTy->getPrimitiveSizeInBits();
722 unsigned DstTyBitWidth = DstTy->getPrimitiveSizeInBits();
723
724 for (User *BitCastUser : Inst->users()) {
725 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
726 // It consist of User* and bool whether user is gep or not.
727 SmallVector<std::pair<User *, bool>, 32> Users;
728
729 GetElementPtrInst *GEP = nullptr;
730 Value *OrgGEPIdx = nullptr;
731 if (GEP = dyn_cast<GetElementPtrInst>(BitCastUser)) {
732 IRBuilder<> Builder(GEP);
733
734 // Build new src/dst address.
735 OrgGEPIdx = GEP->getOperand(1);
736 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
737
738 // If bitcast's user is gep, investigate gep's users too.
739 for (User *GEPUser : GEP->users()) {
740 Users.push_back(std::make_pair(GEPUser, true));
741 }
742 } else {
743 Users.push_back(std::make_pair(BitCastUser, false));
744 }
745
746 // Handle users.
747 bool IsGEPUser = false;
748 for (auto UserIter : Users) {
749 User *U = UserIter.first;
750 IsGEPUser = UserIter.second;
751
752 IRBuilder<> Builder(cast<Instruction>(U));
753
754 // Handle store instruction with gep.
755 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
756 if (SrcTyBitWidth == DstTyBitWidth) {
757 auto STVal = Builder.CreateBitCast(ST->getValueOperand(), SrcTy);
758 Value *DstAddr = Builder.CreateGEP(Src, NewAddrIdx);
759 Builder.CreateStore(STVal, DstAddr);
760 } else if (SrcTyBitWidth < DstTyBitWidth) {
761 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
762
763 // Create Mask.
764 Constant *Mask = nullptr;
765 if (NumElement == 1) {
766 Mask = Builder.getInt32(0xFF);
767 } else if (NumElement == 2) {
768 Mask = Builder.getInt32(0xFFFF);
769 } else {
770 llvm_unreachable("strange type on bitcast");
771 }
772
773 // Create store values.
774 Value *STVal = ST->getValueOperand();
775 SmallVector<Value *, 8> STValues;
776 for (unsigned i = 0; i < NumElement; i++) {
777 Type *TmpTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
778 Value *TmpVal = Builder.CreateBitCast(STVal, TmpTy);
779 TmpVal = Builder.CreateLShr(TmpVal, Builder.getInt32(i));
780 TmpVal = Builder.CreateAnd(TmpVal, Mask);
781 TmpVal = Builder.CreateTrunc(TmpVal, SrcTy);
782 STValues.push_back(TmpVal);
783 }
784
785 // Generate stores.
786 Value *SrcAddrIdx = NewAddrIdx;
787 Value *BaseAddr = Src;
788 for (unsigned i = 0; i < NumElement; i++) {
789 // Calculate store address.
790 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
791 Builder.CreateStore(STValues[i], DstAddr);
792
793 if (i + 1 < NumElement) {
794 // Calculate next store address
795 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
796 }
797 }
798
799 } else {
800 Inst->print(errs());
801 llvm_unreachable("Handle different size store with scalar "
802 "bitcast on ReplacePointerBitcastPass");
803 }
804 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
805 if (SrcTyBitWidth == DstTyBitWidth) {
806 Value *SrcAddr = Builder.CreateGEP(Src, NewAddrIdx);
807 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
808 LD->replaceAllUsesWith(Builder.CreateBitCast(SrcVal, DstTy));
809 } else if (SrcTyBitWidth < DstTyBitWidth) {
810 Value *SrcAddrIdx = NewAddrIdx;
811
812 // Load value from src.
813 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
814 SmallVector<Value *, 8> LDValues;
815 for (unsigned i = 1; i <= NumIter; i++) {
816 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
817 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
818 LDValues.push_back(SrcVal);
819
820 if (i + 1 <= NumIter) {
821 // Calculate next SrcAddrIdx.
822 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
823 }
824 }
825
826 // Merge Load.
827 Type *TmpSrcTy = Type::getIntNTy(M.getContext(), SrcTyBitWidth);
828 Value *DstVal = Builder.CreateBitCast(LDValues[0], TmpSrcTy);
829 Type *TmpDstTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
830 DstVal = Builder.CreateZExt(DstVal, TmpDstTy);
831 for (unsigned i = 1; i < LDValues.size(); i++) {
832 Value *TmpVal = Builder.CreateBitCast(LDValues[i], TmpSrcTy);
833 TmpVal = Builder.CreateZExt(TmpVal, TmpDstTy);
834 TmpVal = Builder.CreateShl(TmpVal,
835 Builder.getInt32(i * SrcTyBitWidth));
836 DstVal = Builder.CreateOr(DstVal, TmpVal);
837 }
838
839 DstVal = Builder.CreateBitCast(DstVal, DstTy);
840 LD->replaceAllUsesWith(DstVal);
841
842 } else {
843 Inst->print(errs());
844 llvm_unreachable("Handle different size load with scalar "
845 "bitcast on ReplacePointerBitcastPass");
846 }
847 } else {
848
849 errs() << "About to crash at 820" << M << "Aboot\n\n";
850 Inst->print(errs());
851 llvm_unreachable("Handle above user of scalar bitcast with gep on "
852 "ReplacePointerBitcastPass");
853 }
854
855 ToBeDeleted.push_back(cast<Instruction>(U));
856 }
857
858 if (IsGEPUser) {
859 ToBeDeleted.push_back(GEP);
860 }
861 }
862
863 ToBeDeleted.push_back(Inst);
864 }
865
866 for (Instruction *Inst : ToBeDeleted) {
867 Inst->eraseFromParent();
868 }
869
870 return Changed;
871}