Add argOrdinal field to descriptor map
Also change how the descriptor map is derived. When clustering
POD kernel arguments, attach a "kernal_arg_map" metadata node
to the kernel function to fully describe the mapping. This could
be more robust, in the face of potential inlining that could occur
later.
When that metadata is present, it has full information on the kernel
argument argument mapping. Otherwise, we just derive the descriptor map
entries directly from the kernel arguments.
diff --git a/lib/SPIRVProducerPass.cpp b/lib/SPIRVProducerPass.cpp
index a981ef3..595144b 100644
--- a/lib/SPIRVProducerPass.cpp
+++ b/lib/SPIRVProducerPass.cpp
@@ -16,6 +16,7 @@
#pragma warning(push, 0)
#endif
+#include <cassert>
#include <clspv/Passes.h>
#include <llvm/ADT/StringSwitch.h>
@@ -43,6 +44,7 @@
using namespace llvm;
using namespace clspv;
+using namespace mdconst;
namespace {
enum SPIRVOperandType {
@@ -2351,46 +2353,41 @@
}
}
+ const auto *ArgMap = F.getMetadata("kernel_arg_map");
+ // Emit descriptor map entries, if there was explicit metadata
+ // attached.
+ if (ArgMap) {
+ for (const auto &arg : ArgMap->operands()) {
+ const MDNode *arg_node = dyn_cast<MDNode>(arg.get());
+ assert(arg_node->getNumOperands() == 4);
+ const auto name =
+ dyn_cast<MDString>(arg_node->getOperand(0))->getString();
+ const auto old_index =
+ dyn_extract<ConstantInt>(arg_node->getOperand(1))->getZExtValue();
+ const auto new_index =
+ dyn_extract<ConstantInt>(arg_node->getOperand(2))->getZExtValue();
+ const auto offset =
+ dyn_extract<ConstantInt>(arg_node->getOperand(3))->getZExtValue();
+ descriptorMapOut << "kernel," << F.getName() << ",arg," << name
+ << ",argOrdinal," << old_index << ",descriptorSet,"
+ << DescriptorSetIdx << ",binding," << new_index
+ << ",offset," << offset << "\n";
+ }
+ }
+
uint32_t BindingIdx = 0;
for (auto &Arg : F.args()) {
Value *NewGV = ArgGVMap[&Arg];
VMap[&Arg] = VMap[NewGV];
ArgGVIDMap[&Arg] = VMap[&Arg];
- // Emit a descriptor map entry. Handle the case where we've clustered POD
- // arguments.
- {
- auto *ArgTy = Arg.getType();
- bool was_podargs = false;
- if (auto *StructTy = dyn_cast<StructType>(ArgTy)) {
- std::string TypeName = StructTy->getName();
- if (StructTy->getName().endswith(".podargs")) {
- // The uses are extractvalue instructions.
- const StructLayout *structLayout =
- dataLayout.getStructLayout(StructTy);
- for (auto use_iter = Arg.use_begin(); use_iter != Arg.use_end();
- ++use_iter) {
- const Value *user = use_iter->getUser();
- if (auto *ExtractInst = dyn_cast<ExtractValueInst>(user)) {
- // There is only one index.
- unsigned member = ExtractInst->getIndices()[0];
- unsigned offset = structLayout->getElementOffset(member);
-
- descriptorMapOut << "kernel," << F.getName() << ",arg,"
- << ExtractInst->getName() << ",descriptorSet,"
- << DescriptorSetIdx << ",binding,"
- << BindingIdx << ",offset," << offset << "\n";
- }
- }
- was_podargs = true;
- }
- }
- if (!was_podargs) {
- descriptorMapOut << "kernel," << F.getName() << ",arg,"
- << Arg.getName() << ",descriptorSet,"
- << DescriptorSetIdx << ",binding," << BindingIdx
- << ",offset,0\n";
- }
+ // Emit a descriptor map entry for this arg, in case there was no explicit
+ // kernel arg mapping metadata.
+ if (!ArgMap) {
+ descriptorMapOut << "kernel," << F.getName() << ",arg," << Arg.getName()
+ << ",argOrdinal," << BindingIdx << ",descriptorSet,"
+ << DescriptorSetIdx << ",binding," << BindingIdx
+ << ",offset,0\n";
}
// Ops[0] = Target ID