blob: 1c2f288d195fd16503451d29d0e7a1144f7f4e69 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <llvm/IR/IRBuilder.h>
16#include <llvm/IR/Instructions.h>
17#include <llvm/IR/Module.h>
18#include <llvm/Pass.h>
19#include <llvm/Support/raw_ostream.h>
20
21using namespace llvm;
22
23#define DEBUG_TYPE "replacepointerbitcast"
24
25namespace {
26struct ReplacePointerBitcastPass : public ModulePass {
27 static char ID;
28 ReplacePointerBitcastPass() : ModulePass(ID) {}
29
30 unsigned CalculateNumIter(unsigned SrcTyBitWidth, unsigned DstTyBitWidth);
31 Value *CalculateNewGEPIdx(unsigned SrcTyBitWidth, unsigned DstTyBitWidth,
32 GetElementPtrInst *GEP);
33
34 bool runOnModule(Module &M) override;
35};
36}
37
38char ReplacePointerBitcastPass::ID = 0;
39static RegisterPass<ReplacePointerBitcastPass>
40 X("ReplacePointerBitcast", "Replace Pointer Bitcast Pass");
41
42namespace clspv {
43ModulePass *createReplacePointerBitcastPass() {
44 return new ReplacePointerBitcastPass();
45}
46}
47
48unsigned ReplacePointerBitcastPass::CalculateNumIter(unsigned SrcTyBitWidth,
49 unsigned DstTyBitWidth) {
50 unsigned NumIter = 0;
51 if (SrcTyBitWidth > DstTyBitWidth) {
52 if (SrcTyBitWidth % DstTyBitWidth) {
53 llvm_unreachable(
54 "Src type bitwidth should be multiple of Dest type bitwidth");
55 }
56 NumIter = 1;
57 } else if (SrcTyBitWidth < DstTyBitWidth) {
58 if (DstTyBitWidth % SrcTyBitWidth) {
59 llvm_unreachable(
60 "Dest type bitwidth should be multiple of Src type bitwidth");
61 }
62 NumIter = DstTyBitWidth / SrcTyBitWidth;
63 } else {
64 NumIter = 0;
65 }
66
67 return NumIter;
68}
69
70Value *ReplacePointerBitcastPass::CalculateNewGEPIdx(
71 unsigned SrcTyBitWidth, unsigned DstTyBitWidth, GetElementPtrInst *GEP) {
72 Value *NewGEPIdx = GEP->getOperand(1);
73 IRBuilder<> Builder(GEP);
74
75 if (SrcTyBitWidth > DstTyBitWidth) {
76 if (GEP->getNumOperands() > 2) {
77 GEP->print(errs());
78 llvm_unreachable("Support above GEP on PointerBitcastPass");
79 }
80
81 NewGEPIdx = Builder.CreateLShr(
82 NewGEPIdx,
83 Builder.getInt32(std::log2(SrcTyBitWidth / DstTyBitWidth)));
84 } else if (DstTyBitWidth > SrcTyBitWidth) {
85 if (GEP->getNumOperands() > 2) {
86 GEP->print(errs());
87 llvm_unreachable("Support above GEP on PointerBitcastPass");
88 }
89
90 NewGEPIdx = Builder.CreateShl(
91 NewGEPIdx,
92 Builder.getInt32(std::log2(DstTyBitWidth / SrcTyBitWidth)));
93 }
94
95 return NewGEPIdx;
96}
97
98bool ReplacePointerBitcastPass::runOnModule(Module &M) {
99 bool Changed = false;
100
101 SmallVector<Instruction *, 16> VectorWorkList;
102 SmallVector<Instruction *, 16> ScalarWorkList;
103 for (Function &F : M) {
104 for (BasicBlock &BB : F) {
105 for (Instruction &I : BB) {
106 // Find pointer bitcast instruction.
107 if (isa<BitCastInst>(&I) && isa<PointerType>(I.getType())) {
108 Value *Src = I.getOperand(0);
109 if (isa<PointerType>(Src->getType())) {
110 Type *SrcEleTy =
111 I.getOperand(0)->getType()->getPointerElementType();
112 Type *DstEleTy = I.getType()->getPointerElementType();
113 if (SrcEleTy->isVectorTy() || DstEleTy->isVectorTy()) {
114 // Handle case either operand is vector type like char4* -> int4*.
115 VectorWorkList.push_back(&I);
116 } else {
117 // Handle case all operands are scalar type like char* -> int*.
118 ScalarWorkList.push_back(&I);
119 }
120
121 Changed = true;
122 } else {
123 llvm_unreachable("Unsupported bitcast");
124 }
125 }
126 }
127 }
128 }
129
130 SmallVector<Instruction *, 16> ToBeDeleted;
131 for (Instruction *Inst : VectorWorkList) {
132 Value *Src = Inst->getOperand(0);
133 Type *SrcTy = Src->getType()->getPointerElementType();
134 Type *DstTy = Inst->getType()->getPointerElementType();
135 Type *SrcEleTy =
136 SrcTy->isVectorTy() ? SrcTy->getSequentialElementType() : SrcTy;
137 Type *DstEleTy =
138 DstTy->isVectorTy() ? DstTy->getSequentialElementType() : DstTy;
139 unsigned SrcTyBitWidth = SrcTy->getPrimitiveSizeInBits();
140 unsigned DstTyBitWidth = DstTy->getPrimitiveSizeInBits();
141 unsigned SrcEleTyBitWidth = SrcEleTy->getPrimitiveSizeInBits();
142 unsigned DstEleTyBitWidth = DstEleTy->getPrimitiveSizeInBits();
143 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
144
145 // Investigate pointer bitcast's users.
146 for (User *BitCastUser : Inst->users()) {
147 Value *BitCastSrc = Inst->getOperand(0);
148 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
149
150 // It consist of User* and bool whether user is gep or not.
151 SmallVector<std::pair<User *, bool>, 32> Users;
152
153 GetElementPtrInst *GEP = nullptr;
154 Value *OrgGEPIdx = nullptr;
155 if (GEP = dyn_cast<GetElementPtrInst>(BitCastUser)) {
156 OrgGEPIdx = GEP->getOperand(1);
157
158 // Build new src/dst address index.
159 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
160
161 // Record gep's users.
162 for (User *GEPUser : GEP->users()) {
163 Users.push_back(std::make_pair(GEPUser, true));
164 }
165 } else {
166 // Record bitcast's users.
167 Users.push_back(std::make_pair(BitCastUser, false));
168 }
169
170 // Handle users.
171 bool IsGEPUser = false;
172 for (auto UserIter : Users) {
173 User *U = UserIter.first;
174 IsGEPUser = UserIter.second;
175
176 IRBuilder<> Builder(cast<Instruction>(U));
177
178 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
179 if (SrcTyBitWidth < DstTyBitWidth) {
180 //
181 // Consider below case.
182 //
183 // Original IR (float2* --> float4*)
184 // 1. val = load (float4*) src_addr
185 // 2. dst_addr = bitcast float2*, float4*
186 // 3. dst_addr = gep (float4*) dst_addr, idx
187 // 4. store (float4*) dst_addr
188 //
189 // Transformed IR
190 // 1. val(float4) = load (float4*) src_addr
191 // 2. val1(float2) = shufflevector (float4)val, (float4)undef,
192 // (float2)<0, 1>
193 // 3. val2(float2) = shufflevector (float4)val, (float4)undef,
194 // (float2)<2, 3>
195 // 4. dst_addr1(float2*) = gep (float2*)dst_addr, idx * 2
196 // 5. dst_addr2(float2*) = gep (float2*)dst_addr, idx * 2 + 1
197 // 6. store (float2)val1, (float2*)dst_addr1
198 // 7. store (float2)val2, (float2*)dst_addr2
199 //
200
201 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
202 unsigned NumVector = 1;
203 // Vulkan SPIR-V does not support over 4 components for
204 // TypeVector.
205 if (NumElement > 4) {
206 NumVector = NumElement >> 2;
207 NumElement = 4;
208 }
209
210 // Create store values.
211 Type *TmpValTy = SrcTy;
212 if (DstTy->isVectorTy()) {
213 if (SrcEleTyBitWidth == DstEleTyBitWidth) {
214 TmpValTy =
215 VectorType::get(SrcEleTy, DstTy->getVectorNumElements());
216 } else {
217 TmpValTy = VectorType::get(SrcEleTy, NumElement);
218 }
219 }
220
221 Value *STVal = ST->getValueOperand();
222 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
223 Value *TmpSTVal = nullptr;
224 if (NumVector == 1) {
225 TmpSTVal = Builder.CreateBitCast(STVal, TmpValTy);
226 } else {
227 unsigned DstVecTyNumElement =
228 DstTy->getVectorNumElements() / NumVector;
229 SmallVector<uint32_t, 4> Idxs;
230 for (unsigned i = 0; i < DstVecTyNumElement; i++) {
231 Idxs.push_back(i + (DstVecTyNumElement * VIdx));
232 }
233 Value *UndefVal = UndefValue::get(DstTy);
234 TmpSTVal = Builder.CreateShuffleVector(STVal, UndefVal, Idxs);
235 TmpSTVal = Builder.CreateBitCast(TmpSTVal, TmpValTy);
236 }
237
238 SmallVector<Value *, 8> STValues;
239 if (!SrcTy->isVectorTy()) {
240 // Handle scalar type.
241 for (unsigned i = 0; i < NumElement; i++) {
242 Value *TmpVal = Builder.CreateExtractElement(
243 TmpSTVal, Builder.getInt32(i));
244 STValues.push_back(TmpVal);
245 }
246 } else {
247 // Handle vector type.
248 unsigned SrcNumElement = SrcTy->getVectorNumElements();
249 unsigned DstNumElement = DstTy->getVectorNumElements();
250 for (unsigned i = 0; i < NumElement; i++) {
251 SmallVector<uint32_t, 4> Idxs;
252 for (unsigned j = 0; j < SrcNumElement; j++) {
253 Idxs.push_back(i * SrcNumElement + j);
254 }
255
256 VectorType *TmpVecTy =
257 VectorType::get(SrcEleTy, DstNumElement);
258 Value *UndefVal = UndefValue::get(TmpVecTy);
259 Value *TmpVal =
260 Builder.CreateShuffleVector(TmpSTVal, UndefVal, Idxs);
261 STValues.push_back(TmpVal);
262 }
263 }
264
265 // Generate stores.
266 Value *SrcAddrIdx = NewAddrIdx;
267 Value *BaseAddr = BitCastSrc;
268 for (unsigned i = 0; i < NumElement; i++) {
269 // Calculate store address.
270 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
271 Builder.CreateStore(STValues[i], DstAddr);
272
273 if (i + 1 < NumElement) {
274 // Calculate next store address
275 SrcAddrIdx =
276 Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
277 }
278 }
279 }
280 } else if (SrcTyBitWidth > DstTyBitWidth) {
281 //
282 // Consider below case.
283 //
284 // Original IR (float4* --> float2*)
285 // 1. val = load (float2*) src_addr
286 // 2. dst_addr = bitcast float4*, float2*
287 // 3. dst_addr = gep (float2*) dst_addr, idx
288 // 4. store (float2) val, (float2*) dst_addr
289 //
290 // Transformed IR
291 // 1. val = load (float2*) src_addr
292 // 2. val1 = (float)extract_element val, 0
293 // 3. val2 = (float)extract_element val, 1
294 // 4. dst_addr1 = gep (float4*) dst, idx / 2
295 // 5. dst_addr2 = gep (float4*) dst, idx / 2 + 1
296 // 6. store (float)val1, (float*) dst_addr1
297 // 7. store (float)val2, (float*) dst_addr2
298 //
299
300 if (SrcTyBitWidth <= DstEleTyBitWidth) {
301 SrcTy->print(errs());
302 DstTy->print(errs());
303 llvm_unreachable("Handle above src/dst type.");
304 }
305
306 // Create store values.
307 Value *STVal = ST->getValueOperand();
308
309 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
310 VectorType *TmpVecTy =
311 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
312 STVal = Builder.CreateBitCast(STVal, TmpVecTy);
313 }
314
315 SmallVector<Value *, 8> STValues;
316 unsigned DstNumElement = 1;
317 if (!DstTy->isVectorTy() || SrcEleTyBitWidth == DstTyBitWidth) {
318 // Handle scalar type.
319 STValues.push_back(STVal);
320 } else {
321 // Handle vector type.
322 DstNumElement = DstTy->getVectorNumElements();
323 for (unsigned i = 0; i < DstNumElement; i++) {
324 Value *Idx = Builder.getInt32(i);
325 Value *TmpVal = Builder.CreateExtractElement(STVal, Idx);
326 STValues.push_back(TmpVal);
327 }
328 }
329
330 // Generate stores.
331 Value *BaseAddr = BitCastSrc;
332 Value *SubEleIdx = Builder.getInt32(0);
333 if (IsGEPUser) {
334 unsigned SubNumElement = SrcTy->getVectorNumElements();
335 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
336 SubNumElement = SrcTy->getVectorNumElements() /
337 DstTy->getVectorNumElements();
338 }
339
340 SubEleIdx = Builder.CreateAnd(
341 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
342 }
343
344 for (unsigned i = 0; i < DstNumElement; i++) {
345 // Calculate address.
346 if (i > 0) {
347 SubEleIdx = Builder.CreateAdd(SubEleIdx, Builder.getInt32(i));
348 }
349
350 Value *Idxs[] = {NewAddrIdx, SubEleIdx};
351 Value *DstAddr = Builder.CreateGEP(BaseAddr, Idxs);
352 Type *TmpSrcTy = SrcEleTy;
353 if (TmpSrcTy->isVectorTy()) {
354 TmpSrcTy = TmpSrcTy->getVectorElementType();
355 }
356 Value *TmpVal = Builder.CreateBitCast(STValues[i], TmpSrcTy);
357
358 Builder.CreateStore(TmpVal, DstAddr);
359 }
360 } else {
361 // if SrcTyBitWidth == DstTyBitWidth
362 Type *TmpSrcTy = SrcTy;
363 Value *DstAddr = Src;
364
365 if (IsGEPUser) {
366 SmallVector<Value *, 4> Idxs;
367 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
368 Idxs.push_back(GEP->getOperand(i));
369 }
370 DstAddr = Builder.CreateGEP(BitCastSrc, Idxs);
371
372 if (GEP->getNumOperands() > 2) {
373 TmpSrcTy = SrcEleTy;
374 }
375 }
376
377 Value *TmpVal =
378 Builder.CreateBitCast(ST->getValueOperand(), TmpSrcTy);
379 Builder.CreateStore(TmpVal, DstAddr);
380 }
381 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
382 Value *SrcAddrIdx = Builder.getInt32(0);
383 if (IsGEPUser) {
384 SrcAddrIdx = NewAddrIdx;
385 }
386
387 // Load value from src.
388 SmallVector<Value *, 8> LDValues;
389
390 for (unsigned i = 1; i <= NumIter; i++) {
391 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
392 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
393 LDValues.push_back(SrcVal);
394
395 if (i + 1 <= NumIter) {
396 // Calculate next SrcAddrIdx.
397 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
398 }
399 }
400
401 Value *DstVal = nullptr;
402 if (SrcTyBitWidth > DstTyBitWidth) {
403 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
404
405 if (SrcEleTyBitWidth == DstTyBitWidth) {
406 //
407 // Consider below case.
408 //
409 // Original IR (int4* --> char4*)
410 // 1. src_addr = bitcast int4*, char4*
411 // 2. element_addr = gep (char4*) src_addr, idx
412 // 3. load (char4*) element_addr
413 //
414 // Transformed IR
415 // 1. src_addr = gep (int4*) src, idx / 4
416 // 2. src_val(int4) = load (int4*) src_addr
417 // 3. tmp_val(int4) = extractelement src_val, idx % 4
418 // 4. dst_val(char4) = bitcast tmp_val, (char4)
419 //
420 Value *EleIdx = Builder.getInt32(0);
421 if (IsGEPUser) {
422 EleIdx = Builder.CreateAnd(OrgGEPIdx,
423 Builder.getInt32(NumElement - 1));
424 }
425 Value *TmpVal =
426 Builder.CreateExtractElement(LDValues[0], EleIdx, "tmp_val");
427 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
428 } else if (SrcEleTyBitWidth < DstTyBitWidth) {
429 if (IsGEPUser) {
430 //
431 // Consider below case.
432 //
433 // Original IR (float4* --> float2*)
434 // 1. src_addr = bitcast float4*, float2*
435 // 2. element_addr = gep (float2*) src_addr, idx
436 // 3. load (float2*) element_addr
437 //
438 // Transformed IR
439 // 1. src_addr = gep (float4*) src, idx / 2
440 // 2. src_val(float4) = load (float4*) src_addr
441 // 3. tmp_val1(float) = extractelement (idx % 2) * 2
442 // 4. tmp_val2(float) = extractelement (idx % 2) * 2 + 1
443 // 5. dst_val(float2) = insertelement undef(float2), tmp_val1, 0
444 // 6. dst_val(float2) = insertelement undef(float2), tmp_val2, 1
445 // 7. dst_val(float2) = bitcast dst_val, (float2)
446 // ==> if types are same between src and dst, it will be
447 // igonored
448 //
449 VectorType *TmpVecTy =
450 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
451 DstVal = UndefValue::get(TmpVecTy);
452 Value *EleIdx = Builder.CreateAnd(
453 OrgGEPIdx, Builder.getInt32(NumElement - 1));
454 EleIdx = Builder.CreateShl(
455 EleIdx, Builder.getInt32(
456 std::log2(DstTyBitWidth / SrcEleTyBitWidth)));
457 Value *TmpOrgGEPIdx = EleIdx;
458 for (unsigned i = 0; i < NumElement; i++) {
459 Value *TmpVal = Builder.CreateExtractElement(
460 LDValues[0], TmpOrgGEPIdx, "tmp_val");
461 DstVal = Builder.CreateInsertElement(DstVal, TmpVal,
462 Builder.getInt32(i));
463
464 if (i + 1 < NumElement) {
465 TmpOrgGEPIdx =
466 Builder.CreateAdd(TmpOrgGEPIdx, Builder.getInt32(1));
467 }
468 }
469 } else {
470 //
471 // Consider below case.
472 //
473 // Original IR (float4* --> int2*)
474 // 1. src_addr = bitcast float4*, int2*
475 // 2. load (int2*) src_addr
476 //
477 // Transformed IR
478 // 1. src_val(float4) = load (float4*) src_addr
479 // 2. tmp_val(float2) = shufflevector (float4)src_val,
480 // (float4)undef,
481 // (float2)<0, 1>
482 // 3. dst_val(int2) = bitcast (float2)tmp_val, (int2)
483 //
484 unsigned NumElement = DstTyBitWidth / SrcEleTyBitWidth;
485 Value *Undef = UndefValue::get(SrcTy);
486
487 SmallVector<uint32_t, 4> Idxs;
488 for (unsigned i = 0; i < NumElement; i++) {
489 Idxs.push_back(i);
490 }
491 DstVal = Builder.CreateShuffleVector(LDValues[0], Undef, Idxs);
492
493 DstVal = Builder.CreateBitCast(DstVal, DstTy);
494 }
495
496 DstVal = Builder.CreateBitCast(DstVal, DstTy);
497 } else {
498 if (IsGEPUser) {
499 //
500 // Consider below case.
501 //
502 // Original IR (int4* --> char2*)
503 // 1. src_addr = bitcast int4*, char2*
504 // 2. element_addr = gep (char2*) src_addr, idx
505 // 3. load (char2*) element_addr
506 //
507 // Transformed IR
508 // 1. src_addr = gep (int4*) src, idx / 8
509 // 2. src_val(int4) = load (int4*) src_addr
510 // 3. tmp_val(int) = extractelement idx / 2
511 // 4. tmp_val(<i16 x 2>) = bitcast tmp_val(int), (<i16 x 2>)
512 // 5. tmp_val(i16) = extractelement idx % 2
513 // 6. dst_val(char2) = bitcast tmp_val, (char2)
514 // ==> if types are same between src and dst, it will be
515 // igonored
516 //
517 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
518 unsigned SubNumElement = SrcEleTyBitWidth / DstTyBitWidth;
519 if (SubNumElement != 2 && SubNumElement != 4) {
520 llvm_unreachable("Unsupported SubNumElement");
521 }
522
523 Value *TmpOrgGEPIdx = Builder.CreateLShr(
524 OrgGEPIdx, Builder.getInt32(std::log2(SubNumElement)));
525 Value *TmpVal = Builder.CreateExtractElement(
526 LDValues[0], TmpOrgGEPIdx, "tmp_val");
527 TmpVal = Builder.CreateBitCast(
528 TmpVal,
529 VectorType::get(
530 IntegerType::get(DstTy->getContext(), DstTyBitWidth),
531 SubNumElement));
532 TmpOrgGEPIdx = Builder.CreateAnd(
533 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
534 TmpVal = Builder.CreateExtractElement(TmpVal, TmpOrgGEPIdx,
535 "tmp_val");
536 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
537 } else {
538 Inst->print(errs());
539 llvm_unreachable("Handle this bitcast");
540 }
541 }
542 } else if (SrcTyBitWidth < DstTyBitWidth) {
543 //
544 // Consider below case.
545 //
546 // Original IR (float2* --> float4*)
547 // 1. src_addr = bitcast float2*, float4*
548 // 2. element_addr = gep (float4*) src_addr, idx
549 // 3. load (float4*) element_addr
550 //
551 // Transformed IR
552 // 1. src_addr = gep (float2*) src, idx * 2
553 // 2. src_val1(float2) = load (float2*) src_addr
554 // 3. src_addr2 = gep (float2*) src_addr, 1
555 // 4. src_val2(float2) = load (float2*) src_addr2
556 // 5. dst_val(float4) = shufflevector src_val1, src_val2, <0, 1>
557 // 6. dst_val(float4) = bitcast dst_val, (float4)
558 // ==> if types are same between src and dst, it will be igonored
559 //
560 unsigned NumElement = 1;
561 if (SrcTy->isVectorTy()) {
562 NumElement = SrcTy->getVectorNumElements() * 2;
563 }
564
565 // Handle scalar type.
566 if (NumElement == 1) {
567 if (SrcTyBitWidth * 4 <= DstTyBitWidth) {
568 unsigned NumVecElement = DstTyBitWidth / SrcTyBitWidth;
569 unsigned NumVector = 1;
570 if (NumVecElement > 4) {
571 NumVector = NumVecElement >> 2;
572 NumVecElement = 4;
573 }
574
575 SmallVector<Value *, 4> Values;
576 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
577 // In this case, generate only insert element. It generates
578 // less instructions than using shuffle vector.
579 VectorType *TmpVecTy = VectorType::get(SrcTy, NumVecElement);
580 Value *TmpVal = UndefValue::get(TmpVecTy);
581 for (unsigned i = 0; i < NumVecElement; i++) {
582 TmpVal = Builder.CreateInsertElement(
583 TmpVal, LDValues[i + (VIdx * 4)], Builder.getInt32(i));
584 }
585 Values.push_back(TmpVal);
586 }
587
588 if (Values.size() > 2) {
589 Inst->print(errs());
590 llvm_unreachable("Support above bitcast");
591 }
592
593 if (Values.size() > 1) {
594 Type *TmpEleTy =
595 Type::getIntNTy(M.getContext(), SrcEleTyBitWidth * 2);
596 VectorType *TmpVecTy = VectorType::get(TmpEleTy, NumVector);
597 for (unsigned i = 0; i < Values.size(); i++) {
598 Values[i] = Builder.CreateBitCast(Values[i], TmpVecTy);
599 }
600 SmallVector<uint32_t, 4> Idxs;
601 for (unsigned i = 0; i < (NumVector * 2); i++) {
602 Idxs.push_back(i);
603 }
604 for (unsigned i = 0; i < Values.size(); i = i + 2) {
605 Values[i] = Builder.CreateShuffleVector(
606 Values[i], Values[i + 1], Idxs);
607 }
608 }
609
610 LDValues.clear();
611 LDValues.push_back(Values[0]);
612 } else {
613 SmallVector<Value *, 4> TmpLDValues;
614 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
615 VectorType *TmpVecTy = VectorType::get(SrcTy, 2);
616 Value *TmpVal = UndefValue::get(TmpVecTy);
617 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i],
618 Builder.getInt32(0));
619 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i + 1],
620 Builder.getInt32(1));
621 TmpLDValues.push_back(TmpVal);
622 }
623 LDValues.clear();
624 LDValues = std::move(TmpLDValues);
625 NumElement = 4;
626 }
627 }
628
629 // Handle vector type.
630 while (LDValues.size() != 1) {
631 SmallVector<Value *, 4> TmpLDValues;
632 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
633 SmallVector<uint32_t, 4> Idxs;
634 for (unsigned j = 0; j < NumElement; j++) {
635 Idxs.push_back(j);
636 }
637 Value *TmpVal = Builder.CreateShuffleVector(
638 LDValues[i], LDValues[i + 1], Idxs);
639 TmpLDValues.push_back(TmpVal);
640 }
641 LDValues.clear();
642 LDValues = std::move(TmpLDValues);
643 NumElement *= 2;
644 }
645
646 DstVal = Builder.CreateBitCast(LDValues[0], DstTy);
647 } else {
648 //
649 // Consider below case.
650 //
651 // Original IR (float4* --> int4*)
652 // 1. src_addr = bitcast float4*, int4*
653 // 2. element_addr = gep (int4*) src_addr, idx, 0
654 // 3. load (int) element_addr
655 //
656 // Transformed IR
657 // 1. element_addr = gep (float4*) src_addr, idx, 0
658 // 2. src_val = load (float*) element_addr
659 // 3. val = bitcast (float) src_val to (int)
660 //
661 Value *SrcAddr = Src;
662 if (IsGEPUser) {
663 SmallVector<Value *, 4> Idxs;
664 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
665 Idxs.push_back(GEP->getOperand(i));
666 }
667 SrcAddr = Builder.CreateGEP(Src, Idxs);
668 }
669 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
670
671 Type *TmpDstTy = DstTy;
672 if (IsGEPUser) {
673 if (GEP->getNumOperands() > 2) {
674 TmpDstTy = DstEleTy;
675 }
676 }
677 DstVal = Builder.CreateBitCast(SrcVal, TmpDstTy);
678 }
679
680 // Update LD's users with DstVal.
681 LD->replaceAllUsesWith(DstVal);
682 } else {
683 U->print(errs());
684 llvm_unreachable(
685 "Handle above user of gep on ReplacePointerBitcastPass");
686 }
687
688 ToBeDeleted.push_back(cast<Instruction>(U));
689 }
690
691 if (IsGEPUser) {
692 ToBeDeleted.push_back(GEP);
693 }
694 }
695
696 ToBeDeleted.push_back(Inst);
697 }
698
699 for (Instruction *Inst : ScalarWorkList) {
700 Value *Src = Inst->getOperand(0);
701 Type *SrcTy = Src->getType()->getPointerElementType();
702 Type *DstTy = Inst->getType()->getPointerElementType();
703 unsigned SrcTyBitWidth = SrcTy->getPrimitiveSizeInBits();
704 unsigned DstTyBitWidth = DstTy->getPrimitiveSizeInBits();
705
706 for (User *BitCastUser : Inst->users()) {
707 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
708 // It consist of User* and bool whether user is gep or not.
709 SmallVector<std::pair<User *, bool>, 32> Users;
710
711 GetElementPtrInst *GEP = nullptr;
712 Value *OrgGEPIdx = nullptr;
713 if (GEP = dyn_cast<GetElementPtrInst>(BitCastUser)) {
714 IRBuilder<> Builder(GEP);
715
716 // Build new src/dst address.
717 OrgGEPIdx = GEP->getOperand(1);
718 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
719
720 // If bitcast's user is gep, investigate gep's users too.
721 for (User *GEPUser : GEP->users()) {
722 Users.push_back(std::make_pair(GEPUser, true));
723 }
724 } else {
725 Users.push_back(std::make_pair(BitCastUser, false));
726 }
727
728 // Handle users.
729 bool IsGEPUser = false;
730 for (auto UserIter : Users) {
731 User *U = UserIter.first;
732 IsGEPUser = UserIter.second;
733
734 IRBuilder<> Builder(cast<Instruction>(U));
735
736 // Handle store instruction with gep.
737 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
738 if (SrcTyBitWidth == DstTyBitWidth) {
739 auto STVal = Builder.CreateBitCast(ST->getValueOperand(), SrcTy);
740 Value *DstAddr = Builder.CreateGEP(Src, NewAddrIdx);
741 Builder.CreateStore(STVal, DstAddr);
742 } else if (SrcTyBitWidth < DstTyBitWidth) {
743 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
744
745 // Create Mask.
746 Constant *Mask = nullptr;
747 if (NumElement == 1) {
748 Mask = Builder.getInt32(0xFF);
749 } else if (NumElement == 2) {
750 Mask = Builder.getInt32(0xFFFF);
751 } else {
752 llvm_unreachable("strange type on bitcast");
753 }
754
755 // Create store values.
756 Value *STVal = ST->getValueOperand();
757 SmallVector<Value *, 8> STValues;
758 for (unsigned i = 0; i < NumElement; i++) {
759 Type *TmpTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
760 Value *TmpVal = Builder.CreateBitCast(STVal, TmpTy);
761 TmpVal = Builder.CreateLShr(TmpVal, Builder.getInt32(i));
762 TmpVal = Builder.CreateAnd(TmpVal, Mask);
763 TmpVal = Builder.CreateTrunc(TmpVal, SrcTy);
764 STValues.push_back(TmpVal);
765 }
766
767 // Generate stores.
768 Value *SrcAddrIdx = NewAddrIdx;
769 Value *BaseAddr = Src;
770 for (unsigned i = 0; i < NumElement; i++) {
771 // Calculate store address.
772 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
773 Builder.CreateStore(STValues[i], DstAddr);
774
775 if (i + 1 < NumElement) {
776 // Calculate next store address
777 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
778 }
779 }
780
781 } else {
782 Inst->print(errs());
783 llvm_unreachable("Handle different size store with scalar "
784 "bitcast on ReplacePointerBitcastPass");
785 }
786 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
787 if (SrcTyBitWidth == DstTyBitWidth) {
788 Value *SrcAddr = Builder.CreateGEP(Src, NewAddrIdx);
789 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
790 LD->replaceAllUsesWith(Builder.CreateBitCast(SrcVal, DstTy));
791 } else if (SrcTyBitWidth < DstTyBitWidth) {
792 Value *SrcAddrIdx = NewAddrIdx;
793
794 // Load value from src.
795 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
796 SmallVector<Value *, 8> LDValues;
797 for (unsigned i = 1; i <= NumIter; i++) {
798 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
799 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
800 LDValues.push_back(SrcVal);
801
802 if (i + 1 <= NumIter) {
803 // Calculate next SrcAddrIdx.
804 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
805 }
806 }
807
808 // Merge Load.
809 Type *TmpSrcTy = Type::getIntNTy(M.getContext(), SrcTyBitWidth);
810 Value *DstVal = Builder.CreateBitCast(LDValues[0], TmpSrcTy);
811 Type *TmpDstTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
812 DstVal = Builder.CreateZExt(DstVal, TmpDstTy);
813 for (unsigned i = 1; i < LDValues.size(); i++) {
814 Value *TmpVal = Builder.CreateBitCast(LDValues[i], TmpSrcTy);
815 TmpVal = Builder.CreateZExt(TmpVal, TmpDstTy);
816 TmpVal = Builder.CreateShl(TmpVal,
817 Builder.getInt32(i * SrcTyBitWidth));
818 DstVal = Builder.CreateOr(DstVal, TmpVal);
819 }
820
821 DstVal = Builder.CreateBitCast(DstVal, DstTy);
822 LD->replaceAllUsesWith(DstVal);
823
824 } else {
825 Inst->print(errs());
826 llvm_unreachable("Handle different size load with scalar "
827 "bitcast on ReplacePointerBitcastPass");
828 }
829 } else {
830
831 errs() << "About to crash at 820" << M << "Aboot\n\n";
832 Inst->print(errs());
833 llvm_unreachable("Handle above user of scalar bitcast with gep on "
834 "ReplacePointerBitcastPass");
835 }
836
837 ToBeDeleted.push_back(cast<Instruction>(U));
838 }
839
840 if (IsGEPUser) {
841 ToBeDeleted.push_back(GEP);
842 }
843 }
844
845 ToBeDeleted.push_back(Inst);
846 }
847
848 for (Instruction *Inst : ToBeDeleted) {
849 Inst->eraseFromParent();
850 }
851
852 return Changed;
853}