blob: 19509dd6237046137e3fe18037c9c1d635d9995d [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto8e138142018-05-29 10:19:21 -040015#include <llvm/IR/DataLayout.h>
David Neto22f144c2017-06-12 14:26:21 -040016#include <llvm/IR/IRBuilder.h>
17#include <llvm/IR/Instructions.h>
18#include <llvm/IR/Module.h>
19#include <llvm/Pass.h>
20#include <llvm/Support/raw_ostream.h>
21
22using namespace llvm;
23
24#define DEBUG_TYPE "replacepointerbitcast"
25
26namespace {
27struct ReplacePointerBitcastPass : public ModulePass {
28 static char ID;
29 ReplacePointerBitcastPass() : ModulePass(ID) {}
30
David Neto30ae05e2017-09-06 19:58:36 -040031 // Returns the number of chunks of source data required to exactly
32 // cover the destination data, if the source and destination types are
33 // different sizes. Otherwise returns 0.
David Neto22f144c2017-06-12 14:26:21 -040034 unsigned CalculateNumIter(unsigned SrcTyBitWidth, unsigned DstTyBitWidth);
35 Value *CalculateNewGEPIdx(unsigned SrcTyBitWidth, unsigned DstTyBitWidth,
36 GetElementPtrInst *GEP);
37
38 bool runOnModule(Module &M) override;
39};
40}
41
42char ReplacePointerBitcastPass::ID = 0;
43static RegisterPass<ReplacePointerBitcastPass>
44 X("ReplacePointerBitcast", "Replace Pointer Bitcast Pass");
45
46namespace clspv {
47ModulePass *createReplacePointerBitcastPass() {
48 return new ReplacePointerBitcastPass();
49}
50}
51
52unsigned ReplacePointerBitcastPass::CalculateNumIter(unsigned SrcTyBitWidth,
53 unsigned DstTyBitWidth) {
54 unsigned NumIter = 0;
55 if (SrcTyBitWidth > DstTyBitWidth) {
56 if (SrcTyBitWidth % DstTyBitWidth) {
57 llvm_unreachable(
58 "Src type bitwidth should be multiple of Dest type bitwidth");
59 }
60 NumIter = 1;
61 } else if (SrcTyBitWidth < DstTyBitWidth) {
62 if (DstTyBitWidth % SrcTyBitWidth) {
63 llvm_unreachable(
64 "Dest type bitwidth should be multiple of Src type bitwidth");
65 }
66 NumIter = DstTyBitWidth / SrcTyBitWidth;
67 } else {
68 NumIter = 0;
69 }
70
71 return NumIter;
72}
73
74Value *ReplacePointerBitcastPass::CalculateNewGEPIdx(
75 unsigned SrcTyBitWidth, unsigned DstTyBitWidth, GetElementPtrInst *GEP) {
76 Value *NewGEPIdx = GEP->getOperand(1);
77 IRBuilder<> Builder(GEP);
78
79 if (SrcTyBitWidth > DstTyBitWidth) {
80 if (GEP->getNumOperands() > 2) {
81 GEP->print(errs());
82 llvm_unreachable("Support above GEP on PointerBitcastPass");
83 }
84
85 NewGEPIdx = Builder.CreateLShr(
86 NewGEPIdx,
87 Builder.getInt32(std::log2(SrcTyBitWidth / DstTyBitWidth)));
88 } else if (DstTyBitWidth > SrcTyBitWidth) {
89 if (GEP->getNumOperands() > 2) {
90 GEP->print(errs());
91 llvm_unreachable("Support above GEP on PointerBitcastPass");
92 }
93
94 NewGEPIdx = Builder.CreateShl(
95 NewGEPIdx,
96 Builder.getInt32(std::log2(DstTyBitWidth / SrcTyBitWidth)));
97 }
98
99 return NewGEPIdx;
100}
101
102bool ReplacePointerBitcastPass::runOnModule(Module &M) {
103 bool Changed = false;
104
David Neto8e138142018-05-29 10:19:21 -0400105 const DataLayout& DL = M.getDataLayout();
106
David Neto22f144c2017-06-12 14:26:21 -0400107 SmallVector<Instruction *, 16> VectorWorkList;
108 SmallVector<Instruction *, 16> ScalarWorkList;
109 for (Function &F : M) {
110 for (BasicBlock &BB : F) {
111 for (Instruction &I : BB) {
112 // Find pointer bitcast instruction.
113 if (isa<BitCastInst>(&I) && isa<PointerType>(I.getType())) {
114 Value *Src = I.getOperand(0);
115 if (isa<PointerType>(Src->getType())) {
116 Type *SrcEleTy =
117 I.getOperand(0)->getType()->getPointerElementType();
118 Type *DstEleTy = I.getType()->getPointerElementType();
119 if (SrcEleTy->isVectorTy() || DstEleTy->isVectorTy()) {
120 // Handle case either operand is vector type like char4* -> int4*.
121 VectorWorkList.push_back(&I);
122 } else {
123 // Handle case all operands are scalar type like char* -> int*.
124 ScalarWorkList.push_back(&I);
125 }
126
127 Changed = true;
128 } else {
129 llvm_unreachable("Unsupported bitcast");
130 }
131 }
132 }
133 }
134 }
135
136 SmallVector<Instruction *, 16> ToBeDeleted;
137 for (Instruction *Inst : VectorWorkList) {
138 Value *Src = Inst->getOperand(0);
139 Type *SrcTy = Src->getType()->getPointerElementType();
140 Type *DstTy = Inst->getType()->getPointerElementType();
141 Type *SrcEleTy =
142 SrcTy->isVectorTy() ? SrcTy->getSequentialElementType() : SrcTy;
143 Type *DstEleTy =
144 DstTy->isVectorTy() ? DstTy->getSequentialElementType() : DstTy;
David Neto30ae05e2017-09-06 19:58:36 -0400145 // These are bit widths of the source and destination types, even
146 // if they are vector types. E.g. bit width of float4 is 64.
David Neto8e138142018-05-29 10:19:21 -0400147 unsigned SrcTyBitWidth = DL.getTypeStoreSizeInBits(SrcTy);
148 unsigned DstTyBitWidth = DL.getTypeStoreSizeInBits(DstTy);
149 unsigned SrcEleTyBitWidth = DL.getTypeStoreSizeInBits(SrcEleTy);
150 unsigned DstEleTyBitWidth = DL.getTypeStoreSizeInBits(DstEleTy);
David Neto22f144c2017-06-12 14:26:21 -0400151 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
152
153 // Investigate pointer bitcast's users.
154 for (User *BitCastUser : Inst->users()) {
155 Value *BitCastSrc = Inst->getOperand(0);
156 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
157
158 // It consist of User* and bool whether user is gep or not.
159 SmallVector<std::pair<User *, bool>, 32> Users;
160
161 GetElementPtrInst *GEP = nullptr;
162 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400163 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400164 OrgGEPIdx = GEP->getOperand(1);
165
166 // Build new src/dst address index.
167 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
168
169 // Record gep's users.
170 for (User *GEPUser : GEP->users()) {
171 Users.push_back(std::make_pair(GEPUser, true));
172 }
173 } else {
174 // Record bitcast's users.
175 Users.push_back(std::make_pair(BitCastUser, false));
176 }
177
178 // Handle users.
179 bool IsGEPUser = false;
180 for (auto UserIter : Users) {
181 User *U = UserIter.first;
182 IsGEPUser = UserIter.second;
183
184 IRBuilder<> Builder(cast<Instruction>(U));
185
186 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
187 if (SrcTyBitWidth < DstTyBitWidth) {
188 //
189 // Consider below case.
190 //
191 // Original IR (float2* --> float4*)
192 // 1. val = load (float4*) src_addr
193 // 2. dst_addr = bitcast float2*, float4*
194 // 3. dst_addr = gep (float4*) dst_addr, idx
195 // 4. store (float4*) dst_addr
196 //
197 // Transformed IR
198 // 1. val(float4) = load (float4*) src_addr
199 // 2. val1(float2) = shufflevector (float4)val, (float4)undef,
200 // (float2)<0, 1>
201 // 3. val2(float2) = shufflevector (float4)val, (float4)undef,
202 // (float2)<2, 3>
203 // 4. dst_addr1(float2*) = gep (float2*)dst_addr, idx * 2
204 // 5. dst_addr2(float2*) = gep (float2*)dst_addr, idx * 2 + 1
205 // 6. store (float2)val1, (float2*)dst_addr1
206 // 7. store (float2)val2, (float2*)dst_addr2
207 //
208
209 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
210 unsigned NumVector = 1;
211 // Vulkan SPIR-V does not support over 4 components for
212 // TypeVector.
213 if (NumElement > 4) {
214 NumVector = NumElement >> 2;
215 NumElement = 4;
216 }
217
218 // Create store values.
219 Type *TmpValTy = SrcTy;
220 if (DstTy->isVectorTy()) {
221 if (SrcEleTyBitWidth == DstEleTyBitWidth) {
222 TmpValTy =
223 VectorType::get(SrcEleTy, DstTy->getVectorNumElements());
224 } else {
225 TmpValTy = VectorType::get(SrcEleTy, NumElement);
226 }
227 }
228
229 Value *STVal = ST->getValueOperand();
230 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
231 Value *TmpSTVal = nullptr;
232 if (NumVector == 1) {
233 TmpSTVal = Builder.CreateBitCast(STVal, TmpValTy);
234 } else {
235 unsigned DstVecTyNumElement =
236 DstTy->getVectorNumElements() / NumVector;
237 SmallVector<uint32_t, 4> Idxs;
238 for (unsigned i = 0; i < DstVecTyNumElement; i++) {
239 Idxs.push_back(i + (DstVecTyNumElement * VIdx));
240 }
241 Value *UndefVal = UndefValue::get(DstTy);
242 TmpSTVal = Builder.CreateShuffleVector(STVal, UndefVal, Idxs);
243 TmpSTVal = Builder.CreateBitCast(TmpSTVal, TmpValTy);
244 }
245
246 SmallVector<Value *, 8> STValues;
247 if (!SrcTy->isVectorTy()) {
248 // Handle scalar type.
249 for (unsigned i = 0; i < NumElement; i++) {
250 Value *TmpVal = Builder.CreateExtractElement(
251 TmpSTVal, Builder.getInt32(i));
252 STValues.push_back(TmpVal);
253 }
254 } else {
255 // Handle vector type.
256 unsigned SrcNumElement = SrcTy->getVectorNumElements();
257 unsigned DstNumElement = DstTy->getVectorNumElements();
258 for (unsigned i = 0; i < NumElement; i++) {
259 SmallVector<uint32_t, 4> Idxs;
260 for (unsigned j = 0; j < SrcNumElement; j++) {
261 Idxs.push_back(i * SrcNumElement + j);
262 }
263
264 VectorType *TmpVecTy =
265 VectorType::get(SrcEleTy, DstNumElement);
266 Value *UndefVal = UndefValue::get(TmpVecTy);
267 Value *TmpVal =
268 Builder.CreateShuffleVector(TmpSTVal, UndefVal, Idxs);
269 STValues.push_back(TmpVal);
270 }
271 }
272
273 // Generate stores.
274 Value *SrcAddrIdx = NewAddrIdx;
275 Value *BaseAddr = BitCastSrc;
276 for (unsigned i = 0; i < NumElement; i++) {
277 // Calculate store address.
278 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
279 Builder.CreateStore(STValues[i], DstAddr);
280
281 if (i + 1 < NumElement) {
282 // Calculate next store address
283 SrcAddrIdx =
284 Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
285 }
286 }
287 }
288 } else if (SrcTyBitWidth > DstTyBitWidth) {
289 //
290 // Consider below case.
291 //
292 // Original IR (float4* --> float2*)
293 // 1. val = load (float2*) src_addr
294 // 2. dst_addr = bitcast float4*, float2*
295 // 3. dst_addr = gep (float2*) dst_addr, idx
296 // 4. store (float2) val, (float2*) dst_addr
297 //
David Neto30ae05e2017-09-06 19:58:36 -0400298 // Transformed IR: Decompose the source vector into elements, then write
299 // them one at a time.
David Neto22f144c2017-06-12 14:26:21 -0400300 // 1. val = load (float2*) src_addr
301 // 2. val1 = (float)extract_element val, 0
302 // 3. val2 = (float)extract_element val, 1
David Neto30ae05e2017-09-06 19:58:36 -0400303 // // Source component k maps to destination component k * idxscale
304 // 3a. idxscale = sizeof(float4)/sizeof(float2)
305 // 3b. idxbase = idx / idxscale
306 // 3c. newarrayidx = idxbase * idxscale
307 // 4. dst_addr1 = gep (float4*) dst, newarrayidx
308 // 5. dst_addr2 = gep (float4*) dst, newarrayidx + 1
David Neto22f144c2017-06-12 14:26:21 -0400309 // 6. store (float)val1, (float*) dst_addr1
310 // 7. store (float)val2, (float*) dst_addr2
311 //
312
313 if (SrcTyBitWidth <= DstEleTyBitWidth) {
314 SrcTy->print(errs());
315 DstTy->print(errs());
316 llvm_unreachable("Handle above src/dst type.");
317 }
318
319 // Create store values.
320 Value *STVal = ST->getValueOperand();
321
322 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
323 VectorType *TmpVecTy =
324 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
325 STVal = Builder.CreateBitCast(STVal, TmpVecTy);
326 }
327
328 SmallVector<Value *, 8> STValues;
David Neto30ae05e2017-09-06 19:58:36 -0400329 // How many destination writes are required?
David Neto22f144c2017-06-12 14:26:21 -0400330 unsigned DstNumElement = 1;
331 if (!DstTy->isVectorTy() || SrcEleTyBitWidth == DstTyBitWidth) {
332 // Handle scalar type.
333 STValues.push_back(STVal);
334 } else {
335 // Handle vector type.
336 DstNumElement = DstTy->getVectorNumElements();
337 for (unsigned i = 0; i < DstNumElement; i++) {
338 Value *Idx = Builder.getInt32(i);
339 Value *TmpVal = Builder.CreateExtractElement(STVal, Idx);
340 STValues.push_back(TmpVal);
341 }
342 }
343
344 // Generate stores.
345 Value *BaseAddr = BitCastSrc;
346 Value *SubEleIdx = Builder.getInt32(0);
347 if (IsGEPUser) {
David Neto30ae05e2017-09-06 19:58:36 -0400348 // Compute SubNumElement = idxscale
David Neto22f144c2017-06-12 14:26:21 -0400349 unsigned SubNumElement = SrcTy->getVectorNumElements();
350 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
David Neto30ae05e2017-09-06 19:58:36 -0400351 // Same condition under which DstNumElements > 1
David Neto22f144c2017-06-12 14:26:21 -0400352 SubNumElement = SrcTy->getVectorNumElements() /
353 DstTy->getVectorNumElements();
354 }
355
David Neto30ae05e2017-09-06 19:58:36 -0400356 // Compute SubEleIdx = idxbase * idxscale
David Neto22f144c2017-06-12 14:26:21 -0400357 SubEleIdx = Builder.CreateAnd(
358 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
David Neto30ae05e2017-09-06 19:58:36 -0400359 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
360 SubEleIdx = Builder.CreateShl(
361 SubEleIdx, Builder.getInt32(std::log2(SubNumElement)));
362 }
David Neto22f144c2017-06-12 14:26:21 -0400363 }
364
365 for (unsigned i = 0; i < DstNumElement; i++) {
366 // Calculate address.
367 if (i > 0) {
368 SubEleIdx = Builder.CreateAdd(SubEleIdx, Builder.getInt32(i));
369 }
370
371 Value *Idxs[] = {NewAddrIdx, SubEleIdx};
372 Value *DstAddr = Builder.CreateGEP(BaseAddr, Idxs);
373 Type *TmpSrcTy = SrcEleTy;
374 if (TmpSrcTy->isVectorTy()) {
375 TmpSrcTy = TmpSrcTy->getVectorElementType();
376 }
377 Value *TmpVal = Builder.CreateBitCast(STValues[i], TmpSrcTy);
378
379 Builder.CreateStore(TmpVal, DstAddr);
380 }
381 } else {
382 // if SrcTyBitWidth == DstTyBitWidth
383 Type *TmpSrcTy = SrcTy;
384 Value *DstAddr = Src;
385
386 if (IsGEPUser) {
387 SmallVector<Value *, 4> Idxs;
388 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
389 Idxs.push_back(GEP->getOperand(i));
390 }
391 DstAddr = Builder.CreateGEP(BitCastSrc, Idxs);
392
393 if (GEP->getNumOperands() > 2) {
394 TmpSrcTy = SrcEleTy;
395 }
396 }
397
398 Value *TmpVal =
399 Builder.CreateBitCast(ST->getValueOperand(), TmpSrcTy);
400 Builder.CreateStore(TmpVal, DstAddr);
401 }
402 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
403 Value *SrcAddrIdx = Builder.getInt32(0);
404 if (IsGEPUser) {
405 SrcAddrIdx = NewAddrIdx;
406 }
407
408 // Load value from src.
409 SmallVector<Value *, 8> LDValues;
410
411 for (unsigned i = 1; i <= NumIter; i++) {
412 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
413 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
414 LDValues.push_back(SrcVal);
415
416 if (i + 1 <= NumIter) {
417 // Calculate next SrcAddrIdx.
418 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
419 }
420 }
421
422 Value *DstVal = nullptr;
423 if (SrcTyBitWidth > DstTyBitWidth) {
424 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
425
426 if (SrcEleTyBitWidth == DstTyBitWidth) {
427 //
428 // Consider below case.
429 //
430 // Original IR (int4* --> char4*)
431 // 1. src_addr = bitcast int4*, char4*
432 // 2. element_addr = gep (char4*) src_addr, idx
433 // 3. load (char4*) element_addr
434 //
435 // Transformed IR
436 // 1. src_addr = gep (int4*) src, idx / 4
437 // 2. src_val(int4) = load (int4*) src_addr
438 // 3. tmp_val(int4) = extractelement src_val, idx % 4
439 // 4. dst_val(char4) = bitcast tmp_val, (char4)
440 //
441 Value *EleIdx = Builder.getInt32(0);
442 if (IsGEPUser) {
443 EleIdx = Builder.CreateAnd(OrgGEPIdx,
444 Builder.getInt32(NumElement - 1));
445 }
446 Value *TmpVal =
447 Builder.CreateExtractElement(LDValues[0], EleIdx, "tmp_val");
448 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
449 } else if (SrcEleTyBitWidth < DstTyBitWidth) {
450 if (IsGEPUser) {
451 //
452 // Consider below case.
453 //
454 // Original IR (float4* --> float2*)
455 // 1. src_addr = bitcast float4*, float2*
456 // 2. element_addr = gep (float2*) src_addr, idx
457 // 3. load (float2*) element_addr
458 //
459 // Transformed IR
460 // 1. src_addr = gep (float4*) src, idx / 2
461 // 2. src_val(float4) = load (float4*) src_addr
462 // 3. tmp_val1(float) = extractelement (idx % 2) * 2
463 // 4. tmp_val2(float) = extractelement (idx % 2) * 2 + 1
464 // 5. dst_val(float2) = insertelement undef(float2), tmp_val1, 0
465 // 6. dst_val(float2) = insertelement undef(float2), tmp_val2, 1
466 // 7. dst_val(float2) = bitcast dst_val, (float2)
467 // ==> if types are same between src and dst, it will be
468 // igonored
469 //
470 VectorType *TmpVecTy =
471 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
472 DstVal = UndefValue::get(TmpVecTy);
473 Value *EleIdx = Builder.CreateAnd(
474 OrgGEPIdx, Builder.getInt32(NumElement - 1));
475 EleIdx = Builder.CreateShl(
476 EleIdx, Builder.getInt32(
477 std::log2(DstTyBitWidth / SrcEleTyBitWidth)));
478 Value *TmpOrgGEPIdx = EleIdx;
479 for (unsigned i = 0; i < NumElement; i++) {
480 Value *TmpVal = Builder.CreateExtractElement(
481 LDValues[0], TmpOrgGEPIdx, "tmp_val");
482 DstVal = Builder.CreateInsertElement(DstVal, TmpVal,
483 Builder.getInt32(i));
484
485 if (i + 1 < NumElement) {
486 TmpOrgGEPIdx =
487 Builder.CreateAdd(TmpOrgGEPIdx, Builder.getInt32(1));
488 }
489 }
490 } else {
491 //
492 // Consider below case.
493 //
494 // Original IR (float4* --> int2*)
495 // 1. src_addr = bitcast float4*, int2*
496 // 2. load (int2*) src_addr
497 //
498 // Transformed IR
499 // 1. src_val(float4) = load (float4*) src_addr
500 // 2. tmp_val(float2) = shufflevector (float4)src_val,
501 // (float4)undef,
502 // (float2)<0, 1>
503 // 3. dst_val(int2) = bitcast (float2)tmp_val, (int2)
504 //
505 unsigned NumElement = DstTyBitWidth / SrcEleTyBitWidth;
506 Value *Undef = UndefValue::get(SrcTy);
507
508 SmallVector<uint32_t, 4> Idxs;
509 for (unsigned i = 0; i < NumElement; i++) {
510 Idxs.push_back(i);
511 }
512 DstVal = Builder.CreateShuffleVector(LDValues[0], Undef, Idxs);
513
514 DstVal = Builder.CreateBitCast(DstVal, DstTy);
515 }
516
517 DstVal = Builder.CreateBitCast(DstVal, DstTy);
518 } else {
519 if (IsGEPUser) {
520 //
521 // Consider below case.
522 //
523 // Original IR (int4* --> char2*)
524 // 1. src_addr = bitcast int4*, char2*
525 // 2. element_addr = gep (char2*) src_addr, idx
526 // 3. load (char2*) element_addr
527 //
528 // Transformed IR
529 // 1. src_addr = gep (int4*) src, idx / 8
530 // 2. src_val(int4) = load (int4*) src_addr
531 // 3. tmp_val(int) = extractelement idx / 2
532 // 4. tmp_val(<i16 x 2>) = bitcast tmp_val(int), (<i16 x 2>)
533 // 5. tmp_val(i16) = extractelement idx % 2
534 // 6. dst_val(char2) = bitcast tmp_val, (char2)
535 // ==> if types are same between src and dst, it will be
536 // igonored
537 //
538 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
539 unsigned SubNumElement = SrcEleTyBitWidth / DstTyBitWidth;
540 if (SubNumElement != 2 && SubNumElement != 4) {
541 llvm_unreachable("Unsupported SubNumElement");
542 }
543
544 Value *TmpOrgGEPIdx = Builder.CreateLShr(
545 OrgGEPIdx, Builder.getInt32(std::log2(SubNumElement)));
546 Value *TmpVal = Builder.CreateExtractElement(
547 LDValues[0], TmpOrgGEPIdx, "tmp_val");
548 TmpVal = Builder.CreateBitCast(
549 TmpVal,
550 VectorType::get(
551 IntegerType::get(DstTy->getContext(), DstTyBitWidth),
552 SubNumElement));
553 TmpOrgGEPIdx = Builder.CreateAnd(
554 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
555 TmpVal = Builder.CreateExtractElement(TmpVal, TmpOrgGEPIdx,
556 "tmp_val");
557 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
558 } else {
559 Inst->print(errs());
560 llvm_unreachable("Handle this bitcast");
561 }
562 }
563 } else if (SrcTyBitWidth < DstTyBitWidth) {
564 //
565 // Consider below case.
566 //
567 // Original IR (float2* --> float4*)
568 // 1. src_addr = bitcast float2*, float4*
569 // 2. element_addr = gep (float4*) src_addr, idx
570 // 3. load (float4*) element_addr
571 //
572 // Transformed IR
573 // 1. src_addr = gep (float2*) src, idx * 2
574 // 2. src_val1(float2) = load (float2*) src_addr
575 // 3. src_addr2 = gep (float2*) src_addr, 1
576 // 4. src_val2(float2) = load (float2*) src_addr2
577 // 5. dst_val(float4) = shufflevector src_val1, src_val2, <0, 1>
578 // 6. dst_val(float4) = bitcast dst_val, (float4)
579 // ==> if types are same between src and dst, it will be igonored
580 //
581 unsigned NumElement = 1;
582 if (SrcTy->isVectorTy()) {
583 NumElement = SrcTy->getVectorNumElements() * 2;
584 }
585
586 // Handle scalar type.
587 if (NumElement == 1) {
588 if (SrcTyBitWidth * 4 <= DstTyBitWidth) {
589 unsigned NumVecElement = DstTyBitWidth / SrcTyBitWidth;
590 unsigned NumVector = 1;
591 if (NumVecElement > 4) {
592 NumVector = NumVecElement >> 2;
593 NumVecElement = 4;
594 }
595
596 SmallVector<Value *, 4> Values;
597 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
598 // In this case, generate only insert element. It generates
599 // less instructions than using shuffle vector.
600 VectorType *TmpVecTy = VectorType::get(SrcTy, NumVecElement);
601 Value *TmpVal = UndefValue::get(TmpVecTy);
602 for (unsigned i = 0; i < NumVecElement; i++) {
603 TmpVal = Builder.CreateInsertElement(
604 TmpVal, LDValues[i + (VIdx * 4)], Builder.getInt32(i));
605 }
606 Values.push_back(TmpVal);
607 }
608
609 if (Values.size() > 2) {
610 Inst->print(errs());
611 llvm_unreachable("Support above bitcast");
612 }
613
614 if (Values.size() > 1) {
615 Type *TmpEleTy =
616 Type::getIntNTy(M.getContext(), SrcEleTyBitWidth * 2);
617 VectorType *TmpVecTy = VectorType::get(TmpEleTy, NumVector);
618 for (unsigned i = 0; i < Values.size(); i++) {
619 Values[i] = Builder.CreateBitCast(Values[i], TmpVecTy);
620 }
621 SmallVector<uint32_t, 4> Idxs;
622 for (unsigned i = 0; i < (NumVector * 2); i++) {
623 Idxs.push_back(i);
624 }
625 for (unsigned i = 0; i < Values.size(); i = i + 2) {
626 Values[i] = Builder.CreateShuffleVector(
627 Values[i], Values[i + 1], Idxs);
628 }
629 }
630
631 LDValues.clear();
632 LDValues.push_back(Values[0]);
633 } else {
634 SmallVector<Value *, 4> TmpLDValues;
635 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
636 VectorType *TmpVecTy = VectorType::get(SrcTy, 2);
637 Value *TmpVal = UndefValue::get(TmpVecTy);
638 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i],
639 Builder.getInt32(0));
640 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i + 1],
641 Builder.getInt32(1));
642 TmpLDValues.push_back(TmpVal);
643 }
644 LDValues.clear();
645 LDValues = std::move(TmpLDValues);
646 NumElement = 4;
647 }
648 }
649
650 // Handle vector type.
651 while (LDValues.size() != 1) {
652 SmallVector<Value *, 4> TmpLDValues;
653 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
654 SmallVector<uint32_t, 4> Idxs;
655 for (unsigned j = 0; j < NumElement; j++) {
656 Idxs.push_back(j);
657 }
658 Value *TmpVal = Builder.CreateShuffleVector(
659 LDValues[i], LDValues[i + 1], Idxs);
660 TmpLDValues.push_back(TmpVal);
661 }
662 LDValues.clear();
663 LDValues = std::move(TmpLDValues);
664 NumElement *= 2;
665 }
666
667 DstVal = Builder.CreateBitCast(LDValues[0], DstTy);
668 } else {
669 //
670 // Consider below case.
671 //
672 // Original IR (float4* --> int4*)
673 // 1. src_addr = bitcast float4*, int4*
674 // 2. element_addr = gep (int4*) src_addr, idx, 0
675 // 3. load (int) element_addr
676 //
677 // Transformed IR
678 // 1. element_addr = gep (float4*) src_addr, idx, 0
679 // 2. src_val = load (float*) element_addr
680 // 3. val = bitcast (float) src_val to (int)
681 //
682 Value *SrcAddr = Src;
683 if (IsGEPUser) {
684 SmallVector<Value *, 4> Idxs;
685 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
686 Idxs.push_back(GEP->getOperand(i));
687 }
688 SrcAddr = Builder.CreateGEP(Src, Idxs);
689 }
690 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
691
692 Type *TmpDstTy = DstTy;
693 if (IsGEPUser) {
694 if (GEP->getNumOperands() > 2) {
695 TmpDstTy = DstEleTy;
696 }
697 }
698 DstVal = Builder.CreateBitCast(SrcVal, TmpDstTy);
699 }
700
701 // Update LD's users with DstVal.
702 LD->replaceAllUsesWith(DstVal);
703 } else {
704 U->print(errs());
705 llvm_unreachable(
706 "Handle above user of gep on ReplacePointerBitcastPass");
707 }
708
709 ToBeDeleted.push_back(cast<Instruction>(U));
710 }
711
712 if (IsGEPUser) {
713 ToBeDeleted.push_back(GEP);
714 }
715 }
716
717 ToBeDeleted.push_back(Inst);
718 }
719
720 for (Instruction *Inst : ScalarWorkList) {
David Neto8e138142018-05-29 10:19:21 -0400721 // Some tests have a stray bitcast from pointer-to-array to
722 // pointer to i8*, but the bitcast has no uses. Exit early
723 // but be sure to delete it later.
724 //
725 // Example:
726 // %1 = bitcast [25 x float]* %dst to i8*
727
728 // errs () << " Scalar bitcast is " << *Inst << "\n";
729
730 if (!Inst->hasNUsesOrMore(1)) {
731 ToBeDeleted.push_back(Inst);
732 continue;
733 }
734
David Neto22f144c2017-06-12 14:26:21 -0400735 Value *Src = Inst->getOperand(0);
David Neto8e138142018-05-29 10:19:21 -0400736 Type *SrcTy; // Original type
737 Type *DstTy; // Type that SrcTy is cast to.
738 unsigned SrcTyBitWidth;
739 unsigned DstTyBitWidth;
740
741 SrcTy = Src->getType()->getPointerElementType();
742 DstTy = Inst->getType()->getPointerElementType();
743 int iter_count = 0;
744 while (++iter_count) {
745 SrcTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(SrcTy));
746 DstTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(DstTy));
747#if 0
748 errs() << " Try Src " << *Src << "\n";
749 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
750 << "\n";
751 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
752 << "\n";
753#endif
754
755 // The normal case that we can handle is source type is smaller than
756 // the dest type.
757 if (SrcTyBitWidth <= DstTyBitWidth)
758 break;
759
760 // The Source type is bigger than the destination type.
761 // Walk into the source type to break it down.
762 if (SrcTy->isArrayTy()) {
763 // If it's an array, consider only the first element.
764 Value *Zero = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
765 Instruction* NewSrc = GetElementPtrInst::CreateInBounds(Src, {Zero, Zero});
766 // errs() << "NewSrc is " << *NewSrc << "\n";
767 if (auto *SrcInst = dyn_cast<Instruction>(Src)) {
768 // errs() << " instruction case\n";
769 NewSrc->insertAfter(SrcInst);
770 } else {
771 // Could be a parameter.
772 auto where = Inst->getParent()
773 ->getParent()
774 ->getEntryBlock()
775 .getFirstInsertionPt();
776 Instruction& whereInst = *where;
777 // errs() << "insert " << *NewSrc << " before " << whereInst << "\n";
778 NewSrc->insertBefore(&whereInst);
779 }
780 Src = NewSrc;
781 SrcTy = Src->getType()->getPointerElementType();
782 } else {
783 errs() << "Replace pointer bitcasts: unhandled case: non-array "
784 "non-vector source type "
785 << *SrcTy << " is wider than dest type " << *DstTy << "\n";
786 llvm_unreachable("ReplacePointerBitcastPass: non-array non-vector "
787 "source type is wider than dest type");
788 }
789 if (iter_count > 1000) {
790 llvm_unreachable("ReplacePointerBitcastPass: Too many iterations!");
791 }
792 };
793#if 0
794 errs() << " Src is " << *Src << "\n";
795 errs() << " Dst is " << *Inst << "\n";
796 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
797 << "\n";
798 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
799 << "\n";
800#endif
David Neto22f144c2017-06-12 14:26:21 -0400801
802 for (User *BitCastUser : Inst->users()) {
803 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
804 // It consist of User* and bool whether user is gep or not.
805 SmallVector<std::pair<User *, bool>, 32> Users;
806
807 GetElementPtrInst *GEP = nullptr;
808 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400809 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400810 IRBuilder<> Builder(GEP);
811
812 // Build new src/dst address.
813 OrgGEPIdx = GEP->getOperand(1);
814 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
815
816 // If bitcast's user is gep, investigate gep's users too.
817 for (User *GEPUser : GEP->users()) {
818 Users.push_back(std::make_pair(GEPUser, true));
819 }
820 } else {
821 Users.push_back(std::make_pair(BitCastUser, false));
822 }
823
824 // Handle users.
825 bool IsGEPUser = false;
826 for (auto UserIter : Users) {
827 User *U = UserIter.first;
828 IsGEPUser = UserIter.second;
829
830 IRBuilder<> Builder(cast<Instruction>(U));
831
832 // Handle store instruction with gep.
833 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
David Neto8e138142018-05-29 10:19:21 -0400834 //errs() << " store is " << *ST << "\n";
David Neto22f144c2017-06-12 14:26:21 -0400835 if (SrcTyBitWidth == DstTyBitWidth) {
836 auto STVal = Builder.CreateBitCast(ST->getValueOperand(), SrcTy);
837 Value *DstAddr = Builder.CreateGEP(Src, NewAddrIdx);
838 Builder.CreateStore(STVal, DstAddr);
839 } else if (SrcTyBitWidth < DstTyBitWidth) {
840 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
841
842 // Create Mask.
843 Constant *Mask = nullptr;
844 if (NumElement == 1) {
845 Mask = Builder.getInt32(0xFF);
846 } else if (NumElement == 2) {
847 Mask = Builder.getInt32(0xFFFF);
848 } else {
849 llvm_unreachable("strange type on bitcast");
850 }
851
852 // Create store values.
853 Value *STVal = ST->getValueOperand();
854 SmallVector<Value *, 8> STValues;
855 for (unsigned i = 0; i < NumElement; i++) {
856 Type *TmpTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
857 Value *TmpVal = Builder.CreateBitCast(STVal, TmpTy);
858 TmpVal = Builder.CreateLShr(TmpVal, Builder.getInt32(i));
859 TmpVal = Builder.CreateAnd(TmpVal, Mask);
860 TmpVal = Builder.CreateTrunc(TmpVal, SrcTy);
861 STValues.push_back(TmpVal);
862 }
863
864 // Generate stores.
865 Value *SrcAddrIdx = NewAddrIdx;
866 Value *BaseAddr = Src;
867 for (unsigned i = 0; i < NumElement; i++) {
868 // Calculate store address.
869 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
870 Builder.CreateStore(STValues[i], DstAddr);
871
872 if (i + 1 < NumElement) {
873 // Calculate next store address
874 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
875 }
876 }
877
878 } else {
879 Inst->print(errs());
880 llvm_unreachable("Handle different size store with scalar "
881 "bitcast on ReplacePointerBitcastPass");
882 }
883 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
884 if (SrcTyBitWidth == DstTyBitWidth) {
885 Value *SrcAddr = Builder.CreateGEP(Src, NewAddrIdx);
886 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
887 LD->replaceAllUsesWith(Builder.CreateBitCast(SrcVal, DstTy));
888 } else if (SrcTyBitWidth < DstTyBitWidth) {
889 Value *SrcAddrIdx = NewAddrIdx;
890
891 // Load value from src.
892 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
893 SmallVector<Value *, 8> LDValues;
894 for (unsigned i = 1; i <= NumIter; i++) {
895 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
896 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
897 LDValues.push_back(SrcVal);
898
899 if (i + 1 <= NumIter) {
900 // Calculate next SrcAddrIdx.
901 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
902 }
903 }
904
905 // Merge Load.
906 Type *TmpSrcTy = Type::getIntNTy(M.getContext(), SrcTyBitWidth);
907 Value *DstVal = Builder.CreateBitCast(LDValues[0], TmpSrcTy);
908 Type *TmpDstTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
909 DstVal = Builder.CreateZExt(DstVal, TmpDstTy);
910 for (unsigned i = 1; i < LDValues.size(); i++) {
911 Value *TmpVal = Builder.CreateBitCast(LDValues[i], TmpSrcTy);
912 TmpVal = Builder.CreateZExt(TmpVal, TmpDstTy);
913 TmpVal = Builder.CreateShl(TmpVal,
914 Builder.getInt32(i * SrcTyBitWidth));
915 DstVal = Builder.CreateOr(DstVal, TmpVal);
916 }
917
918 DstVal = Builder.CreateBitCast(DstVal, DstTy);
919 LD->replaceAllUsesWith(DstVal);
920
921 } else {
922 Inst->print(errs());
923 llvm_unreachable("Handle different size load with scalar "
924 "bitcast on ReplacePointerBitcastPass");
925 }
926 } else {
927
928 errs() << "About to crash at 820" << M << "Aboot\n\n";
929 Inst->print(errs());
930 llvm_unreachable("Handle above user of scalar bitcast with gep on "
931 "ReplacePointerBitcastPass");
932 }
933
934 ToBeDeleted.push_back(cast<Instruction>(U));
935 }
936
937 if (IsGEPUser) {
938 ToBeDeleted.push_back(GEP);
939 }
940 }
941
942 ToBeDeleted.push_back(Inst);
943 }
944
945 for (Instruction *Inst : ToBeDeleted) {
946 Inst->eraseFromParent();
947 }
948
949 return Changed;
950}