Automate pod arg implementation decision (#575)
Contributes to #529
* Decision made on a per-kernel basis
* New pass assigns metadata to kernels
* passes use that metadata through new utliities to get the right
ArgKinds
* currently only decides based on command line options (so NFC)
* Refactored layout validation into new files
* Refactored some global push constant methods
* new enum for how pod args are implemented
* kGlobalPushConstant will be implemented in the future
* added new metadata to test that required it
* changed arg kind utility to take the arg instead of just the type
* Priorities per-kernel push constants then ubo then ssbo
* checks for compatibility
* Added new variant of isValidExplicitLayout that checks the whole
struct
diff --git a/lib/SPIRVProducerPass.cpp b/lib/SPIRVProducerPass.cpp
index d33baf1..68124e2 100644
--- a/lib/SPIRVProducerPass.cpp
+++ b/lib/SPIRVProducerPass.cpp
@@ -57,6 +57,7 @@
#include "ConstantEmitter.h"
#include "Constants.h"
#include "DescriptorCounter.h"
+#include "Layout.h"
#include "NormalizeGlobalVariable.h"
#include "Passes.h"
#include "SpecConstant.h"
@@ -2881,200 +2882,6 @@
}
}
-namespace {
-
-bool isScalarType(Type *type) {
- return type->isIntegerTy() || type->isFloatTy();
-}
-
-uint64_t structAlignment(StructType *type,
- std::function<uint64_t(Type *)> alignFn) {
- uint64_t maxAlign = 1;
- for (unsigned i = 0; i < type->getStructNumElements(); i++) {
- uint64_t align = alignFn(type->getStructElementType(i));
- maxAlign = std::max(align, maxAlign);
- }
- return maxAlign;
-}
-
-uint64_t scalarAlignment(Type *type) {
- // A scalar of size N has a scalar alignment of N.
- if (isScalarType(type)) {
- return type->getScalarSizeInBits() / 8;
- }
-
- // A vector or matrix type has a scalar alignment equal to that of its
- // component type.
- if (auto vec_type = dyn_cast<VectorType>(type)) {
- return scalarAlignment(vec_type->getElementType());
- }
-
- // An array type has a scalar alignment equal to that of its element type.
- if (type->isArrayTy()) {
- return scalarAlignment(type->getArrayElementType());
- }
-
- // A structure has a scalar alignment equal to the largest scalar alignment of
- // any of its members.
- if (type->isStructTy()) {
- return structAlignment(cast<StructType>(type), scalarAlignment);
- }
-
- llvm_unreachable("Unsupported type");
-}
-
-uint64_t baseAlignment(Type *type) {
- // A scalar has a base alignment equal to its scalar alignment.
- if (isScalarType(type)) {
- return scalarAlignment(type);
- }
-
- if (auto vec_type = dyn_cast<VectorType>(type)) {
- unsigned numElems = vec_type->getNumElements();
-
- // A two-component vector has a base alignment equal to twice its scalar
- // alignment.
- if (numElems == 2) {
- return 2 * scalarAlignment(type);
- }
- // A three- or four-component vector has a base alignment equal to four
- // times its scalar alignment.
- if ((numElems == 3) || (numElems == 4)) {
- return 4 * scalarAlignment(type);
- }
- }
-
- // An array has a base alignment equal to the base alignment of its element
- // type.
- if (type->isArrayTy()) {
- return baseAlignment(type->getArrayElementType());
- }
-
- // A structure has a base alignment equal to the largest base alignment of any
- // of its members.
- if (type->isStructTy()) {
- return structAlignment(cast<StructType>(type), baseAlignment);
- }
-
- // TODO A row-major matrix of C columns has a base alignment equal to the base
- // alignment of a vector of C matrix components.
- // TODO A column-major matrix has a base alignment equal to the base alignment
- // of the matrix column type.
-
- llvm_unreachable("Unsupported type");
-}
-
-uint64_t extendedAlignment(Type *type) {
- // A scalar, vector or matrix type has an extended alignment equal to its base
- // alignment.
- // TODO matrix type
- if (isScalarType(type) || type->isVectorTy()) {
- return baseAlignment(type);
- }
-
- // An array or structure type has an extended alignment equal to the largest
- // extended alignment of any of its members, rounded up to a multiple of 16
- if (type->isStructTy()) {
- auto salign = structAlignment(cast<StructType>(type), extendedAlignment);
- return alignTo(salign, 16);
- }
-
- if (type->isArrayTy()) {
- auto salign = extendedAlignment(type->getArrayElementType());
- return alignTo(salign, 16);
- }
-
- llvm_unreachable("Unsupported type");
-}
-
-uint64_t standardAlignment(Type *type, spv::StorageClass sclass) {
- // If the scalarBlockLayout feature is enabled on the device then every member
- // must be aligned according to its scalar alignment
- if (clspv::Option::ScalarBlockLayout()) {
- return scalarAlignment(type);
- }
-
- // All vectors must be aligned according to their scalar alignment
- if (type->isVectorTy()) {
- return scalarAlignment(type);
- }
-
- // If the uniformBufferStandardLayout feature is not enabled on the device,
- // then any member of an OpTypeStruct with a storage class of Uniform and a
- // decoration of Block must be aligned according to its extended alignment.
- if (!clspv::Option::Std430UniformBufferLayout() &&
- sclass == spv::StorageClassUniform) {
- return extendedAlignment(type);
- }
-
- // Every other member must be aligned according to its base alignment
- return baseAlignment(type);
-}
-
-bool improperlyStraddles(const DataLayout &DL, Type *type, unsigned offset) {
- assert(type->isVectorTy());
-
- auto size = DL.getTypeStoreSize(type);
-
- // It is a vector with total size less than or equal to 16 bytes, and has
- // Offset decorations placing its first byte at F and its last byte at L,
- // where floor(F / 16) != floor(L / 16).
- if ((size <= 16) && (offset % 16 + size > 16)) {
- return true;
- }
-
- // It is a vector with total size greater than 16 bytes and has its Offset
- // decorations placing its first byte at a non-integer multiple of 16
- if ((size > 16) && (offset % 16 != 0)) {
- return true;
- }
-
- return false;
-}
-
-// See 14.5 Shader Resource Interface in Vulkan spec
-bool isValidExplicitLayout(Module &M, StructType *STy, unsigned Member,
- spv::StorageClass SClass, unsigned Offset,
- unsigned PreviousMemberOffset) {
-
- auto MemberType = STy->getElementType(Member);
- auto Align = standardAlignment(MemberType, SClass);
- auto &DL = M.getDataLayout();
-
- // The Offset decoration of any member must be a multiple of its alignment
- if (Offset % Align != 0) {
- return false;
- }
-
- // TODO Any ArrayStride or MatrixStride decoration must be a multiple of the
- // alignment of the array or matrix as defined above
-
- if (!clspv::Option::ScalarBlockLayout()) {
- // Vectors must not improperly straddle, as defined above
- if (MemberType->isVectorTy() &&
- improperlyStraddles(DL, MemberType, Offset)) {
- return true;
- }
-
- // The Offset decoration of a member must not place it between the end
- // of a structure or an array and the next multiple of the alignment of that
- // structure or array
- if (Member > 0) {
- auto PType = STy->getElementType(Member - 1);
- if (PType->isStructTy() || PType->isArrayTy()) {
- auto PAlign = standardAlignment(PType, SClass);
- if (Offset - PreviousMemberOffset < PAlign) {
- return false;
- }
- }
- }
- }
-
- return true;
-}
-
-} // namespace
-
void SPIRVProducerPass::GeneratePushConstantDescriptorMapEntries() {
if (auto GV = module->getGlobalVariable(clspv::PushConstantsVariableName())) {
@@ -3433,16 +3240,6 @@
// Gather the list of resources that are used by this function's arguments.
auto &resource_var_at_index = FunctionToResourceVarsMap[&F];
- // TODO(alan-baker): This should become unnecessary by fixing the rest of the
- // flow to generate pod_ubo arguments earlier.
- auto remap_arg_kind = [](StringRef argKind) {
- std::string kind =
- clspv::Option::PodArgsInUniformBuffer() && argKind.equals("pod")
- ? "pod_ubo"
- : argKind.str();
- return GetArgKindFromName(kind);
- };
-
auto *fty = F.getType()->getPointerElementType();
auto *func_ty = dyn_cast<FunctionType>(fty);
@@ -3465,8 +3262,8 @@
dyn_extract<ConstantInt>(arg_node->getOperand(3))->getZExtValue();
const auto arg_size =
dyn_extract<ConstantInt>(arg_node->getOperand(4))->getZExtValue();
- const auto argKind = remap_arg_kind(
- dyn_cast<MDString>(arg_node->getOperand(5))->getString());
+ const auto argKind = clspv::GetArgKindFromName(
+ dyn_cast<MDString>(arg_node->getOperand(5))->getString().str());
const auto spec_id =
dyn_extract<ConstantInt>(arg_node->getOperand(6))->getSExtValue();
@@ -3517,7 +3314,7 @@
F.getName().str(),
arg->getName().str(),
arg_index,
- remap_arg_kind(clspv::GetArgKindName(info->arg_kind)),
+ info->arg_kind,
0,
0,
0,