blob: ae4f4bda9f74057a48d53ffb6fe4a2687eda8507 [file] [log] [blame]
David Neto6373d822017-10-12 19:09:53 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16
David Neto118188e2018-08-24 11:27:54 -040017#include "llvm/ADT/DenseMap.h"
18#include "llvm/IR/Attributes.h"
19#include "llvm/IR/DerivedTypes.h"
20#include "llvm/IR/Function.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/Module.h"
23#include "llvm/Pass.h"
24#include "llvm/Support/raw_ostream.h"
David Neto6373d822017-10-12 19:09:53 -040025
David Neto118188e2018-08-24 11:27:54 -040026#include "clspv/AddressSpace.h"
David Neto6373d822017-10-12 19:09:53 -040027
28using namespace llvm;
29using std::string;
30
31#define DEBUG_TYPE "hideconstantloads"
32
David Neto6373d822017-10-12 19:09:53 -040033namespace {
34
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035const char *kWrapFunctionPrefix = "clspv.wrap_constant_load.";
David Neto6373d822017-10-12 19:09:53 -040036
37class HideConstantLoadsPass : public ModulePass {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040038public:
David Neto6373d822017-10-12 19:09:53 -040039 static char ID;
40 HideConstantLoadsPass() : ModulePass(ID) {}
41
42 bool runOnModule(Module &M) override;
43
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040044private:
45 // Return the name for the wrap function for the given type.
46 string &WrapFunctionNameForType(Type *type) {
47 auto where = function_for_type_.find(type);
48 if (where == function_for_type_.end()) {
49 // Insert it.
50 auto &result = function_for_type_[type] =
51 string(kWrapFunctionPrefix) +
52 std::to_string(function_for_type_.size());
53 return result;
54 } else {
55 return where->second;
56 }
57 }
David Neto6373d822017-10-12 19:09:53 -040058
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040059 // Maps a loaded type to the name of the wrap function for that type.
60 DenseMap<Type *, string> function_for_type_;
David Neto6373d822017-10-12 19:09:53 -040061};
62} // namespace
63
64char HideConstantLoadsPass::ID = 0;
65static RegisterPass<HideConstantLoadsPass>
66 X("HideConstantLoads", "Hide loads from __constant memory");
67
68namespace clspv {
69llvm::ModulePass *createHideConstantLoadsPass() {
70 return new HideConstantLoadsPass();
71}
72} // namespace clspv
73
David Neto6373d822017-10-12 19:09:53 -040074bool HideConstantLoadsPass::runOnModule(Module &M) {
75 bool Changed = false;
76
77 SmallVector<LoadInst *, 16> WorkList;
78 for (Function &F : M) {
79 for (BasicBlock &BB : F) {
80 for (Instruction &I : BB) {
81 if (LoadInst *load = dyn_cast<LoadInst>(&I)) {
82 if (clspv::AddressSpace::Constant == load->getPointerAddressSpace()) {
83 WorkList.push_back(load);
84 }
85 }
86 }
87 }
88 }
89
90 if (WorkList.size() == 0) {
91 return Changed;
92 }
93
94 for (LoadInst *load : WorkList) {
95 Changed = true;
96
97 auto loadedTy = load->getType();
98
99 // The wrap function conceptually maps the loaded value to itself.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400100 const string &fn_name = WrapFunctionNameForType(loadedTy);
101 Function *fn = M.getFunction(fn_name);
David Neto6373d822017-10-12 19:09:53 -0400102 if (!fn) {
103 // Make the function.
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400104 FunctionType *fnTy = FunctionType::get(loadedTy, {loadedTy}, false);
David Neto6373d822017-10-12 19:09:53 -0400105 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
alan-bakerbccf62c2019-03-29 10:32:41 -0400106 fn = cast<Function>(fn_constant.getCallee());
David Neto6373d822017-10-12 19:09:53 -0400107 fn->addFnAttr(Attribute::ReadOnly);
David Neto6373d822017-10-12 19:09:53 -0400108 }
109
110 // Wrap the load
111 auto call = CallInst::Create(fn, {load});
112 call->insertAfter(load);
113
114 // Replace other uses of the load with the result of the wrap call.
115 {
116 SmallVector<User *, 16> ToReplaceIn;
117 for (auto &use : load->uses()) {
118 User *user = use.getUser();
119 ToReplaceIn.push_back(user);
120 }
121 for (auto *user : ToReplaceIn) {
122 if (dyn_cast<CallInst>(user) != call) {
123 user->replaceUsesOfWith(load, call);
124 }
125 }
126 }
127 }
128
129 return Changed;
130}
131
132namespace {
133class UnhideConstantLoadsPass : public ModulePass {
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400134public:
David Neto6373d822017-10-12 19:09:53 -0400135 static char ID;
136 UnhideConstantLoadsPass() : ModulePass(ID) {}
137
138 bool runOnModule(Module &M) override;
139
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400140private:
141 // Maps a loaded type to the name of the wrap function for that type.
142 DenseMap<Type *, string> function_for_type_;
David Neto6373d822017-10-12 19:09:53 -0400143};
144
145} // namespace
146
147char UnhideConstantLoadsPass::ID = 0;
148static RegisterPass<UnhideConstantLoadsPass>
149 X2("UnhideConstantLoads", "Unhide loads from __constant memory");
150
151namespace clspv {
152llvm::ModulePass *createUnhideConstantLoadsPass() {
153 return new UnhideConstantLoadsPass();
154}
155} // namespace clspv
156
157bool UnhideConstantLoadsPass::runOnModule(Module &M) {
158 bool Changed = false;
159
160 SmallVector<Function *, 16> WorkList;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400161 for (auto &F : M.getFunctionList()) {
David Neto6373d822017-10-12 19:09:53 -0400162 if (F.getName().startswith(kWrapFunctionPrefix)) {
163 WorkList.push_back(&F);
164 }
165 }
166
167 if (WorkList.size() == 0)
168 return Changed;
169
170 SmallVector<CallInst *, 16> RemoveList;
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400171 for (auto *F : WorkList) {
172 for (auto &use : F->uses()) {
173 if (auto *call = dyn_cast<CallInst>(use.getUser())) {
David Neto6373d822017-10-12 19:09:53 -0400174 assert(call->getNumArgOperands() == 1);
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400175 auto *load = call->getArgOperand(0);
David Neto6373d822017-10-12 19:09:53 -0400176 call->replaceAllUsesWith(load);
177 RemoveList.push_back(call);
178 }
179 }
180 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400181 for (auto *call : RemoveList) {
David Neto6373d822017-10-12 19:09:53 -0400182 call->eraseFromParent();
183 }
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400184 for (auto *F : WorkList) {
David Neto6373d822017-10-12 19:09:53 -0400185 F->eraseFromParent();
186 }
187
188 return Changed;
189}