blob: 20658977f8e234316b6cd31f9f2ffee71d9ec280 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/DataLayout.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/raw_ostream.h"
David Neto22f144c2017-06-12 14:26:21 -040021
Diego Novilloa4c44fa2019-04-11 10:56:15 -040022#include "Passes.h"
23
David Neto22f144c2017-06-12 14:26:21 -040024using namespace llvm;
25
26#define DEBUG_TYPE "replacepointerbitcast"
27
28namespace {
29struct ReplacePointerBitcastPass : public ModulePass {
30 static char ID;
31 ReplacePointerBitcastPass() : ModulePass(ID) {}
32
David Neto30ae05e2017-09-06 19:58:36 -040033 // Returns the number of chunks of source data required to exactly
34 // cover the destination data, if the source and destination types are
35 // different sizes. Otherwise returns 0.
David Neto22f144c2017-06-12 14:26:21 -040036 unsigned CalculateNumIter(unsigned SrcTyBitWidth, unsigned DstTyBitWidth);
37 Value *CalculateNewGEPIdx(unsigned SrcTyBitWidth, unsigned DstTyBitWidth,
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040038 GetElementPtrInst *GEP);
David Neto22f144c2017-06-12 14:26:21 -040039
40 bool runOnModule(Module &M) override;
41};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace
David Neto22f144c2017-06-12 14:26:21 -040043
44char ReplacePointerBitcastPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040045INITIALIZE_PASS(ReplacePointerBitcastPass, "ReplacePointerBitcast",
46 "Replace Pointer Bitcast Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040047
48namespace clspv {
49ModulePass *createReplacePointerBitcastPass() {
50 return new ReplacePointerBitcastPass();
51}
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040052} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040053
54unsigned ReplacePointerBitcastPass::CalculateNumIter(unsigned SrcTyBitWidth,
55 unsigned DstTyBitWidth) {
56 unsigned NumIter = 0;
57 if (SrcTyBitWidth > DstTyBitWidth) {
58 if (SrcTyBitWidth % DstTyBitWidth) {
59 llvm_unreachable(
60 "Src type bitwidth should be multiple of Dest type bitwidth");
61 }
62 NumIter = 1;
63 } else if (SrcTyBitWidth < DstTyBitWidth) {
64 if (DstTyBitWidth % SrcTyBitWidth) {
65 llvm_unreachable(
66 "Dest type bitwidth should be multiple of Src type bitwidth");
67 }
68 NumIter = DstTyBitWidth / SrcTyBitWidth;
69 } else {
70 NumIter = 0;
71 }
72
73 return NumIter;
74}
75
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040076Value *ReplacePointerBitcastPass::CalculateNewGEPIdx(unsigned SrcTyBitWidth,
77 unsigned DstTyBitWidth,
78 GetElementPtrInst *GEP) {
David Neto22f144c2017-06-12 14:26:21 -040079 Value *NewGEPIdx = GEP->getOperand(1);
80 IRBuilder<> Builder(GEP);
81
82 if (SrcTyBitWidth > DstTyBitWidth) {
83 if (GEP->getNumOperands() > 2) {
84 GEP->print(errs());
85 llvm_unreachable("Support above GEP on PointerBitcastPass");
86 }
87
88 NewGEPIdx = Builder.CreateLShr(
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040089 NewGEPIdx, Builder.getInt32(std::log2(SrcTyBitWidth / DstTyBitWidth)));
David Neto22f144c2017-06-12 14:26:21 -040090 } else if (DstTyBitWidth > SrcTyBitWidth) {
91 if (GEP->getNumOperands() > 2) {
92 GEP->print(errs());
93 llvm_unreachable("Support above GEP on PointerBitcastPass");
94 }
95
96 NewGEPIdx = Builder.CreateShl(
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040097 NewGEPIdx, Builder.getInt32(std::log2(DstTyBitWidth / SrcTyBitWidth)));
David Neto22f144c2017-06-12 14:26:21 -040098 }
99
100 return NewGEPIdx;
101}
102
103bool ReplacePointerBitcastPass::runOnModule(Module &M) {
104 bool Changed = false;
105
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400106 const DataLayout &DL = M.getDataLayout();
David Neto8e138142018-05-29 10:19:21 -0400107
David Neto22f144c2017-06-12 14:26:21 -0400108 SmallVector<Instruction *, 16> VectorWorkList;
109 SmallVector<Instruction *, 16> ScalarWorkList;
110 for (Function &F : M) {
111 for (BasicBlock &BB : F) {
112 for (Instruction &I : BB) {
113 // Find pointer bitcast instruction.
114 if (isa<BitCastInst>(&I) && isa<PointerType>(I.getType())) {
115 Value *Src = I.getOperand(0);
116 if (isa<PointerType>(Src->getType())) {
117 Type *SrcEleTy =
118 I.getOperand(0)->getType()->getPointerElementType();
119 Type *DstEleTy = I.getType()->getPointerElementType();
120 if (SrcEleTy->isVectorTy() || DstEleTy->isVectorTy()) {
121 // Handle case either operand is vector type like char4* -> int4*.
122 VectorWorkList.push_back(&I);
123 } else {
124 // Handle case all operands are scalar type like char* -> int*.
125 ScalarWorkList.push_back(&I);
126 }
127
128 Changed = true;
129 } else {
130 llvm_unreachable("Unsupported bitcast");
131 }
132 }
133 }
134 }
135 }
136
137 SmallVector<Instruction *, 16> ToBeDeleted;
138 for (Instruction *Inst : VectorWorkList) {
139 Value *Src = Inst->getOperand(0);
140 Type *SrcTy = Src->getType()->getPointerElementType();
141 Type *DstTy = Inst->getType()->getPointerElementType();
142 Type *SrcEleTy =
143 SrcTy->isVectorTy() ? SrcTy->getSequentialElementType() : SrcTy;
144 Type *DstEleTy =
145 DstTy->isVectorTy() ? DstTy->getSequentialElementType() : DstTy;
David Neto30ae05e2017-09-06 19:58:36 -0400146 // These are bit widths of the source and destination types, even
147 // if they are vector types. E.g. bit width of float4 is 64.
David Neto8e138142018-05-29 10:19:21 -0400148 unsigned SrcTyBitWidth = DL.getTypeStoreSizeInBits(SrcTy);
149 unsigned DstTyBitWidth = DL.getTypeStoreSizeInBits(DstTy);
150 unsigned SrcEleTyBitWidth = DL.getTypeStoreSizeInBits(SrcEleTy);
151 unsigned DstEleTyBitWidth = DL.getTypeStoreSizeInBits(DstEleTy);
David Neto22f144c2017-06-12 14:26:21 -0400152 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
153
154 // Investigate pointer bitcast's users.
155 for (User *BitCastUser : Inst->users()) {
156 Value *BitCastSrc = Inst->getOperand(0);
157 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
158
159 // It consist of User* and bool whether user is gep or not.
160 SmallVector<std::pair<User *, bool>, 32> Users;
161
162 GetElementPtrInst *GEP = nullptr;
163 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400164 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400165 OrgGEPIdx = GEP->getOperand(1);
166
167 // Build new src/dst address index.
168 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
169
170 // Record gep's users.
171 for (User *GEPUser : GEP->users()) {
172 Users.push_back(std::make_pair(GEPUser, true));
173 }
174 } else {
175 // Record bitcast's users.
176 Users.push_back(std::make_pair(BitCastUser, false));
177 }
178
179 // Handle users.
180 bool IsGEPUser = false;
181 for (auto UserIter : Users) {
182 User *U = UserIter.first;
183 IsGEPUser = UserIter.second;
184
185 IRBuilder<> Builder(cast<Instruction>(U));
186
187 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
188 if (SrcTyBitWidth < DstTyBitWidth) {
189 //
190 // Consider below case.
191 //
192 // Original IR (float2* --> float4*)
193 // 1. val = load (float4*) src_addr
194 // 2. dst_addr = bitcast float2*, float4*
195 // 3. dst_addr = gep (float4*) dst_addr, idx
196 // 4. store (float4*) dst_addr
197 //
198 // Transformed IR
199 // 1. val(float4) = load (float4*) src_addr
200 // 2. val1(float2) = shufflevector (float4)val, (float4)undef,
201 // (float2)<0, 1>
202 // 3. val2(float2) = shufflevector (float4)val, (float4)undef,
203 // (float2)<2, 3>
204 // 4. dst_addr1(float2*) = gep (float2*)dst_addr, idx * 2
205 // 5. dst_addr2(float2*) = gep (float2*)dst_addr, idx * 2 + 1
206 // 6. store (float2)val1, (float2*)dst_addr1
207 // 7. store (float2)val2, (float2*)dst_addr2
208 //
209
210 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
211 unsigned NumVector = 1;
212 // Vulkan SPIR-V does not support over 4 components for
213 // TypeVector.
214 if (NumElement > 4) {
215 NumVector = NumElement >> 2;
216 NumElement = 4;
217 }
218
219 // Create store values.
220 Type *TmpValTy = SrcTy;
221 if (DstTy->isVectorTy()) {
222 if (SrcEleTyBitWidth == DstEleTyBitWidth) {
223 TmpValTy =
224 VectorType::get(SrcEleTy, DstTy->getVectorNumElements());
225 } else {
226 TmpValTy = VectorType::get(SrcEleTy, NumElement);
227 }
228 }
229
230 Value *STVal = ST->getValueOperand();
231 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
232 Value *TmpSTVal = nullptr;
233 if (NumVector == 1) {
234 TmpSTVal = Builder.CreateBitCast(STVal, TmpValTy);
235 } else {
236 unsigned DstVecTyNumElement =
237 DstTy->getVectorNumElements() / NumVector;
238 SmallVector<uint32_t, 4> Idxs;
239 for (unsigned i = 0; i < DstVecTyNumElement; i++) {
240 Idxs.push_back(i + (DstVecTyNumElement * VIdx));
241 }
242 Value *UndefVal = UndefValue::get(DstTy);
243 TmpSTVal = Builder.CreateShuffleVector(STVal, UndefVal, Idxs);
244 TmpSTVal = Builder.CreateBitCast(TmpSTVal, TmpValTy);
245 }
246
247 SmallVector<Value *, 8> STValues;
248 if (!SrcTy->isVectorTy()) {
249 // Handle scalar type.
250 for (unsigned i = 0; i < NumElement; i++) {
251 Value *TmpVal = Builder.CreateExtractElement(
252 TmpSTVal, Builder.getInt32(i));
253 STValues.push_back(TmpVal);
254 }
255 } else {
256 // Handle vector type.
257 unsigned SrcNumElement = SrcTy->getVectorNumElements();
258 unsigned DstNumElement = DstTy->getVectorNumElements();
259 for (unsigned i = 0; i < NumElement; i++) {
260 SmallVector<uint32_t, 4> Idxs;
261 for (unsigned j = 0; j < SrcNumElement; j++) {
262 Idxs.push_back(i * SrcNumElement + j);
263 }
264
265 VectorType *TmpVecTy =
266 VectorType::get(SrcEleTy, DstNumElement);
267 Value *UndefVal = UndefValue::get(TmpVecTy);
268 Value *TmpVal =
269 Builder.CreateShuffleVector(TmpSTVal, UndefVal, Idxs);
270 STValues.push_back(TmpVal);
271 }
272 }
273
274 // Generate stores.
275 Value *SrcAddrIdx = NewAddrIdx;
276 Value *BaseAddr = BitCastSrc;
277 for (unsigned i = 0; i < NumElement; i++) {
278 // Calculate store address.
279 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
280 Builder.CreateStore(STValues[i], DstAddr);
281
282 if (i + 1 < NumElement) {
283 // Calculate next store address
284 SrcAddrIdx =
285 Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
286 }
287 }
288 }
289 } else if (SrcTyBitWidth > DstTyBitWidth) {
290 //
291 // Consider below case.
292 //
293 // Original IR (float4* --> float2*)
294 // 1. val = load (float2*) src_addr
295 // 2. dst_addr = bitcast float4*, float2*
296 // 3. dst_addr = gep (float2*) dst_addr, idx
297 // 4. store (float2) val, (float2*) dst_addr
298 //
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400299 // Transformed IR: Decompose the source vector into elements, then
300 // write them one at a time.
David Neto22f144c2017-06-12 14:26:21 -0400301 // 1. val = load (float2*) src_addr
302 // 2. val1 = (float)extract_element val, 0
303 // 3. val2 = (float)extract_element val, 1
David Neto30ae05e2017-09-06 19:58:36 -0400304 // // Source component k maps to destination component k * idxscale
305 // 3a. idxscale = sizeof(float4)/sizeof(float2)
306 // 3b. idxbase = idx / idxscale
307 // 3c. newarrayidx = idxbase * idxscale
308 // 4. dst_addr1 = gep (float4*) dst, newarrayidx
309 // 5. dst_addr2 = gep (float4*) dst, newarrayidx + 1
David Neto22f144c2017-06-12 14:26:21 -0400310 // 6. store (float)val1, (float*) dst_addr1
311 // 7. store (float)val2, (float*) dst_addr2
312 //
313
314 if (SrcTyBitWidth <= DstEleTyBitWidth) {
315 SrcTy->print(errs());
316 DstTy->print(errs());
317 llvm_unreachable("Handle above src/dst type.");
318 }
319
320 // Create store values.
321 Value *STVal = ST->getValueOperand();
322
323 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
324 VectorType *TmpVecTy =
325 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
326 STVal = Builder.CreateBitCast(STVal, TmpVecTy);
327 }
328
329 SmallVector<Value *, 8> STValues;
David Neto30ae05e2017-09-06 19:58:36 -0400330 // How many destination writes are required?
David Neto22f144c2017-06-12 14:26:21 -0400331 unsigned DstNumElement = 1;
332 if (!DstTy->isVectorTy() || SrcEleTyBitWidth == DstTyBitWidth) {
333 // Handle scalar type.
334 STValues.push_back(STVal);
335 } else {
336 // Handle vector type.
337 DstNumElement = DstTy->getVectorNumElements();
338 for (unsigned i = 0; i < DstNumElement; i++) {
339 Value *Idx = Builder.getInt32(i);
340 Value *TmpVal = Builder.CreateExtractElement(STVal, Idx);
341 STValues.push_back(TmpVal);
342 }
343 }
344
345 // Generate stores.
346 Value *BaseAddr = BitCastSrc;
347 Value *SubEleIdx = Builder.getInt32(0);
348 if (IsGEPUser) {
David Neto30ae05e2017-09-06 19:58:36 -0400349 // Compute SubNumElement = idxscale
David Neto22f144c2017-06-12 14:26:21 -0400350 unsigned SubNumElement = SrcTy->getVectorNumElements();
351 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
David Neto30ae05e2017-09-06 19:58:36 -0400352 // Same condition under which DstNumElements > 1
David Neto22f144c2017-06-12 14:26:21 -0400353 SubNumElement = SrcTy->getVectorNumElements() /
354 DstTy->getVectorNumElements();
355 }
356
David Neto30ae05e2017-09-06 19:58:36 -0400357 // Compute SubEleIdx = idxbase * idxscale
David Neto22f144c2017-06-12 14:26:21 -0400358 SubEleIdx = Builder.CreateAnd(
359 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
David Neto30ae05e2017-09-06 19:58:36 -0400360 if (DstTy->isVectorTy() && (SrcEleTyBitWidth != DstTyBitWidth)) {
361 SubEleIdx = Builder.CreateShl(
362 SubEleIdx, Builder.getInt32(std::log2(SubNumElement)));
363 }
David Neto22f144c2017-06-12 14:26:21 -0400364 }
365
366 for (unsigned i = 0; i < DstNumElement; i++) {
367 // Calculate address.
368 if (i > 0) {
369 SubEleIdx = Builder.CreateAdd(SubEleIdx, Builder.getInt32(i));
370 }
371
372 Value *Idxs[] = {NewAddrIdx, SubEleIdx};
373 Value *DstAddr = Builder.CreateGEP(BaseAddr, Idxs);
374 Type *TmpSrcTy = SrcEleTy;
375 if (TmpSrcTy->isVectorTy()) {
376 TmpSrcTy = TmpSrcTy->getVectorElementType();
377 }
378 Value *TmpVal = Builder.CreateBitCast(STValues[i], TmpSrcTy);
379
380 Builder.CreateStore(TmpVal, DstAddr);
381 }
382 } else {
383 // if SrcTyBitWidth == DstTyBitWidth
384 Type *TmpSrcTy = SrcTy;
385 Value *DstAddr = Src;
386
387 if (IsGEPUser) {
388 SmallVector<Value *, 4> Idxs;
389 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
390 Idxs.push_back(GEP->getOperand(i));
391 }
392 DstAddr = Builder.CreateGEP(BitCastSrc, Idxs);
393
394 if (GEP->getNumOperands() > 2) {
395 TmpSrcTy = SrcEleTy;
396 }
397 }
398
399 Value *TmpVal =
400 Builder.CreateBitCast(ST->getValueOperand(), TmpSrcTy);
401 Builder.CreateStore(TmpVal, DstAddr);
402 }
403 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
404 Value *SrcAddrIdx = Builder.getInt32(0);
405 if (IsGEPUser) {
406 SrcAddrIdx = NewAddrIdx;
407 }
408
409 // Load value from src.
410 SmallVector<Value *, 8> LDValues;
411
412 for (unsigned i = 1; i <= NumIter; i++) {
413 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
414 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
415 LDValues.push_back(SrcVal);
416
417 if (i + 1 <= NumIter) {
418 // Calculate next SrcAddrIdx.
419 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
420 }
421 }
422
423 Value *DstVal = nullptr;
424 if (SrcTyBitWidth > DstTyBitWidth) {
425 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
426
427 if (SrcEleTyBitWidth == DstTyBitWidth) {
428 //
429 // Consider below case.
430 //
431 // Original IR (int4* --> char4*)
432 // 1. src_addr = bitcast int4*, char4*
433 // 2. element_addr = gep (char4*) src_addr, idx
434 // 3. load (char4*) element_addr
435 //
436 // Transformed IR
437 // 1. src_addr = gep (int4*) src, idx / 4
438 // 2. src_val(int4) = load (int4*) src_addr
439 // 3. tmp_val(int4) = extractelement src_val, idx % 4
440 // 4. dst_val(char4) = bitcast tmp_val, (char4)
441 //
442 Value *EleIdx = Builder.getInt32(0);
443 if (IsGEPUser) {
444 EleIdx = Builder.CreateAnd(OrgGEPIdx,
445 Builder.getInt32(NumElement - 1));
446 }
447 Value *TmpVal =
448 Builder.CreateExtractElement(LDValues[0], EleIdx, "tmp_val");
449 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
450 } else if (SrcEleTyBitWidth < DstTyBitWidth) {
451 if (IsGEPUser) {
452 //
453 // Consider below case.
454 //
455 // Original IR (float4* --> float2*)
456 // 1. src_addr = bitcast float4*, float2*
457 // 2. element_addr = gep (float2*) src_addr, idx
458 // 3. load (float2*) element_addr
459 //
460 // Transformed IR
461 // 1. src_addr = gep (float4*) src, idx / 2
462 // 2. src_val(float4) = load (float4*) src_addr
463 // 3. tmp_val1(float) = extractelement (idx % 2) * 2
464 // 4. tmp_val2(float) = extractelement (idx % 2) * 2 + 1
465 // 5. dst_val(float2) = insertelement undef(float2), tmp_val1, 0
466 // 6. dst_val(float2) = insertelement undef(float2), tmp_val2, 1
467 // 7. dst_val(float2) = bitcast dst_val, (float2)
468 // ==> if types are same between src and dst, it will be
469 // igonored
470 //
471 VectorType *TmpVecTy =
472 VectorType::get(SrcEleTy, DstTyBitWidth / SrcEleTyBitWidth);
473 DstVal = UndefValue::get(TmpVecTy);
474 Value *EleIdx = Builder.CreateAnd(
475 OrgGEPIdx, Builder.getInt32(NumElement - 1));
476 EleIdx = Builder.CreateShl(
477 EleIdx, Builder.getInt32(
478 std::log2(DstTyBitWidth / SrcEleTyBitWidth)));
479 Value *TmpOrgGEPIdx = EleIdx;
480 for (unsigned i = 0; i < NumElement; i++) {
481 Value *TmpVal = Builder.CreateExtractElement(
482 LDValues[0], TmpOrgGEPIdx, "tmp_val");
483 DstVal = Builder.CreateInsertElement(DstVal, TmpVal,
484 Builder.getInt32(i));
485
486 if (i + 1 < NumElement) {
487 TmpOrgGEPIdx =
488 Builder.CreateAdd(TmpOrgGEPIdx, Builder.getInt32(1));
489 }
490 }
491 } else {
492 //
493 // Consider below case.
494 //
495 // Original IR (float4* --> int2*)
496 // 1. src_addr = bitcast float4*, int2*
497 // 2. load (int2*) src_addr
498 //
499 // Transformed IR
500 // 1. src_val(float4) = load (float4*) src_addr
501 // 2. tmp_val(float2) = shufflevector (float4)src_val,
502 // (float4)undef,
503 // (float2)<0, 1>
504 // 3. dst_val(int2) = bitcast (float2)tmp_val, (int2)
505 //
506 unsigned NumElement = DstTyBitWidth / SrcEleTyBitWidth;
507 Value *Undef = UndefValue::get(SrcTy);
508
509 SmallVector<uint32_t, 4> Idxs;
510 for (unsigned i = 0; i < NumElement; i++) {
511 Idxs.push_back(i);
512 }
513 DstVal = Builder.CreateShuffleVector(LDValues[0], Undef, Idxs);
514
515 DstVal = Builder.CreateBitCast(DstVal, DstTy);
516 }
517
518 DstVal = Builder.CreateBitCast(DstVal, DstTy);
519 } else {
520 if (IsGEPUser) {
521 //
522 // Consider below case.
523 //
524 // Original IR (int4* --> char2*)
525 // 1. src_addr = bitcast int4*, char2*
526 // 2. element_addr = gep (char2*) src_addr, idx
527 // 3. load (char2*) element_addr
528 //
529 // Transformed IR
530 // 1. src_addr = gep (int4*) src, idx / 8
531 // 2. src_val(int4) = load (int4*) src_addr
532 // 3. tmp_val(int) = extractelement idx / 2
533 // 4. tmp_val(<i16 x 2>) = bitcast tmp_val(int), (<i16 x 2>)
534 // 5. tmp_val(i16) = extractelement idx % 2
535 // 6. dst_val(char2) = bitcast tmp_val, (char2)
536 // ==> if types are same between src and dst, it will be
537 // igonored
538 //
539 unsigned NumElement = SrcTyBitWidth / DstTyBitWidth;
540 unsigned SubNumElement = SrcEleTyBitWidth / DstTyBitWidth;
541 if (SubNumElement != 2 && SubNumElement != 4) {
542 llvm_unreachable("Unsupported SubNumElement");
543 }
544
545 Value *TmpOrgGEPIdx = Builder.CreateLShr(
546 OrgGEPIdx, Builder.getInt32(std::log2(SubNumElement)));
547 Value *TmpVal = Builder.CreateExtractElement(
548 LDValues[0], TmpOrgGEPIdx, "tmp_val");
549 TmpVal = Builder.CreateBitCast(
550 TmpVal,
551 VectorType::get(
552 IntegerType::get(DstTy->getContext(), DstTyBitWidth),
553 SubNumElement));
554 TmpOrgGEPIdx = Builder.CreateAnd(
555 OrgGEPIdx, Builder.getInt32(SubNumElement - 1));
556 TmpVal = Builder.CreateExtractElement(TmpVal, TmpOrgGEPIdx,
557 "tmp_val");
558 DstVal = Builder.CreateBitCast(TmpVal, DstTy);
559 } else {
560 Inst->print(errs());
561 llvm_unreachable("Handle this bitcast");
562 }
563 }
564 } else if (SrcTyBitWidth < DstTyBitWidth) {
565 //
566 // Consider below case.
567 //
568 // Original IR (float2* --> float4*)
569 // 1. src_addr = bitcast float2*, float4*
570 // 2. element_addr = gep (float4*) src_addr, idx
571 // 3. load (float4*) element_addr
572 //
573 // Transformed IR
574 // 1. src_addr = gep (float2*) src, idx * 2
575 // 2. src_val1(float2) = load (float2*) src_addr
576 // 3. src_addr2 = gep (float2*) src_addr, 1
577 // 4. src_val2(float2) = load (float2*) src_addr2
578 // 5. dst_val(float4) = shufflevector src_val1, src_val2, <0, 1>
579 // 6. dst_val(float4) = bitcast dst_val, (float4)
580 // ==> if types are same between src and dst, it will be igonored
581 //
582 unsigned NumElement = 1;
583 if (SrcTy->isVectorTy()) {
584 NumElement = SrcTy->getVectorNumElements() * 2;
585 }
586
587 // Handle scalar type.
588 if (NumElement == 1) {
589 if (SrcTyBitWidth * 4 <= DstTyBitWidth) {
590 unsigned NumVecElement = DstTyBitWidth / SrcTyBitWidth;
591 unsigned NumVector = 1;
592 if (NumVecElement > 4) {
593 NumVector = NumVecElement >> 2;
594 NumVecElement = 4;
595 }
596
597 SmallVector<Value *, 4> Values;
598 for (unsigned VIdx = 0; VIdx < NumVector; VIdx++) {
599 // In this case, generate only insert element. It generates
600 // less instructions than using shuffle vector.
601 VectorType *TmpVecTy = VectorType::get(SrcTy, NumVecElement);
602 Value *TmpVal = UndefValue::get(TmpVecTy);
603 for (unsigned i = 0; i < NumVecElement; i++) {
604 TmpVal = Builder.CreateInsertElement(
605 TmpVal, LDValues[i + (VIdx * 4)], Builder.getInt32(i));
606 }
607 Values.push_back(TmpVal);
608 }
609
610 if (Values.size() > 2) {
611 Inst->print(errs());
612 llvm_unreachable("Support above bitcast");
613 }
614
615 if (Values.size() > 1) {
616 Type *TmpEleTy =
617 Type::getIntNTy(M.getContext(), SrcEleTyBitWidth * 2);
618 VectorType *TmpVecTy = VectorType::get(TmpEleTy, NumVector);
619 for (unsigned i = 0; i < Values.size(); i++) {
620 Values[i] = Builder.CreateBitCast(Values[i], TmpVecTy);
621 }
622 SmallVector<uint32_t, 4> Idxs;
623 for (unsigned i = 0; i < (NumVector * 2); i++) {
624 Idxs.push_back(i);
625 }
626 for (unsigned i = 0; i < Values.size(); i = i + 2) {
627 Values[i] = Builder.CreateShuffleVector(
628 Values[i], Values[i + 1], Idxs);
629 }
630 }
631
632 LDValues.clear();
633 LDValues.push_back(Values[0]);
634 } else {
635 SmallVector<Value *, 4> TmpLDValues;
636 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
637 VectorType *TmpVecTy = VectorType::get(SrcTy, 2);
638 Value *TmpVal = UndefValue::get(TmpVecTy);
639 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i],
640 Builder.getInt32(0));
641 TmpVal = Builder.CreateInsertElement(TmpVal, LDValues[i + 1],
642 Builder.getInt32(1));
643 TmpLDValues.push_back(TmpVal);
644 }
645 LDValues.clear();
646 LDValues = std::move(TmpLDValues);
647 NumElement = 4;
648 }
649 }
650
651 // Handle vector type.
652 while (LDValues.size() != 1) {
653 SmallVector<Value *, 4> TmpLDValues;
654 for (unsigned i = 0; i < LDValues.size(); i = i + 2) {
655 SmallVector<uint32_t, 4> Idxs;
656 for (unsigned j = 0; j < NumElement; j++) {
657 Idxs.push_back(j);
658 }
659 Value *TmpVal = Builder.CreateShuffleVector(
660 LDValues[i], LDValues[i + 1], Idxs);
661 TmpLDValues.push_back(TmpVal);
662 }
663 LDValues.clear();
664 LDValues = std::move(TmpLDValues);
665 NumElement *= 2;
666 }
667
668 DstVal = Builder.CreateBitCast(LDValues[0], DstTy);
669 } else {
670 //
671 // Consider below case.
672 //
673 // Original IR (float4* --> int4*)
674 // 1. src_addr = bitcast float4*, int4*
675 // 2. element_addr = gep (int4*) src_addr, idx, 0
676 // 3. load (int) element_addr
677 //
678 // Transformed IR
679 // 1. element_addr = gep (float4*) src_addr, idx, 0
680 // 2. src_val = load (float*) element_addr
681 // 3. val = bitcast (float) src_val to (int)
682 //
683 Value *SrcAddr = Src;
684 if (IsGEPUser) {
685 SmallVector<Value *, 4> Idxs;
686 for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
687 Idxs.push_back(GEP->getOperand(i));
688 }
689 SrcAddr = Builder.CreateGEP(Src, Idxs);
690 }
691 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
692
693 Type *TmpDstTy = DstTy;
694 if (IsGEPUser) {
695 if (GEP->getNumOperands() > 2) {
696 TmpDstTy = DstEleTy;
697 }
698 }
699 DstVal = Builder.CreateBitCast(SrcVal, TmpDstTy);
700 }
701
702 // Update LD's users with DstVal.
703 LD->replaceAllUsesWith(DstVal);
704 } else {
705 U->print(errs());
706 llvm_unreachable(
707 "Handle above user of gep on ReplacePointerBitcastPass");
708 }
709
710 ToBeDeleted.push_back(cast<Instruction>(U));
711 }
712
713 if (IsGEPUser) {
714 ToBeDeleted.push_back(GEP);
715 }
716 }
717
718 ToBeDeleted.push_back(Inst);
719 }
720
721 for (Instruction *Inst : ScalarWorkList) {
David Neto8e138142018-05-29 10:19:21 -0400722 // Some tests have a stray bitcast from pointer-to-array to
723 // pointer to i8*, but the bitcast has no uses. Exit early
724 // but be sure to delete it later.
725 //
726 // Example:
727 // %1 = bitcast [25 x float]* %dst to i8*
728
729 // errs () << " Scalar bitcast is " << *Inst << "\n";
730
731 if (!Inst->hasNUsesOrMore(1)) {
732 ToBeDeleted.push_back(Inst);
733 continue;
734 }
735
David Neto22f144c2017-06-12 14:26:21 -0400736 Value *Src = Inst->getOperand(0);
David Neto8e138142018-05-29 10:19:21 -0400737 Type *SrcTy; // Original type
738 Type *DstTy; // Type that SrcTy is cast to.
739 unsigned SrcTyBitWidth;
740 unsigned DstTyBitWidth;
741
742 SrcTy = Src->getType()->getPointerElementType();
743 DstTy = Inst->getType()->getPointerElementType();
744 int iter_count = 0;
745 while (++iter_count) {
746 SrcTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(SrcTy));
747 DstTyBitWidth = unsigned(DL.getTypeStoreSizeInBits(DstTy));
748#if 0
749 errs() << " Try Src " << *Src << "\n";
750 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
751 << "\n";
752 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
753 << "\n";
754#endif
755
756 // The normal case that we can handle is source type is smaller than
757 // the dest type.
758 if (SrcTyBitWidth <= DstTyBitWidth)
759 break;
760
761 // The Source type is bigger than the destination type.
762 // Walk into the source type to break it down.
763 if (SrcTy->isArrayTy()) {
764 // If it's an array, consider only the first element.
765 Value *Zero = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400766 Instruction *NewSrc =
767 GetElementPtrInst::CreateInBounds(Src, {Zero, Zero});
David Neto8e138142018-05-29 10:19:21 -0400768 // errs() << "NewSrc is " << *NewSrc << "\n";
769 if (auto *SrcInst = dyn_cast<Instruction>(Src)) {
770 // errs() << " instruction case\n";
771 NewSrc->insertAfter(SrcInst);
772 } else {
773 // Could be a parameter.
774 auto where = Inst->getParent()
775 ->getParent()
776 ->getEntryBlock()
777 .getFirstInsertionPt();
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400778 Instruction &whereInst = *where;
David Neto8e138142018-05-29 10:19:21 -0400779 // errs() << "insert " << *NewSrc << " before " << whereInst << "\n";
780 NewSrc->insertBefore(&whereInst);
781 }
782 Src = NewSrc;
783 SrcTy = Src->getType()->getPointerElementType();
784 } else {
785 errs() << "Replace pointer bitcasts: unhandled case: non-array "
786 "non-vector source type "
787 << *SrcTy << " is wider than dest type " << *DstTy << "\n";
788 llvm_unreachable("ReplacePointerBitcastPass: non-array non-vector "
789 "source type is wider than dest type");
790 }
791 if (iter_count > 1000) {
792 llvm_unreachable("ReplacePointerBitcastPass: Too many iterations!");
793 }
794 };
795#if 0
796 errs() << " Src is " << *Src << "\n";
797 errs() << " Dst is " << *Inst << "\n";
798 errs() << " SrcTy elem " << *SrcTy << " bit width " << SrcTyBitWidth
799 << "\n";
800 errs() << " DstTy elem " << *DstTy << " bit width " << DstTyBitWidth
801 << "\n";
802#endif
David Neto22f144c2017-06-12 14:26:21 -0400803
804 for (User *BitCastUser : Inst->users()) {
805 Value *NewAddrIdx = ConstantInt::get(Type::getInt32Ty(M.getContext()), 0);
806 // It consist of User* and bool whether user is gep or not.
807 SmallVector<std::pair<User *, bool>, 32> Users;
808
809 GetElementPtrInst *GEP = nullptr;
810 Value *OrgGEPIdx = nullptr;
Jason Gavrise44af072018-08-14 20:44:50 -0400811 if ((GEP = dyn_cast<GetElementPtrInst>(BitCastUser))) {
David Neto22f144c2017-06-12 14:26:21 -0400812 IRBuilder<> Builder(GEP);
813
814 // Build new src/dst address.
815 OrgGEPIdx = GEP->getOperand(1);
816 NewAddrIdx = CalculateNewGEPIdx(SrcTyBitWidth, DstTyBitWidth, GEP);
817
818 // If bitcast's user is gep, investigate gep's users too.
819 for (User *GEPUser : GEP->users()) {
820 Users.push_back(std::make_pair(GEPUser, true));
821 }
822 } else {
823 Users.push_back(std::make_pair(BitCastUser, false));
824 }
825
826 // Handle users.
827 bool IsGEPUser = false;
828 for (auto UserIter : Users) {
829 User *U = UserIter.first;
830 IsGEPUser = UserIter.second;
831
832 IRBuilder<> Builder(cast<Instruction>(U));
833
834 // Handle store instruction with gep.
835 if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400836 // errs() << " store is " << *ST << "\n";
David Neto22f144c2017-06-12 14:26:21 -0400837 if (SrcTyBitWidth == DstTyBitWidth) {
838 auto STVal = Builder.CreateBitCast(ST->getValueOperand(), SrcTy);
839 Value *DstAddr = Builder.CreateGEP(Src, NewAddrIdx);
840 Builder.CreateStore(STVal, DstAddr);
841 } else if (SrcTyBitWidth < DstTyBitWidth) {
842 unsigned NumElement = DstTyBitWidth / SrcTyBitWidth;
843
844 // Create Mask.
845 Constant *Mask = nullptr;
846 if (NumElement == 1) {
847 Mask = Builder.getInt32(0xFF);
848 } else if (NumElement == 2) {
849 Mask = Builder.getInt32(0xFFFF);
alan-baker4130faa2019-04-10 17:26:17 -0400850 } else if (NumElement == 4) {
851 Mask = Builder.getInt32(0xFFFFFFFF);
David Neto22f144c2017-06-12 14:26:21 -0400852 } else {
853 llvm_unreachable("strange type on bitcast");
854 }
855
856 // Create store values.
857 Value *STVal = ST->getValueOperand();
858 SmallVector<Value *, 8> STValues;
859 for (unsigned i = 0; i < NumElement; i++) {
860 Type *TmpTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
861 Value *TmpVal = Builder.CreateBitCast(STVal, TmpTy);
alan-baker4130faa2019-04-10 17:26:17 -0400862 TmpVal = Builder.CreateLShr(TmpVal,
863 Builder.getInt32(i * SrcTyBitWidth));
David Neto22f144c2017-06-12 14:26:21 -0400864 TmpVal = Builder.CreateAnd(TmpVal, Mask);
865 TmpVal = Builder.CreateTrunc(TmpVal, SrcTy);
866 STValues.push_back(TmpVal);
867 }
868
869 // Generate stores.
870 Value *SrcAddrIdx = NewAddrIdx;
871 Value *BaseAddr = Src;
872 for (unsigned i = 0; i < NumElement; i++) {
873 // Calculate store address.
874 Value *DstAddr = Builder.CreateGEP(BaseAddr, SrcAddrIdx);
875 Builder.CreateStore(STValues[i], DstAddr);
876
877 if (i + 1 < NumElement) {
878 // Calculate next store address
879 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
880 }
881 }
882
883 } else {
884 Inst->print(errs());
885 llvm_unreachable("Handle different size store with scalar "
886 "bitcast on ReplacePointerBitcastPass");
887 }
888 } else if (LoadInst *LD = dyn_cast<LoadInst>(U)) {
889 if (SrcTyBitWidth == DstTyBitWidth) {
890 Value *SrcAddr = Builder.CreateGEP(Src, NewAddrIdx);
891 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
892 LD->replaceAllUsesWith(Builder.CreateBitCast(SrcVal, DstTy));
893 } else if (SrcTyBitWidth < DstTyBitWidth) {
894 Value *SrcAddrIdx = NewAddrIdx;
895
896 // Load value from src.
897 unsigned NumIter = CalculateNumIter(SrcTyBitWidth, DstTyBitWidth);
898 SmallVector<Value *, 8> LDValues;
899 for (unsigned i = 1; i <= NumIter; i++) {
900 Value *SrcAddr = Builder.CreateGEP(Src, SrcAddrIdx);
901 LoadInst *SrcVal = Builder.CreateLoad(SrcAddr, "src_val");
902 LDValues.push_back(SrcVal);
903
904 if (i + 1 <= NumIter) {
905 // Calculate next SrcAddrIdx.
906 SrcAddrIdx = Builder.CreateAdd(SrcAddrIdx, Builder.getInt32(1));
907 }
908 }
909
910 // Merge Load.
911 Type *TmpSrcTy = Type::getIntNTy(M.getContext(), SrcTyBitWidth);
912 Value *DstVal = Builder.CreateBitCast(LDValues[0], TmpSrcTy);
913 Type *TmpDstTy = Type::getIntNTy(M.getContext(), DstTyBitWidth);
914 DstVal = Builder.CreateZExt(DstVal, TmpDstTy);
915 for (unsigned i = 1; i < LDValues.size(); i++) {
916 Value *TmpVal = Builder.CreateBitCast(LDValues[i], TmpSrcTy);
917 TmpVal = Builder.CreateZExt(TmpVal, TmpDstTy);
918 TmpVal = Builder.CreateShl(TmpVal,
919 Builder.getInt32(i * SrcTyBitWidth));
920 DstVal = Builder.CreateOr(DstVal, TmpVal);
921 }
922
923 DstVal = Builder.CreateBitCast(DstVal, DstTy);
924 LD->replaceAllUsesWith(DstVal);
925
926 } else {
927 Inst->print(errs());
928 llvm_unreachable("Handle different size load with scalar "
929 "bitcast on ReplacePointerBitcastPass");
930 }
931 } else {
David Neto22f144c2017-06-12 14:26:21 -0400932 Inst->print(errs());
933 llvm_unreachable("Handle above user of scalar bitcast with gep on "
934 "ReplacePointerBitcastPass");
935 }
936
937 ToBeDeleted.push_back(cast<Instruction>(U));
938 }
939
940 if (IsGEPUser) {
941 ToBeDeleted.push_back(GEP);
942 }
943 }
944
945 ToBeDeleted.push_back(Inst);
946 }
947
948 for (Instruction *Inst : ToBeDeleted) {
949 Inst->eraseFromParent();
950 }
951
952 return Changed;
953}