blob: d3af411fc33a3f1fcd1d625e2fd351aafb292e64 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/Instructions.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
David Neto22f144c2017-06-12 14:26:21 -040020
21using namespace llvm;
22
23#define DEBUG_TYPE "splatarg"
24
25namespace {
26struct SplatArgPass : public ModulePass {
27 static char ID;
28 SplatArgPass() : ModulePass(ID) {}
29
30 const char* getSplatName(StringRef Name);
31 bool runOnModule(Module &M) override;
32};
33}
34
35char SplatArgPass::ID = 0;
36static RegisterPass<SplatArgPass> X("SplatArg", "Splat Argument Pass");
37
38namespace clspv {
39llvm::ModulePass *createSplatArgPass() { return new SplatArgPass(); }
40}
41
42const char* SplatArgPass::getSplatName(StringRef Name) {
43 if (Name.equals("_Z5clampDv2_iii")) {
44 return "_Z5clampDv2_iS_S_";
45 } else if (Name.equals("_Z5clampDv3_iii")) {
46 return "_Z5clampDv3_iS_S_";
47 } else if (Name.equals("_Z5clampDv4_iii")) {
48 return "_Z5clampDv4_iS_S_";
49 } else if (Name.equals("_Z5clampDv2_jjj")) {
50 return "_Z5clampDv2_jS_S_";
51 } else if (Name.equals("_Z5clampDv3_jjj")) {
52 return "_Z5clampDv3_jS_S_";
53 } else if (Name.equals("_Z5clampDv4_jjj")) {
54 return "_Z5clampDv4_jS_S_";
55 } else if (Name.equals("_Z5clampDv2_fff")) {
56 return "_Z5clampDv2_fS_S_";
57 } else if (Name.equals("_Z5clampDv3_fff")) {
58 return "_Z5clampDv3_fS_S_";
59 } else if (Name.equals("_Z5clampDv4_fff")) {
60 return "_Z5clampDv4_fS_S_";
61
62 } else if (Name.equals("_Z3maxDv2_ii")) {
63 return "_Z3maxDv2_iS_";
64 } else if (Name.equals("_Z3maxDv3_ii")) {
65 return "_Z3maxDv3_iS_";
66 } else if (Name.equals("_Z3maxDv4_ii")) {
67 return "_Z3maxDv4_iS_";
68 } else if (Name.equals("_Z3maxDv2_jj")) {
69 return "_Z3maxDv2_jS_";
70 } else if (Name.equals("_Z3maxDv3_jj")) {
71 return "_Z3maxDv3_jS_";
72 } else if (Name.equals("_Z3maxDv4_jj")) {
73 return "_Z3maxDv4_jS_";
74 } else if (Name.equals("_Z3maxDv2_ff")) {
75 return "_Z3maxDv2_fS_";
76 } else if (Name.equals("_Z3maxDv3_ff")) {
77 return "_Z3maxDv3_fS_";
78 } else if (Name.equals("_Z3maxDv4_ff")) {
79 return "_Z3maxDv4_fS_";
80 } else if (Name.equals("_Z4fmaxDv2_ff")) {
81 return "_Z4fmaxDv2_fS_";
82 } else if (Name.equals("_Z4fmaxDv3_ff")) {
83 return "_Z4fmaxDv3_fS_";
84 } else if (Name.equals("_Z4fmaxDv4_ff")) {
85 return "_Z4fmaxDv4_fS_";
86
87 } else if (Name.equals("_Z3minDv2_ii")) {
88 return "_Z3minDv2_iS_";
89 } else if (Name.equals("_Z3minDv3_ii")) {
90 return "_Z3minDv3_iS_";
91 } else if (Name.equals("_Z3minDv4_ii")) {
92 return "_Z3minDv4_iS_";
93 } else if (Name.equals("_Z3minDv2_jj")) {
94 return "_Z3minDv2_jS_";
95 } else if (Name.equals("_Z3minDv3_jj")) {
96 return "_Z3minDv3_jS_";
97 } else if (Name.equals("_Z3minDv4_jj")) {
98 return "_Z3minDv4_jS_";
99 } else if (Name.equals("_Z3minDv2_ff")) {
100 return "_Z3minDv2_fS_";
101 } else if (Name.equals("_Z3minDv3_ff")) {
102 return "_Z3minDv3_fS_";
103 } else if (Name.equals("_Z3minDv4_ff")) {
104 return "_Z3minDv4_fS_";
105 } else if (Name.equals("_Z4fminDv2_ff")) {
106 return "_Z4fminDv2_fS_";
107 } else if (Name.equals("_Z4fminDv3_ff")) {
108 return "_Z4fminDv3_fS_";
109 } else if (Name.equals("_Z4fminDv4_ff")) {
110 return "_Z4fminDv4_fS_";
111
112 } else if (Name.equals("_Z3mixDv2_fS_f")) {
113 return "_Z3mixDv2_fS_S_";
114 } else if (Name.equals("_Z3mixDv3_fS_f")) {
115 return "_Z3mixDv3_fS_S_";
116 } else if (Name.equals("_Z3mixDv4_fS_f")) {
117 return "_Z3mixDv4_fS_S_";
118 }
119
120 return nullptr;
121}
122
123bool SplatArgPass::runOnModule(Module &M) {
124 bool Changed = false;
125
126 SmallVector<CallInst *, 16> WorkList;
127 for (Function &F : M) {
128 for (BasicBlock &BB : F) {
129 for (Instruction &I : BB) {
130 if (CallInst *Call = dyn_cast<CallInst>(&I)) {
131 Function *Callee = Call->getCalledFunction();
132 if (Callee) {
133 // If min/max/mix/clamp function call has scalar type argument, we
134 // need to splat the scalar type one to vector type.
135 if (getSplatName(Callee->getName())) {
136 WorkList.push_back(Call);
137 Changed = true;
138 }
139 }
140 }
141 }
142 }
143 }
144
145 for (CallInst *Call : WorkList) {
146 Function *Callee = Call->getCalledFunction();
147 FunctionType *CalleeTy = Callee->getFunctionType();
148
149 // Create new callee function type with vector type.
150 SmallVector<Type *, 4> NewCalleeParamTys;
151 for (const auto &Arg : Callee->args()) {
152 if (Arg.getType()->isVectorTy()) {
153 NewCalleeParamTys.push_back(Arg.getType());
154 } else {
155 NewCalleeParamTys.push_back(Call->getType());
156 }
157 }
158
159 FunctionType *NewCalleeTy =
160 FunctionType::get(Call->getType(), NewCalleeParamTys, false);
161
162 // Create new callee function declaration with new function type.
163 StringRef NewCallName(getSplatName(Callee->getName()));
alan-bakerbccf62c2019-03-29 10:32:41 -0400164 Function *NewCallee = cast<Function>(
165 M.getOrInsertFunction(NewCallName, NewCalleeTy).getCallee());
David Neto22f144c2017-06-12 14:26:21 -0400166 NewCallee->setCallingConv(CallingConv::SPIR_FUNC);
167
168 // Change target of call instruction.
169 Call->setCalledFunction(NewCalleeTy, NewCallee);
170
171 // Change operands of call instruction.
172 IRBuilder<> Builder(Call);
173 for (unsigned i = 0; i < CalleeTy->getNumParams(); i++) {
174 if (!CalleeTy->getParamType(i)->isVectorTy()) {
175 VectorType *VTy = cast<VectorType>(Call->getType());
176 Value *NewArg = Builder.CreateVectorSplat(
177 VTy->getNumElements(), Call->getArgOperand(i), "arg_splat");
178 Call->setArgOperand(i, NewArg);
179 }
180 }
181
182 Call->setCallingConv(CallingConv::SPIR_FUNC);
183 }
184
185 return Changed;
186}