Add support for the select built-in function (#227)
Passes the OpenCL conformance tests via clvk on NVidia
hardware for all data types except 8-bit.
Signed-off-by: Kévin Petit <kpet@free.fr>
diff --git a/lib/ReplaceOpenCLBuiltinPass.cpp b/lib/ReplaceOpenCLBuiltinPass.cpp
index cd285e1..251e4c0 100644
--- a/lib/ReplaceOpenCLBuiltinPass.cpp
+++ b/lib/ReplaceOpenCLBuiltinPass.cpp
@@ -20,6 +20,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
@@ -76,6 +77,7 @@
bool replaceRelational(Module &M);
bool replaceIsInfAndIsNan(Module &M);
bool replaceAllAndAny(Module &M);
+ bool replaceSelect(Module &M);
bool replaceSignbit(Module &M);
bool replaceMadandMad24andMul24(Module &M);
bool replaceVloadHalf(Module &M);
@@ -117,6 +119,7 @@
Changed |= replaceRelational(M);
Changed |= replaceIsInfAndIsNan(M);
Changed |= replaceAllAndAny(M);
+ Changed |= replaceSelect(M);
Changed |= replaceSignbit(M);
Changed |= replaceMadandMad24andMul24(M);
Changed |= replaceVloadHalf(M);
@@ -780,6 +783,111 @@
return Changed;
}
+bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
+ bool Changed = false;
+
+ for (auto const &SymVal : M.getValueSymbolTable()) {
+ // Skip symbols whose name doesn't match
+ if (!SymVal.getKey().startswith("_Z6select")) {
+ continue;
+ }
+ // Is there a function going by that name?
+ if (auto F = dyn_cast<Function>(SymVal.getValue())) {
+
+ SmallVector<Instruction *, 4> ToRemoves;
+
+ // Walk the users of the function.
+ for (auto &U : F->uses()) {
+ if (auto CI = dyn_cast<CallInst>(U.getUser())) {
+
+ // Get arguments
+ auto FalseValue = CI->getOperand(0);
+ auto TrueValue = CI->getOperand(1);
+ auto PredicateValue = CI->getOperand(2);
+
+ // Don't touch overloads that aren't in OpenCL C
+ auto FalseType = FalseValue->getType();
+ auto TrueType = TrueValue->getType();
+ auto PredicateType = PredicateValue->getType();
+
+ if (FalseType != TrueType) {
+ continue;
+ }
+
+ if (!PredicateType->isIntOrIntVectorTy()) {
+ continue;
+ }
+
+ if (!FalseType->isIntOrIntVectorTy() &&
+ !FalseType->getScalarType()->isFloatingPointTy()) {
+ continue;
+ }
+
+ if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
+ continue;
+ }
+
+ if (FalseType->getScalarSizeInBits() !=
+ PredicateType->getScalarSizeInBits()) {
+ continue;
+ }
+
+ if (FalseType->isVectorTy()) {
+ if (FalseType->getVectorNumElements() !=
+ PredicateType->getVectorNumElements()) {
+ continue;
+ }
+
+ if ((FalseType->getVectorNumElements() != 2) &&
+ (FalseType->getVectorNumElements() != 3) &&
+ (FalseType->getVectorNumElements() != 4) &&
+ (FalseType->getVectorNumElements() != 8) &&
+ (FalseType->getVectorNumElements() != 16)) {
+ continue;
+ }
+ }
+
+ // Create constant
+ const auto ZeroValue = Constant::getNullValue(PredicateType);
+
+ // Scalar and vector are to be treated differently
+ CmpInst::Predicate Pred;
+ if (PredicateType->isVectorTy()) {
+ Pred = CmpInst::ICMP_SLT;
+ } else {
+ Pred = CmpInst::ICMP_NE;
+ }
+
+ // Create comparison instruction
+ auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
+ ZeroValue, "", CI);
+
+ // Create select
+ Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
+
+ // Replace call with the selection
+ CI->replaceAllUsesWith(V);
+
+ // Lastly, remember to remove the user.
+ ToRemoves.push_back(CI);
+ }
+ }
+
+ Changed = !ToRemoves.empty();
+
+ // And cleanup the calls we don't use anymore.
+ for (auto V : ToRemoves) {
+ V->eraseFromParent();
+ }
+
+ // And remove the function we don't need either too.
+ F->eraseFromParent();
+ }
+ }
+
+ return Changed;
+}
+
bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
bool Changed = false;