blob: 1fd2ba5500886361d17ae56d40652de644b187d3 [file] [log] [blame]
David Neto85082642018-03-24 06:55:20 -07001// Copyright 2018 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Cluster module-scope __constant variables. But only if option
16// ModuleScopeConstantsInUniformBuffer is true.
17
18#include <cassert>
19
David Neto118188e2018-08-24 11:27:54 -040020#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/UniqueVector.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
25#include "llvm/IR/GlobalVariable.h"
26#include "llvm/IR/Instructions.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Module.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/raw_ostream.h"
David Neto85082642018-03-24 06:55:20 -070031
32#include "clspv/AddressSpace.h"
33#include "clspv/Option.h"
34
35#include "ArgKind.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "clusterconstants"
40
41namespace {
42struct ClusterModuleScopeConstantVars : public ModulePass {
43 static char ID;
44 ClusterModuleScopeConstantVars() : ModulePass(ID) {}
45
46 bool runOnModule(Module &M) override;
47};
48
49} // namespace
50
51char ClusterModuleScopeConstantVars::ID = 0;
52static RegisterPass<ClusterModuleScopeConstantVars>
53 X("ClusterModuleScopeConstantVars",
54 "Cluster module-scope __constant variables");
55
56namespace clspv {
57llvm::ModulePass *createClusterModuleScopeConstantVars() {
58 return new ClusterModuleScopeConstantVars();
59}
60} // namespace clspv
61
62bool ClusterModuleScopeConstantVars::runOnModule(Module &M) {
63 bool Changed = false;
64 LLVMContext &Context = M.getContext();
65
66 SmallVector<GlobalVariable *, 8> global_constants;
67 UniqueVector<Constant *> initializers;
68 SmallVector<GlobalVariable *, 8> dead_global_constants;
69 for (GlobalVariable &GV : M.globals()) {
70 if (GV.hasInitializer() && GV.getType()->getPointerAddressSpace() ==
71 clspv::AddressSpace::Constant) {
72 // Only keep live __constant variables.
73 if (GV.use_empty()) {
74 dead_global_constants.push_back(&GV);
75 } else {
76 global_constants.push_back(&GV);
77 initializers.insert(GV.getInitializer());
78 }
79 }
80 }
81
82 for (GlobalVariable *GV : dead_global_constants) {
83 Changed = true;
84 GV->eraseFromParent();
85 }
86
87 if (global_constants.size() > 1 ||
88 (global_constants.size() == 1 &&
89 !global_constants[0]->getType()->isStructTy())) {
90
91 Changed = true;
92
93 // Make the struct type.
94 SmallVector<Type *, 8> types;
95 types.reserve(initializers.size());
96 for (Value *init : initializers) {
97 Type *ty = init->getType();
98 types.push_back(ty);
99 }
100 StructType *type = StructType::get(Context, types);
101
102 // Make the global variable.
103 SmallVector<Constant *, 8> initializers_as_vec(initializers.begin(),
104 initializers.end());
105 Constant *clustered_initializer =
106 ConstantStruct::get(type, initializers_as_vec);
107 GlobalVariable *clustered_gv = new GlobalVariable(
108 M, type, true, GlobalValue::InternalLinkage, clustered_initializer,
109 "clspv.clustered_constants", nullptr,
110 GlobalValue::ThreadLocalMode::NotThreadLocal,
111 clspv::AddressSpace::Constant);
112 assert(clustered_gv);
113
114 // Replace uses of the other globals with references to the members of the
115 // clustered constant.
116 IRBuilder<> Builder(Context);
117 Value *zero = Builder.getInt32(0);
118 for (GlobalVariable *GV : global_constants) {
119 SmallVector<User *, 8> users(GV->users());
120 for (User *user : users) {
121 if (GV == user) {
122 // This is the original global variable declaration. Skip it.
123 } else if (auto *inst = dyn_cast<Instruction>(user)) {
124 unsigned index = initializers.idFor(GV->getInitializer()) - 1;
125 Instruction *gep = GetElementPtrInst::CreateInBounds(
126 clustered_gv, {zero, Builder.getInt32(index)}, "", inst);
127 user->replaceUsesOfWith(GV, gep);
128 } else {
129 errs() << "Don't know how to handle updating user of __constant: "
130 << *user << "\n";
131 llvm_unreachable("Unhandled case replacing a user of __constant");
132 }
133 }
134 }
135
136 // Remove the old constants.
137 for (GlobalVariable *GV : global_constants) {
138 GV->eraseFromParent();
139 }
140 }
141
142 return Changed;
143}