blob: 7c211067a9f0a32a27d5e82fe3ba59baf1d6b534 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040016#include "llvm/IR/Instructions.h"
David Neto118188e2018-08-24 11:27:54 -040017#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
David Neto22f144c2017-06-12 14:26:21 -040020
Diego Novilloa4c44fa2019-04-11 10:56:15 -040021#include "Passes.h"
22
David Neto22f144c2017-06-12 14:26:21 -040023using namespace llvm;
24
25#define DEBUG_TYPE "splatarg"
26
27namespace {
28struct SplatArgPass : public ModulePass {
29 static char ID;
30 SplatArgPass() : ModulePass(ID) {}
31
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040032 const char *getSplatName(StringRef Name);
David Neto22f144c2017-06-12 14:26:21 -040033 bool runOnModule(Module &M) override;
34};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035} // namespace
David Neto22f144c2017-06-12 14:26:21 -040036
37char SplatArgPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040038INITIALIZE_PASS(SplatArgPass, "SplatArg", "Splat Argument Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040039
40namespace clspv {
41llvm::ModulePass *createSplatArgPass() { return new SplatArgPass(); }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040043
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044const char *SplatArgPass::getSplatName(StringRef Name) {
David Neto22f144c2017-06-12 14:26:21 -040045 if (Name.equals("_Z5clampDv2_iii")) {
46 return "_Z5clampDv2_iS_S_";
47 } else if (Name.equals("_Z5clampDv3_iii")) {
48 return "_Z5clampDv3_iS_S_";
49 } else if (Name.equals("_Z5clampDv4_iii")) {
50 return "_Z5clampDv4_iS_S_";
51 } else if (Name.equals("_Z5clampDv2_jjj")) {
52 return "_Z5clampDv2_jS_S_";
53 } else if (Name.equals("_Z5clampDv3_jjj")) {
54 return "_Z5clampDv3_jS_S_";
55 } else if (Name.equals("_Z5clampDv4_jjj")) {
56 return "_Z5clampDv4_jS_S_";
57 } else if (Name.equals("_Z5clampDv2_fff")) {
58 return "_Z5clampDv2_fS_S_";
59 } else if (Name.equals("_Z5clampDv3_fff")) {
60 return "_Z5clampDv3_fS_S_";
61 } else if (Name.equals("_Z5clampDv4_fff")) {
62 return "_Z5clampDv4_fS_S_";
63
64 } else if (Name.equals("_Z3maxDv2_ii")) {
65 return "_Z3maxDv2_iS_";
66 } else if (Name.equals("_Z3maxDv3_ii")) {
67 return "_Z3maxDv3_iS_";
68 } else if (Name.equals("_Z3maxDv4_ii")) {
69 return "_Z3maxDv4_iS_";
70 } else if (Name.equals("_Z3maxDv2_jj")) {
71 return "_Z3maxDv2_jS_";
72 } else if (Name.equals("_Z3maxDv3_jj")) {
73 return "_Z3maxDv3_jS_";
74 } else if (Name.equals("_Z3maxDv4_jj")) {
75 return "_Z3maxDv4_jS_";
76 } else if (Name.equals("_Z3maxDv2_ff")) {
77 return "_Z3maxDv2_fS_";
78 } else if (Name.equals("_Z3maxDv3_ff")) {
79 return "_Z3maxDv3_fS_";
80 } else if (Name.equals("_Z3maxDv4_ff")) {
81 return "_Z3maxDv4_fS_";
82 } else if (Name.equals("_Z4fmaxDv2_ff")) {
83 return "_Z4fmaxDv2_fS_";
84 } else if (Name.equals("_Z4fmaxDv3_ff")) {
85 return "_Z4fmaxDv3_fS_";
86 } else if (Name.equals("_Z4fmaxDv4_ff")) {
87 return "_Z4fmaxDv4_fS_";
88
89 } else if (Name.equals("_Z3minDv2_ii")) {
90 return "_Z3minDv2_iS_";
91 } else if (Name.equals("_Z3minDv3_ii")) {
92 return "_Z3minDv3_iS_";
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040093 } else if (Name.equals("_Z3minDv4_ii")) {
David Neto22f144c2017-06-12 14:26:21 -040094 return "_Z3minDv4_iS_";
95 } else if (Name.equals("_Z3minDv2_jj")) {
96 return "_Z3minDv2_jS_";
97 } else if (Name.equals("_Z3minDv3_jj")) {
98 return "_Z3minDv3_jS_";
99 } else if (Name.equals("_Z3minDv4_jj")) {
100 return "_Z3minDv4_jS_";
101 } else if (Name.equals("_Z3minDv2_ff")) {
102 return "_Z3minDv2_fS_";
103 } else if (Name.equals("_Z3minDv3_ff")) {
104 return "_Z3minDv3_fS_";
105 } else if (Name.equals("_Z3minDv4_ff")) {
106 return "_Z3minDv4_fS_";
107 } else if (Name.equals("_Z4fminDv2_ff")) {
108 return "_Z4fminDv2_fS_";
109 } else if (Name.equals("_Z4fminDv3_ff")) {
110 return "_Z4fminDv3_fS_";
111 } else if (Name.equals("_Z4fminDv4_ff")) {
112 return "_Z4fminDv4_fS_";
113
114 } else if (Name.equals("_Z3mixDv2_fS_f")) {
115 return "_Z3mixDv2_fS_S_";
116 } else if (Name.equals("_Z3mixDv3_fS_f")) {
117 return "_Z3mixDv3_fS_S_";
118 } else if (Name.equals("_Z3mixDv4_fS_f")) {
119 return "_Z3mixDv4_fS_S_";
120 }
121
122 return nullptr;
123}
124
125bool SplatArgPass::runOnModule(Module &M) {
126 bool Changed = false;
127
128 SmallVector<CallInst *, 16> WorkList;
129 for (Function &F : M) {
130 for (BasicBlock &BB : F) {
131 for (Instruction &I : BB) {
132 if (CallInst *Call = dyn_cast<CallInst>(&I)) {
133 Function *Callee = Call->getCalledFunction();
134 if (Callee) {
135 // If min/max/mix/clamp function call has scalar type argument, we
136 // need to splat the scalar type one to vector type.
137 if (getSplatName(Callee->getName())) {
138 WorkList.push_back(Call);
139 Changed = true;
140 }
141 }
142 }
143 }
144 }
145 }
146
147 for (CallInst *Call : WorkList) {
148 Function *Callee = Call->getCalledFunction();
149 FunctionType *CalleeTy = Callee->getFunctionType();
150
151 // Create new callee function type with vector type.
152 SmallVector<Type *, 4> NewCalleeParamTys;
153 for (const auto &Arg : Callee->args()) {
154 if (Arg.getType()->isVectorTy()) {
155 NewCalleeParamTys.push_back(Arg.getType());
156 } else {
157 NewCalleeParamTys.push_back(Call->getType());
158 }
159 }
160
161 FunctionType *NewCalleeTy =
162 FunctionType::get(Call->getType(), NewCalleeParamTys, false);
163
164 // Create new callee function declaration with new function type.
165 StringRef NewCallName(getSplatName(Callee->getName()));
alan-bakerbccf62c2019-03-29 10:32:41 -0400166 Function *NewCallee = cast<Function>(
167 M.getOrInsertFunction(NewCallName, NewCalleeTy).getCallee());
David Neto22f144c2017-06-12 14:26:21 -0400168 NewCallee->setCallingConv(CallingConv::SPIR_FUNC);
169
170 // Change target of call instruction.
171 Call->setCalledFunction(NewCalleeTy, NewCallee);
172
173 // Change operands of call instruction.
174 IRBuilder<> Builder(Call);
175 for (unsigned i = 0; i < CalleeTy->getNumParams(); i++) {
176 if (!CalleeTy->getParamType(i)->isVectorTy()) {
177 VectorType *VTy = cast<VectorType>(Call->getType());
178 Value *NewArg = Builder.CreateVectorSplat(
179 VTy->getNumElements(), Call->getArgOperand(i), "arg_splat");
180 Call->setArgOperand(i, NewArg);
181 }
182 }
183
184 Call->setCallingConv(CallingConv::SPIR_FUNC);
185 }
186
187 return Changed;
188}