blob: 05bf00feba0476ca30be0f7e27c00ec65f087e41 [file] [log] [blame]
David Neto6373d822017-10-12 19:09:53 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/IR/Attributes.h>
19#include <llvm/IR/DerivedTypes.h>
20#include <llvm/IR/Function.h>
21#include <llvm/IR/Instructions.h>
22#include <llvm/IR/Module.h>
23#include <llvm/Pass.h>
24#include <llvm/Support/raw_ostream.h>
25
26#include <clspv/AddressSpace.h>
27
28using namespace llvm;
29using std::string;
30
31#define DEBUG_TYPE "hideconstantloads"
32
33
34namespace {
35
36const char* kWrapFunctionPrefix = "clspv.wrap_constant_load.";
37
38class HideConstantLoadsPass : public ModulePass {
39 public:
40 static char ID;
41 HideConstantLoadsPass() : ModulePass(ID) {}
42
43 bool runOnModule(Module &M) override;
44
45 private:
46 // Return the name for the wrap function for the given type.
47 string &WrapFunctionNameForType(Type *type) {
48 auto where = function_for_type_.find(type);
49 if (where == function_for_type_.end()) {
50 // Insert it.
51 auto &result = function_for_type_[type] =
52 string(kWrapFunctionPrefix) +
53 std::to_string(function_for_type_.size());
54 return result;
55 } else {
56 return where->second;
57 }
58 }
59
60 // Maps a loaded type to the name of the wrap function for that type.
61 DenseMap<Type *, string> function_for_type_;
62};
63} // namespace
64
65char HideConstantLoadsPass::ID = 0;
66static RegisterPass<HideConstantLoadsPass>
67 X("HideConstantLoads", "Hide loads from __constant memory");
68
69namespace clspv {
70llvm::ModulePass *createHideConstantLoadsPass() {
71 return new HideConstantLoadsPass();
72}
73} // namespace clspv
74
75
76bool HideConstantLoadsPass::runOnModule(Module &M) {
77 bool Changed = false;
78
79 SmallVector<LoadInst *, 16> WorkList;
80 for (Function &F : M) {
81 for (BasicBlock &BB : F) {
82 for (Instruction &I : BB) {
83 if (LoadInst *load = dyn_cast<LoadInst>(&I)) {
84 if (clspv::AddressSpace::Constant == load->getPointerAddressSpace()) {
85 WorkList.push_back(load);
86 }
87 }
88 }
89 }
90 }
91
92 if (WorkList.size() == 0) {
93 return Changed;
94 }
95
96 for (LoadInst *load : WorkList) {
97 Changed = true;
98
99 auto loadedTy = load->getType();
100
101 // The wrap function conceptually maps the loaded value to itself.
102 const string& fn_name = WrapFunctionNameForType(loadedTy);
103 Function* fn = M.getFunction(fn_name);
104 if (!fn) {
105 // Make the function.
106 FunctionType* fnTy = FunctionType::get(loadedTy, {loadedTy}, false);
107 auto fn_constant = M.getOrInsertFunction(fn_name, fnTy);
108 fn = cast<Function>(fn_constant);
109 fn->addFnAttr(Attribute::ReadOnly);
110 fn->addFnAttr(Attribute::ReadNone);
111 }
112
113 // Wrap the load
114 auto call = CallInst::Create(fn, {load});
115 call->insertAfter(load);
116
117 // Replace other uses of the load with the result of the wrap call.
118 {
119 SmallVector<User *, 16> ToReplaceIn;
120 for (auto &use : load->uses()) {
121 User *user = use.getUser();
122 ToReplaceIn.push_back(user);
123 }
124 for (auto *user : ToReplaceIn) {
125 if (dyn_cast<CallInst>(user) != call) {
126 user->replaceUsesOfWith(load, call);
127 }
128 }
129 }
130 }
131
132 return Changed;
133}
134
135namespace {
136class UnhideConstantLoadsPass : public ModulePass {
137 public:
138 static char ID;
139 UnhideConstantLoadsPass() : ModulePass(ID) {}
140
141 bool runOnModule(Module &M) override;
142
143 private:
144
145 // Maps a loaded type to the name of the wrap function for that type.
146 DenseMap<Type *, string> function_for_type_;
147};
148
149} // namespace
150
151char UnhideConstantLoadsPass::ID = 0;
152static RegisterPass<UnhideConstantLoadsPass>
153 X2("UnhideConstantLoads", "Unhide loads from __constant memory");
154
155namespace clspv {
156llvm::ModulePass *createUnhideConstantLoadsPass() {
157 return new UnhideConstantLoadsPass();
158}
159} // namespace clspv
160
161bool UnhideConstantLoadsPass::runOnModule(Module &M) {
162 bool Changed = false;
163
164 SmallVector<Function *, 16> WorkList;
165 for (auto& F : M.getFunctionList()) {
166 if (F.getName().startswith(kWrapFunctionPrefix)) {
167 WorkList.push_back(&F);
168 }
169 }
170
171 if (WorkList.size() == 0)
172 return Changed;
173
174 SmallVector<CallInst *, 16> RemoveList;
175 for (auto* F : WorkList) {
176 for (auto& use : F->uses()) {
177 if (auto* call = dyn_cast<CallInst>(use.getUser())) {
178 assert(call->getNumArgOperands() == 1);
179 auto* load = call->getArgOperand(0);
180 call->replaceAllUsesWith(load);
181 RemoveList.push_back(call);
182 }
183 }
184 }
185 for (auto* call : RemoveList) {
186 call->eraseFromParent();
187 }
188 for (auto* F : WorkList) {
189 F->eraseFromParent();
190 }
191
192 return Changed;
193}