blob: bbda9e9ab5b1c565aeadef8d877a0100cc9c09fc [file] [log] [blame]
David Netodbb61f32017-09-29 15:50:36 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040016#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040017#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
David Netodbb61f32017-09-29 15:50:36 -040020
21using namespace llvm;
22
23#define DEBUG_TYPE "splatselectcond"
24
25namespace {
26struct SplatSelectConditionPass : public ModulePass {
27 static char ID;
28 SplatSelectConditionPass() : ModulePass(ID) {}
29
30 bool runOnModule(Module &M) override;
31};
32} // namespace
33
34char SplatSelectConditionPass::ID = 0;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035static RegisterPass<SplatSelectConditionPass> X("SplatSelectCond",
36 "Splat Select Condition Pass");
David Netodbb61f32017-09-29 15:50:36 -040037
38namespace clspv {
39llvm::ModulePass *createSplatSelectConditionPass() {
40 return new SplatSelectConditionPass();
41}
42} // namespace clspv
43
David Netodbb61f32017-09-29 15:50:36 -040044bool SplatSelectConditionPass::runOnModule(Module &M) {
45 bool Changed = false;
46
47 SmallVector<SelectInst *, 16> WorkList;
48 for (Function &F : M) {
49 for (BasicBlock &BB : F) {
50 for (Instruction &I : BB) {
51 if (SelectInst *sel = dyn_cast<SelectInst>(&I)) {
52 auto cond = sel->getCondition();
53 if (cond->getType()->isIntegerTy(1)) {
54 Type *valueTy = sel->getTrueValue()->getType();
55 if (valueTy->isVectorTy()) {
56 WorkList.push_back(sel);
57 }
58 }
59 }
60 }
61 }
62 }
63
64 if (WorkList.size() == 0)
65 return Changed;
66
67 IRBuilder<> Builder(WorkList.front());
68
69 for (SelectInst *sel : WorkList) {
70 Changed = true;
71 auto cond = sel->getCondition();
72 auto numElems = sel->getTrueValue()->getType()->getVectorNumElements();
73 Builder.SetInsertPoint(sel);
74 auto splat = Builder.CreateVectorSplat(numElems, cond);
75 sel->setCondition(splat);
76 }
77
78 return Changed;
79}