blob: 9ae7f76f19d3a70648b53cd640732b1acbd0a395 [file] [log] [blame]
David Neto22f144c2017-06-12 14:26:21 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto118188e2018-08-24 11:27:54 -040015#include "llvm/IR/Constants.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/IR/Module.h"
18#include "llvm/Pass.h"
19#include "llvm/Support/raw_ostream.h"
20#include "llvm/Transforms/Utils/Cloning.h"
David Neto22f144c2017-06-12 14:26:21 -040021
Diego Novilloa4c44fa2019-04-11 10:56:15 -040022#include "Passes.h"
23
David Neto22f144c2017-06-12 14:26:21 -040024using namespace llvm;
25
26#define DEBUG_TYPE "undosret"
27
28namespace {
29struct UndoSRetPass : public ModulePass {
30 static char ID;
31 UndoSRetPass() : ModulePass(ID) {}
32
33 bool runOnModule(Module &M) override;
34};
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035} // namespace
David Neto22f144c2017-06-12 14:26:21 -040036
37char UndoSRetPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040038INITIALIZE_PASS(UndoSRetPass, "UndoSRet", "Undo SRet Pass", false, false)
David Neto22f144c2017-06-12 14:26:21 -040039
40namespace clspv {
41llvm::ModulePass *createUndoSRetPass() { return new UndoSRetPass(); }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040042} // namespace clspv
David Neto22f144c2017-06-12 14:26:21 -040043
44bool UndoSRetPass::runOnModule(Module &M) {
45 bool Changed = false;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040046 LLVMContext &Context = M.getContext();
David Neto22f144c2017-06-12 14:26:21 -040047
48 SmallVector<Function *, 8> WorkList;
49 for (Function &F : M) {
50 if (F.isDeclaration()) {
51 continue;
52 }
53
54 if (F.getReturnType()->isVoidTy()) {
55 for (Argument &Arg : F.args()) {
56 // Check sret attribute.
57 if (Arg.hasStructRetAttr()) {
58 // We found a function that needs to be modified!
59 WorkList.push_back(&F);
60 Changed = true;
61 }
62 }
63 }
64 }
65
66 for (Function *F : WorkList) {
67 auto InsertPoint = F->getEntryBlock().getFirstNonPHIOrDbg();
68
69 for (Argument &Arg : F->args()) {
70 // Check sret attribute.
71 if (Arg.hasStructRetAttr()) {
72 PointerType *PTy = cast<PointerType>(Arg.getType());
73 Type *RetTy = PTy->getElementType();
74 // Create alloca instruction for return value on function's entry
75 // block.
76 AllocaInst *RetVal =
77 new AllocaInst(RetTy, 0, nullptr, "retval", InsertPoint);
78
79 // Change arg's users with retval.
80 Arg.replaceAllUsesWith(RetVal);
81
82 // Create new function type with real return type instead of sret
83 // argument.
84 SmallVector<Type *, 8> NewFuncParamTys;
85 for (const auto &Arg : F->args()) {
86 // Ignore argument with sret attribute.
87 if (Arg.hasStructRetAttr()) {
88 continue;
89 }
90 NewFuncParamTys.push_back(Arg.getType());
91 }
92 FunctionType *NewFuncTy =
93 FunctionType::get(RetTy, NewFuncParamTys, false);
94
95 // Create new function.
96 Function *NewFunc = Function::Create(NewFuncTy, F->getLinkage());
97 NewFunc->takeName(F);
98
99 // Insert the function just after the original to preserve the ordering
100 // of the functions within the module.
101 auto &FunctionList = M.getFunctionList();
102
103 for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
104 Iter != IterEnd; ++Iter) {
105 // If we find our functions place in the iterator.
106 if (&*Iter == F) {
107 FunctionList.insertAfter(Iter, NewFunc);
108 break;
109 }
110 }
111
112 // Map original function's arguments to new function's arguments.
113 ValueToValueMapTy VMap;
114 auto NewArg = NewFunc->arg_begin();
115 for (auto &Arg : F->args()) {
116 if (Arg.hasStructRetAttr()) {
117 VMap[&Arg] = UndefValue::get(Arg.getType());
118 continue;
119 }
alan-baker038e9242019-04-19 22:14:41 -0400120 NewArg->setName(Arg.getName());
David Neto22f144c2017-06-12 14:26:21 -0400121 VMap[&Arg] = &*(NewArg++);
122 }
123
124 // Clone original function into new function.
125 SmallVector<ReturnInst *, 4> RetInsts;
alan-bakerb41cfd32021-02-17 08:28:59 -0500126 CloneFunctionInto(NewFunc, F, VMap,
127 CloneFunctionChangeType::LocalChangesOnly, RetInsts);
David Neto22f144c2017-06-12 14:26:21 -0400128
129 // Change return instruction like this.
130 //
131 // %retv = load %retval;
132 // ret %retv;
133 for (auto Ret : RetInsts) {
alan-baker741fd1f2020-04-14 17:38:15 -0400134 LoadInst *LD = new LoadInst(RetTy, VMap[RetVal], "", Ret);
David Neto862b7d82018-06-14 18:48:37 -0400135 ReturnInst *NewRet = ReturnInst::Create(Context, LD, Ret);
David Neto22f144c2017-06-12 14:26:21 -0400136 Ret->replaceAllUsesWith(NewRet);
137 Ret->eraseFromParent();
138 }
139
140 SmallVector<User *, 8> ToRemoves;
141
142 // Update caller site.
143 for (auto User : F->users()) {
144 if (CallInst *Call = dyn_cast<CallInst>(User)) {
145 // Create new call instruction for new function without sret.
146 SmallVector<Value *, 8> NewArgs(Call->arg_begin() + 1,
147 Call->arg_end());
148 CallInst *NewCall = CallInst::Create(NewFunc, NewArgs, "", Call);
149
150 NewCall->takeName(Call);
151 NewCall->setCallingConv(Call->getCallingConv());
David Neto22f144c2017-06-12 14:26:21 -0400152 NewCall->setDebugLoc(Call->getDebugLoc());
153
David Neto862b7d82018-06-14 18:48:37 -0400154 // Copy attributes over, but skip the attributes for the first
155 // parameter since it is removed. In particular, the old
156 // first parameter has a StructRet attribute that should disappear.
157 auto attrs(Call->getAttributes());
158 AttributeList new_attrs(
159 AttributeList::get(Context, AttributeList::FunctionIndex,
alan-bakera6001ae2021-08-18 17:08:27 -0400160 AttrBuilder(attrs.getFnAttrs())));
alan-baker56db84f2021-09-08 20:50:35 -0400161 new_attrs = new_attrs.addAttributesAtIndex(
162 Context, AttributeList::ReturnIndex,
163 AttrBuilder(attrs.getRetAttrs()));
David Neto862b7d82018-06-14 18:48:37 -0400164 for (unsigned i = 1; i < Call->getNumArgOperands(); i++) {
165 new_attrs = new_attrs.addParamAttributes(
alan-bakera6001ae2021-08-18 17:08:27 -0400166 Context, i - 1, AttrBuilder(attrs.getParamAttrs(i)));
David Neto862b7d82018-06-14 18:48:37 -0400167 }
168 NewCall->setAttributes(new_attrs);
169
David Neto22f144c2017-06-12 14:26:21 -0400170 // Store the value we returned from our function call into the
171 // the orignal destination.
172 new StoreInst(NewCall, Call->getArgOperand(0), Call);
173 }
174
175 ToRemoves.push_back(User);
176 }
177
178 for (User *U : ToRemoves) {
179 U->dropAllReferences();
180 if (Instruction *I = dyn_cast<Instruction>(U)) {
181 I->eraseFromParent();
182 }
183 }
184
185 // We found the argument that had sret, so we are done with this
186 // function!
187 break;
188 }
189 }
190
191 // Delete original functions with sret argument.
192 F->dropAllReferences();
193 F->eraseFromParent();
194 }
195
196 return Changed;
197}