blob: 4525e24950f7ce0796556b18aa956b3e040f3dcc [file] [log] [blame]
David Netodd992212017-06-23 17:47:55 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <cassert>
16
17#include <llvm/IR/Constants.h>
18#include <llvm/IR/DerivedTypes.h>
19#include <llvm/IR/Function.h>
20#include <llvm/IR/Instructions.h>
21#include <llvm/IR/IRBuilder.h>
22#include <llvm/IR/Module.h>
23#include <llvm/Pass.h>
24#include <llvm/Support/raw_ostream.h>
25
26//#include <llvm/Transforms/Utils/Cloning.h>
27
28using namespace llvm;
29
30#define DEBUG_TYPE "clusterpodkernelargs"
31
32namespace {
33struct ClusterPodKernelArgumentsPass : public ModulePass {
34 static char ID;
35 ClusterPodKernelArgumentsPass() : ModulePass(ID) {}
36
37 bool runOnModule(Module &M) override;
38};
39} // namespace
40
41char ClusterPodKernelArgumentsPass::ID = 0;
42static RegisterPass<ClusterPodKernelArgumentsPass>
43 X("ClusterPodKernelArgumentsPass",
44 "Cluster POD Kernel Arguments Pass");
45
46namespace clspv {
47llvm::ModulePass *createClusterPodKernelArgumentsPass() {
48 return new ClusterPodKernelArgumentsPass();
49}
50} // namespace clspv
51
52bool ClusterPodKernelArgumentsPass::runOnModule(Module &M) {
53 bool Changed = false;
54 LLVMContext &Context = M.getContext();
55
56 SmallVector<Function *, 8> WorkList;
57
58 for (Function &F : M) {
59 if (F.isDeclaration() || F.getCallingConv() != CallingConv::SPIR_KERNEL) {
60 continue;
61 }
62 for (Argument &Arg : F.args()) {
63 if (!isa<PointerType>(Arg.getType())) {
64 WorkList.push_back(&F);
65 break;
66 }
67 }
68 }
69
70 //WorkList.clear();
71
72 for (Function* F : WorkList) {
73 Changed = true;
74
75 // In OpenCL, kernel arguments are either pointers or POD. A composite with
76 // an element or memeber that is a pointer is not allowed. So we'll use POD
77 // as a shorthand for non-pointer.
78
79 SmallVector<Type *, 8> PtrArgTys;
80 SmallVector<Type *, 8> PodArgTys;
81 for (Argument &Arg : F->args()) {
82 Type *ArgTy = Arg.getType();
83 if (isa<PointerType>(ArgTy)) {
84 PtrArgTys.push_back(ArgTy);
85 } else {
86 PodArgTys.push_back(ArgTy);
87 }
88 }
89
90
91 // Put the pointer arguments first, and then POD arguments struct last.
92 auto PodArgsStructTy =
93 StructType::create(PodArgTys, F->getName().str() + ".podargs");
94 SmallVector<Type *, 8> NewFuncParamTys(PtrArgTys);
95 NewFuncParamTys.push_back(PodArgsStructTy);
96
97 FunctionType *NewFuncTy =
98 FunctionType::get(F->getReturnType(), NewFuncParamTys, false);
99
100 // Create the new function and set key properties.
101 auto NewFunc = Function::Create(NewFuncTy, F->getLinkage());
102 // The new function adopts the real name so that linkage to the outside
103 // world remains the same.
104 NewFunc->setName(F->getName());
105 F->setName(NewFunc->getName().str() + ".inner");
106
107 NewFunc->setCallingConv(F->getCallingConv());
108 F->setCallingConv(CallingConv::SPIR_FUNC);
109
110 NewFunc->setAttributes(F->getAttributes());
111 // Move OpenCL kernel named attributes.
112 // TODO(dneto): Attributes starting with kernel_arg_* should be rewritten
113 // to reflect change in the argument shape.
114 std::vector<const char *> Metadatas{
115 "reqd_work_group_size", "kernel_arg_addr_space",
116 "kernel_arg_access_qual", "kernel_arg_type",
117 "kernel_arg_base_type", "kernel_arg_type_qual"};
118 for (auto name : Metadatas) {
119 NewFunc->setMetadata(name, F->getMetadata(name));
120 F->setMetadata(name, nullptr);
121 }
122
123 // Insert the function after the original, to preserve ordering
124 // in the module as much as possible.
125 auto &FunctionList = M.getFunctionList();
126 for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
127 Iter != IterEnd; ++Iter) {
128 if (&*Iter == F) {
129 FunctionList.insertAfter(Iter, NewFunc);
130 break;
131 }
132 }
133
134 // The body of the wrapper is essentially a call to the original function,
135 // but we have to unwrap the non-pointer arguments from the struct.
136 IRBuilder<> Builder(BasicBlock::Create(Context, "entry", NewFunc));
137
138 // Map the wrapper's arguments to the callee's arguments.
139 SmallVector<Argument *, 8> CallerArgs;
140 for (Argument &Arg : NewFunc->args()) {
141 CallerArgs.push_back(&Arg);
142 }
143 Argument *PodArg = CallerArgs.back();
144 PodArg->setName("podargs");
145
146 SmallVector<Value *, 8> CalleeArgs;
147 unsigned podIndex = 0;
148 unsigned ptrIndex = 0;
149 for (const Argument &Arg : F->args()) {
150 if (isa<PointerType>(Arg.getType())) {
151 CalleeArgs.push_back(CallerArgs[ptrIndex++]);
152 } else {
153 CalleeArgs.push_back(Builder.CreateExtractValue(PodArg, {podIndex++}));
154 }
155 CalleeArgs.back()->setName(Arg.getName());
156 }
157 assert(ptrIndex + podIndex == F->arg_size());
158 assert(ptrIndex = PtrArgTys.size());
159 assert(podIndex = PodArgTys.size());
160
161 auto Call = Builder.CreateCall(F, CalleeArgs);
162 Call->setCallingConv(F->getCallingConv());
163
164 Builder.CreateRetVoid();
165 }
166
167 return Changed;
168}