Reuse kernel arg variables for StorageBuffers
We can reuse such a variable if:
- it is of the same type
- it was last used for a different function
- it uses the same binding number
Fixes https://github.com/google/clspv/issues/103
diff --git a/lib/SPIRVProducerPass.cpp b/lib/SPIRVProducerPass.cpp
index 535bb3e..270a0ae 100644
--- a/lib/SPIRVProducerPass.cpp
+++ b/lib/SPIRVProducerPass.cpp
@@ -41,7 +41,9 @@
#include <list>
#include <iomanip>
+#include <set>
#include <sstream>
+#include <tuple>
#include <utility>
#if defined(_MSC_VER)
@@ -337,6 +339,10 @@
// This mimics what Glslang does, and that's what drivers are used to.
uint32_t WorkgroupSizeValueID;
uint32_t WorkgroupSizeVarID;
+
+ // What module-scope variables already have had their binding information
+ // emitted?
+ DenseSet<Value*> GVarWithEmittedBindingInfo;
};
char SPIRVProducerPass::ID;
@@ -631,11 +637,33 @@
}
+
+ // Map kernel functions to their ordinal number in the compilation unit.
+ UniqueVector<Function*> KernelOrdinal;
+
+ // Map the global variables created for kernel args to their creation
+ // order.
+ UniqueVector<GlobalVariable*> KernelArgVarOrdinal;
+
+ // For each kernel argument type, record the kernel arg global variables
+ // generated for that type, the function in which that variable was most
+ // recently used, and the binding number it took. For reproducibility,
+ // we track things by ordinal number (rather than pointer), and we use a
+ // std::set rather than DenseSet since std::set maintains an ordering.
+ // Each tuple is the ordinals of the kernel function, the binding number,
+ // and the ordinal of the kernal-arg-var.
+ //
+ // This table lets us reuse module-scope StorageBuffer variables between
+ // different kernels.
+ DenseMap<Type *, std::set<std::tuple<unsigned, unsigned, unsigned>>>
+ GVarsForType;
+
for (Function &F : M) {
// Handle kernel function first.
if (F.isDeclaration() || F.getCallingConv() != CallingConv::SPIR_KERNEL) {
continue;
}
+ KernelOrdinal.insert(&F);
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
@@ -778,15 +806,57 @@
// In order to build type map between llvm type and spirv id, LLVM
// global variable is needed. It has llvm type and other instructions
// can access it with its type.
- GlobalVariable *NewGV = new GlobalVariable(
- M, GVTy, false, GlobalValue::ExternalLinkage, UndefValue::get(GVTy),
- F.getName() + ".arg." + std::to_string(Idx++), nullptr,
- GlobalValue::ThreadLocalMode::NotThreadLocal, AddrSpace);
+ //
+ // Reuse a global variable if it was created for a different entry point.
+
+ // Returns a new global variable for this kernel argument, and remembers
+ // it in KernelArgVarOrdinal.
+ auto make_gvar = [&]() {
+ auto result = new GlobalVariable(
+ M, GVTy, false, GlobalValue::ExternalLinkage, UndefValue::get(GVTy),
+ F.getName() + ".arg." + std::to_string(Idx), nullptr,
+ GlobalValue::ThreadLocalMode::NotThreadLocal, AddrSpace);
+ KernelArgVarOrdinal.insert(result);
+ return result;
+ };
+
+ // Make a new variable if there was none for this type, or if we can
+ // reuse one created for a different function but not yet reused for
+ // the current function, *and* the binding is the same.
+ // Always make a new variable if we're forcing distinct descriptor sets.
+ GlobalVariable *GV = nullptr;
+ auto which_set = GVarsForType.find(GVTy);
+ if (IsSamplerType || IsImageType || which_set == GVarsForType.end() ||
+ distinct_kernel_descriptor_sets) {
+ GV = make_gvar();
+ } else {
+ auto &set = which_set->second;
+ // Reuse a variable if it was associated with a different function.
+ for (auto iter = set.begin(), end = set.end();
+ iter != end; ++iter) {
+ const unsigned fn_ordinal = std::get<0>(*iter);
+ const unsigned binding = std::get<1>(*iter);
+ if (fn_ordinal != KernelOrdinal.idFor(&F) && binding == Idx) {
+ GV = KernelArgVarOrdinal[std::get<2>(*iter)];
+ // Remove it from the set. We'll add it back later.
+ set.erase(iter);
+ break;
+ }
+ }
+ if (!GV) {
+ GV = make_gvar();
+ }
+ }
+ assert(GV);
+ GVarsForType[GVTy].insert(std::make_tuple(KernelOrdinal.idFor(&F), Idx,
+ KernelArgVarOrdinal.idFor(GV)));
// Generate type info for argument global variable.
- FindType(NewGV->getType());
+ FindType(GV->getType());
- ArgGVMap[&Arg] = NewGV;
+ ArgGVMap[&Arg] = GV;
+
+ Idx++;
// Generate pointer type of argument type for OpAccessChain of argument.
if (!Arg.use_empty()) {
@@ -2626,6 +2696,12 @@
<< clspv::GetArgKindForType(Arg.getType()) << "\n";
}
+ if (GVarWithEmittedBindingInfo.count(NewGV)) {
+ BindingIdx++;
+ continue;
+ }
+ GVarWithEmittedBindingInfo.insert(NewGV);
+
// Ops[0] = Target ID
// Ops[1] = Decoration (DescriptorSet)
// Ops[2] = LiteralNumber according to Decoration