blob: b913642148ca91a208c9cf4a7c5b52831d082051 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
20#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040021
22using namespace llvm;
23
24#define DEBUG_TYPE "undosret"
25
26namespace {
27struct UndoSRetPass : public ModulePass {
28 static char ID;
29 UndoSRetPass() : ModulePass(ID) {}
30
31 bool runOnModule(Module &M) override;
32};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040033} // namespace
David Neto22f144c2017-06-12 14:26:21 -040034
35char UndoSRetPass::ID = 0;
36static RegisterPass<UndoSRetPass> X("UndoSRet", "Undo SRet Pass");
37
38namespace clspv {
39llvm::ModulePass *createUndoSRetPass() { return new UndoSRetPass(); }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040040} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040041
42bool UndoSRetPass::runOnModule(Module &M) {
43 bool Changed = false;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044 LLVMContext &Context = M.getContext();
David Neto22f144c2017-06-12 14:26:21 -040045
46 SmallVector<Function *, 8> WorkList;
47 for (Function &F : M) {
48 if (F.isDeclaration()) {
49 continue;
50 }
51
52 if (F.getReturnType()->isVoidTy()) {
53 for (Argument &Arg : F.args()) {
54 // Check sret attribute.
55 if (Arg.hasStructRetAttr()) {
56 // We found a function that needs to be modified!
57 WorkList.push_back(&F);
58 Changed = true;
59 }
60 }
61 }
62 }
63
64 for (Function *F : WorkList) {
65 auto InsertPoint = F->getEntryBlock().getFirstNonPHIOrDbg();
66
67 for (Argument &Arg : F->args()) {
68 // Check sret attribute.
69 if (Arg.hasStructRetAttr()) {
70 PointerType *PTy = cast<PointerType>(Arg.getType());
71 Type *RetTy = PTy->getElementType();
72 // Create alloca instruction for return value on function's entry
73 // block.
74 AllocaInst *RetVal =
75 new AllocaInst(RetTy, 0, nullptr, "retval", InsertPoint);
76
77 // Change arg's users with retval.
78 Arg.replaceAllUsesWith(RetVal);
79
80 // Create new function type with real return type instead of sret
81 // argument.
82 SmallVector<Type *, 8> NewFuncParamTys;
83 for (const auto &Arg : F->args()) {
84 // Ignore argument with sret attribute.
85 if (Arg.hasStructRetAttr()) {
86 continue;
87 }
88 NewFuncParamTys.push_back(Arg.getType());
89 }
90 FunctionType *NewFuncTy =
91 FunctionType::get(RetTy, NewFuncParamTys, false);
92
93 // Create new function.
94 Function *NewFunc = Function::Create(NewFuncTy, F->getLinkage());
95 NewFunc->takeName(F);
96
97 // Insert the function just after the original to preserve the ordering
98 // of the functions within the module.
99 auto &FunctionList = M.getFunctionList();
100
101 for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
102 Iter != IterEnd; ++Iter) {
103 // If we find our functions place in the iterator.
104 if (&*Iter == F) {
105 FunctionList.insertAfter(Iter, NewFunc);
106 break;
107 }
108 }
109
110 // Map original function's arguments to new function's arguments.
111 ValueToValueMapTy VMap;
112 auto NewArg = NewFunc->arg_begin();
113 for (auto &Arg : F->args()) {
114 if (Arg.hasStructRetAttr()) {
115 VMap[&Arg] = UndefValue::get(Arg.getType());
116 continue;
117 }
118 VMap[&Arg] = &*(NewArg++);
119 }
120
121 // Clone original function into new function.
122 SmallVector<ReturnInst *, 4> RetInsts;
123 CloneFunctionInto(NewFunc, F, VMap, false, RetInsts);
124
125 // Change return instruction like this.
126 //
127 // %retv = load %retval;
128 // ret %retv;
129 for (auto Ret : RetInsts) {
130 LoadInst *LD = new LoadInst(VMap[RetVal], "", Ret);
David Neto862b7d82018-06-14 18:48:37 -0400131 ReturnInst *NewRet = ReturnInst::Create(Context, LD, Ret);
David Neto22f144c2017-06-12 14:26:21 -0400132 Ret->replaceAllUsesWith(NewRet);
133 Ret->eraseFromParent();
134 }
135
136 SmallVector<User *, 8> ToRemoves;
137
138 // Update caller site.
139 for (auto User : F->users()) {
140 if (CallInst *Call = dyn_cast<CallInst>(User)) {
141 // Create new call instruction for new function without sret.
142 SmallVector<Value *, 8> NewArgs(Call->arg_begin() + 1,
143 Call->arg_end());
144 CallInst *NewCall = CallInst::Create(NewFunc, NewArgs, "", Call);
145
146 NewCall->takeName(Call);
147 NewCall->setCallingConv(Call->getCallingConv());
David Neto22f144c2017-06-12 14:26:21 -0400148 NewCall->setDebugLoc(Call->getDebugLoc());
149
David Neto862b7d82018-06-14 18:48:37 -0400150 // Copy attributes over, but skip the attributes for the first
151 // parameter since it is removed. In particular, the old
152 // first parameter has a StructRet attribute that should disappear.
153 auto attrs(Call->getAttributes());
154 AttributeList new_attrs(
155 AttributeList::get(Context, AttributeList::FunctionIndex,
156 AttrBuilder(attrs.getFnAttributes())));
157 new_attrs =
158 new_attrs.addAttributes(Context, AttributeList::ReturnIndex,
159 AttrBuilder(attrs.getRetAttributes()));
160 for (unsigned i = 1; i < Call->getNumArgOperands(); i++) {
161 new_attrs = new_attrs.addParamAttributes(
162 Context, i - 1, AttrBuilder(attrs.getParamAttributes(i)));
163 }
164 NewCall->setAttributes(new_attrs);
165
David Neto22f144c2017-06-12 14:26:21 -0400166 // Store the value we returned from our function call into the
167 // the orignal destination.
168 new StoreInst(NewCall, Call->getArgOperand(0), Call);
169 }
170
171 ToRemoves.push_back(User);
172 }
173
174 for (User *U : ToRemoves) {
175 U->dropAllReferences();
176 if (Instruction *I = dyn_cast<Instruction>(U)) {
177 I->eraseFromParent();
178 }
179 }
180
181 // We found the argument that had sret, so we are done with this
182 // function!
183 break;
184 }
185 }
186
187 // Delete original functions with sret argument.
188 F->dropAllReferences();
189 F->eraseFromParent();
190 }
191
192 return Changed;
193}