blob: 40d086a737074d9b43093a567bd2569a836f0484 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
20#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040021
Diego Novilloa4c44fa2019-04-11 10:56:15 -040022#include "Passes.h"
23
David Neto22f144c2017-06-12 14:26:21 -040024using namespace llvm;
25
26#define DEBUG_TYPE "undosret"
27
28namespace {
29struct UndoSRetPass : public ModulePass {
30 static char ID;
31 UndoSRetPass() : ModulePass(ID) {}
32
33 bool runOnModule(Module &M) override;
34};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035} // namespace
David Neto22f144c2017-06-12 14:26:21 -040036
37char UndoSRetPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040038INITIALIZE_PASS(UndoSRetPass, "UndoSRet", "Undo SRet Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040039
40namespace clspv {
41llvm::ModulePass *createUndoSRetPass() { return new UndoSRetPass(); }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040043
44bool UndoSRetPass::runOnModule(Module &M) {
45 bool Changed = false;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040046 LLVMContext &Context = M.getContext();
David Neto22f144c2017-06-12 14:26:21 -040047
48 SmallVector<Function *, 8> WorkList;
49 for (Function &F : M) {
50 if (F.isDeclaration()) {
51 continue;
52 }
53
54 if (F.getReturnType()->isVoidTy()) {
55 for (Argument &Arg : F.args()) {
56 // Check sret attribute.
57 if (Arg.hasStructRetAttr()) {
58 // We found a function that needs to be modified!
59 WorkList.push_back(&F);
60 Changed = true;
61 }
62 }
63 }
64 }
65
66 for (Function *F : WorkList) {
67 auto InsertPoint = F->getEntryBlock().getFirstNonPHIOrDbg();
68
69 for (Argument &Arg : F->args()) {
70 // Check sret attribute.
71 if (Arg.hasStructRetAttr()) {
72 PointerType *PTy = cast<PointerType>(Arg.getType());
73 Type *RetTy = PTy->getElementType();
74 // Create alloca instruction for return value on function's entry
75 // block.
76 AllocaInst *RetVal =
77 new AllocaInst(RetTy, 0, nullptr, "retval", InsertPoint);
78
79 // Change arg's users with retval.
80 Arg.replaceAllUsesWith(RetVal);
81
82 // Create new function type with real return type instead of sret
83 // argument.
84 SmallVector<Type *, 8> NewFuncParamTys;
85 for (const auto &Arg : F->args()) {
86 // Ignore argument with sret attribute.
87 if (Arg.hasStructRetAttr()) {
88 continue;
89 }
90 NewFuncParamTys.push_back(Arg.getType());
91 }
92 FunctionType *NewFuncTy =
93 FunctionType::get(RetTy, NewFuncParamTys, false);
94
95 // Create new function.
96 Function *NewFunc = Function::Create(NewFuncTy, F->getLinkage());
97 NewFunc->takeName(F);
98
99 // Insert the function just after the original to preserve the ordering
100 // of the functions within the module.
101 auto &FunctionList = M.getFunctionList();
102
103 for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
104 Iter != IterEnd; ++Iter) {
105 // If we find our functions place in the iterator.
106 if (&*Iter == F) {
107 FunctionList.insertAfter(Iter, NewFunc);
108 break;
109 }
110 }
111
112 // Map original function's arguments to new function's arguments.
113 ValueToValueMapTy VMap;
114 auto NewArg = NewFunc->arg_begin();
115 for (auto &Arg : F->args()) {
116 if (Arg.hasStructRetAttr()) {
117 VMap[&Arg] = UndefValue::get(Arg.getType());
118 continue;
119 }
alan-baker038e9242019-04-19 22:14:41 -0400120 NewArg->setName(Arg.getName());
David Neto22f144c2017-06-12 14:26:21 -0400121 VMap[&Arg] = &*(NewArg++);
122 }
123
124 // Clone original function into new function.
125 SmallVector<ReturnInst *, 4> RetInsts;
126 CloneFunctionInto(NewFunc, F, VMap, false, RetInsts);
127
128 // Change return instruction like this.
129 //
130 // %retv = load %retval;
131 // ret %retv;
132 for (auto Ret : RetInsts) {
alan-baker741fd1f2020-04-14 17:38:15 -0400133 LoadInst *LD = new LoadInst(RetTy, VMap[RetVal], "", Ret);
David Neto862b7d82018-06-14 18:48:37 -0400134 ReturnInst *NewRet = ReturnInst::Create(Context, LD, Ret);
David Neto22f144c2017-06-12 14:26:21 -0400135 Ret->replaceAllUsesWith(NewRet);
136 Ret->eraseFromParent();
137 }
138
139 SmallVector<User *, 8> ToRemoves;
140
141 // Update caller site.
142 for (auto User : F->users()) {
143 if (CallInst *Call = dyn_cast<CallInst>(User)) {
144 // Create new call instruction for new function without sret.
145 SmallVector<Value *, 8> NewArgs(Call->arg_begin() + 1,
146 Call->arg_end());
147 CallInst *NewCall = CallInst::Create(NewFunc, NewArgs, "", Call);
148
149 NewCall->takeName(Call);
150 NewCall->setCallingConv(Call->getCallingConv());
David Neto22f144c2017-06-12 14:26:21 -0400151 NewCall->setDebugLoc(Call->getDebugLoc());
152
David Neto862b7d82018-06-14 18:48:37 -0400153 // Copy attributes over, but skip the attributes for the first
154 // parameter since it is removed. In particular, the old
155 // first parameter has a StructRet attribute that should disappear.
156 auto attrs(Call->getAttributes());
157 AttributeList new_attrs(
158 AttributeList::get(Context, AttributeList::FunctionIndex,
159 AttrBuilder(attrs.getFnAttributes())));
160 new_attrs =
161 new_attrs.addAttributes(Context, AttributeList::ReturnIndex,
162 AttrBuilder(attrs.getRetAttributes()));
163 for (unsigned i = 1; i < Call->getNumArgOperands(); i++) {
164 new_attrs = new_attrs.addParamAttributes(
165 Context, i - 1, AttrBuilder(attrs.getParamAttributes(i)));
166 }
167 NewCall->setAttributes(new_attrs);
168
David Neto22f144c2017-06-12 14:26:21 -0400169 // Store the value we returned from our function call into the
170 // the orignal destination.
171 new StoreInst(NewCall, Call->getArgOperand(0), Call);
172 }
173
174 ToRemoves.push_back(User);
175 }
176
177 for (User *U : ToRemoves) {
178 U->dropAllReferences();
179 if (Instruction *I = dyn_cast<Instruction>(U)) {
180 I->eraseFromParent();
181 }
182 }
183
184 // We found the argument that had sret, so we are done with this
185 // function!
186 break;
187 }
188 }
189
190 // Delete original functions with sret argument.
191 F->dropAllReferences();
192 F->eraseFromParent();
193 }
194
195 return Changed;
196}