Add fine-grained storage restrictions (#580)
Contributes to #529
* Remove -f16bit_storage in favour of new options
* updated uses
* New options to restrict which storage classes support 8- and 16-bit
types
* -no-storage-16bit and -no-storage-8bit
* Default behaviour assumes support for all storage classes
* Added frontend diagnostics to check 8- and 16-bit storage support of
kernel arguments
* Modified pod arg determination to check storage class support
* Added a restriction to not generate push constant pod args if pod args
contain an array
* Added tests
* Fix layout check to support more floating point types
* Document option and usage
diff --git a/lib/FrontendPlugin.cpp b/lib/FrontendPlugin.cpp
index 2758d07..9a7cbb9 100644
--- a/lib/FrontendPlugin.cpp
+++ b/lib/FrontendPlugin.cpp
@@ -61,10 +61,59 @@
CustomDiagnosticRecursiveStruct = 18,
CustomDiagnosticPushConstantSizeExceeded = 19,
CustomDiagnosticPushConstantContainsArray = 20,
+ CustomDiagnosticUnsupported16BitStorage = 21,
+ CustomDiagnosticUnsupported8BitStorage = 22,
CustomDiagnosticTotal
};
std::vector<unsigned> CustomDiagnosticsIDMap;
+ clspv::Option::StorageClass ConvertToStorageClass(clang::LangAS aspace) {
+ switch (aspace) {
+ case LangAS::opencl_constant:
+ if (clspv::Option::ConstantArgsInUniformBuffer()) {
+ return clspv::Option::StorageClass::kUBO;
+ } else {
+ return clspv::Option::StorageClass::kSSBO;
+ }
+ case LangAS::opencl_global:
+ default:
+ return clspv::Option::StorageClass::kSSBO;
+ }
+ }
+
+ bool ContainsSizedType(QualType QT, uint32_t width) {
+ auto canonical = QT.getCanonicalType();
+ if (auto *BT = dyn_cast<BuiltinType>(canonical)) {
+ switch (BT->getKind()) {
+ case BuiltinType::UShort:
+ case BuiltinType::Short:
+ case BuiltinType::Half:
+ case BuiltinType::Float16:
+ return width == 16;
+ case BuiltinType::UChar:
+ case BuiltinType::Char_U:
+ case BuiltinType::SChar:
+ case BuiltinType::Char_S:
+ return width == 8;
+ default:
+ return false;
+ }
+ } else if (auto *PT = dyn_cast<PointerType>(canonical)) {
+ return ContainsSizedType(PT->getPointeeType(), width);
+ } else if (auto *AT = dyn_cast<ArrayType>(canonical)) {
+ return ContainsSizedType(AT->getElementType(), width);
+ } else if (auto *VT = dyn_cast<VectorType>(canonical)) {
+ return ContainsSizedType(VT->getElementType(), width);
+ } else if (auto *RT = dyn_cast<RecordType>(canonical)) {
+ for (auto field_decl : RT->getDecl()->fields()) {
+ if (ContainsSizedType(field_decl->getType(), width))
+ return true;
+ }
+ }
+
+ return false;
+ }
+
bool ContainsPointerType(QualType QT) {
auto canonical = QT.getCanonicalType();
if (canonical->isPointerType()) {
@@ -491,6 +540,14 @@
DE.getCustomDiagID(
DiagnosticsEngine::Error,
"arrays are not supported in push constants currently");
+ CustomDiagnosticsIDMap[CustomDiagnosticUnsupported16BitStorage] =
+ DE.getCustomDiagID(DiagnosticsEngine::Error,
+ "16-bit storage is not supported for "
+ "%select{SSBOs|UBOs|push constants}0");
+ CustomDiagnosticsIDMap[CustomDiagnosticUnsupported8BitStorage] =
+ DE.getCustomDiagID(DiagnosticsEngine::Error,
+ "8-bit storage is not supported for "
+ "%select{SSBOs|UBOs|push constants}0");
}
virtual bool HandleTopLevelDecl(DeclGroupRef DG) override {
@@ -558,6 +615,39 @@
}
}
+ // Check if storage capabilities are supported.
+ if (is_opencl_kernel) {
+ bool contains_16bit =
+ ContainsSizedType(type.getCanonicalType(), 16);
+ bool contains_8bit =
+ ContainsSizedType(type.getCanonicalType(), 8);
+ auto sc = clspv::Option::StorageClass::kSSBO;
+ if (type->isPointerType()) {
+ sc = ConvertToStorageClass(
+ type->getPointeeType().getAddressSpace());
+ } else if (clspv::Option::PodArgsInUniformBuffer()) {
+ sc = clspv::Option::StorageClass::kUBO;
+ } else if (clspv::Option::PodArgsInPushConstants()) {
+ sc = clspv::Option::StorageClass::kPushConstant;
+ }
+ if (contains_16bit &&
+ !clspv::Option::Supports16BitStorageClass(sc)) {
+ Instance.getDiagnostics().Report(
+ P->getSourceRange().getBegin(),
+ CustomDiagnosticsIDMap
+ [CustomDiagnosticUnsupported16BitStorage])
+ << static_cast<int>(sc);
+ }
+ if (contains_8bit &&
+ !clspv::Option::Supports8BitStorageClass(sc)) {
+ Instance.getDiagnostics().Report(
+ P->getSourceRange().getBegin(),
+ CustomDiagnosticsIDMap
+ [CustomDiagnosticUnsupported8BitStorage])
+ << static_cast<int>(sc);
+ }
+ }
+
if (is_opencl_kernel && type->isPointerType()) {
auto pointee_type = type->getPointeeType().getCanonicalType();
if (ContainsPointerType(pointee_type)) {