Refactor checks about pointers in structs (#470)
Fixes #468
* Move the check about pointers in structs to only happen for kernel
arguments
* Add a new check to prevent recursive struct definitions
* New and updated tests
diff --git a/lib/FrontendPlugin.cpp b/lib/FrontendPlugin.cpp
index 4c72a1f..1ad55da 100644
--- a/lib/FrontendPlugin.cpp
+++ b/lib/FrontendPlugin.cpp
@@ -55,6 +55,7 @@
CustomDiagnosticSSBOUnalignedStruct = 15,
CustomDiagnosticOverloadedKernel = 16,
CustomDiagnosticStructContainsPointer = 17,
+ CustomDiagnosticRecursiveStruct = 18,
CustomDiagnosticTotal
};
std::vector<unsigned> CustomDiagnosticsIDMap;
@@ -75,6 +76,28 @@
return false;
}
+ bool IsRecursiveType(QualType QT, llvm::DenseSet<const Type *> *seen) {
+ auto canonical = QT.getCanonicalType();
+ if (canonical->isRecordType() &&
+ !seen->insert(canonical.getTypePtr()).second) {
+ return true;
+ }
+
+ if (auto *PT = dyn_cast<PointerType>(canonical)) {
+ return IsRecursiveType(canonical->getPointeeType(), seen);
+ } else if (auto *AT = dyn_cast<ArrayType>(canonical)) {
+ return IsRecursiveType(AT->getElementType(), seen);
+ } else if (auto *RT = dyn_cast<RecordType>(canonical)) {
+ for (auto field_decl : RT->getDecl()->fields()) {
+ if (IsRecursiveType(field_decl->getType(), seen))
+ return true;
+ }
+ }
+
+ seen->erase(canonical.getTypePtr());
+ return false;
+ }
+
bool IsSupportedType(QualType QT, SourceRange SR) {
auto *Ty = QT.getTypePtr();
@@ -101,11 +124,12 @@
return false;
}
} else if (canonicalType->isRecordType()) {
- // Structures should not contain pointers.
- if (ContainsPointerType(canonicalType)) {
+ // Do not allow recursive struct definitions.
+ llvm::DenseSet<const Type *> seen;
+ if (IsRecursiveType(canonicalType, &seen)) {
Instance.getDiagnostics().Report(
SR.getBegin(),
- CustomDiagnosticsIDMap[CustomDiagnosticStructContainsPointer]);
+ CustomDiagnosticsIDMap[CustomDiagnosticRecursiveStruct]);
return false;
}
}
@@ -437,6 +461,9 @@
CustomDiagnosticsIDMap[CustomDiagnosticStructContainsPointer] =
DE.getCustomDiagID(DiagnosticsEngine::Error,
"structures may not contain pointers");
+ CustomDiagnosticsIDMap[CustomDiagnosticRecursiveStruct] =
+ DE.getCustomDiagID(DiagnosticsEngine::Error,
+ "recursive structures are not supported");
}
virtual bool HandleTopLevelDecl(DeclGroupRef DG) override {
@@ -498,6 +525,17 @@
}
}
+ if (is_opencl_kernel && type->isPointerType()) {
+ auto pointee_type = type->getPointeeType().getCanonicalType();
+ if (ContainsPointerType(pointee_type)) {
+ Instance.getDiagnostics().Report(
+ P->getSourceRange().getBegin(),
+ CustomDiagnosticsIDMap
+ [CustomDiagnosticStructContainsPointer]);
+ return false;
+ }
+ }
+
if (is_opencl_kernel && !type->isPointerType()) {
Layout layout = SSBO;
if (clspv::Option::PodArgsInUniformBuffer() &&