Add support for the select built-in function (#227)

Passes the OpenCL conformance tests via clvk on NVidia
hardware for all data types except 8-bit.

Signed-off-by: Kévin Petit <kpet@free.fr>
diff --git a/lib/ReplaceOpenCLBuiltinPass.cpp b/lib/ReplaceOpenCLBuiltinPass.cpp
index cd285e1..251e4c0 100644
--- a/lib/ReplaceOpenCLBuiltinPass.cpp
+++ b/lib/ReplaceOpenCLBuiltinPass.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/raw_ostream.h"
@@ -76,6 +77,7 @@
   bool replaceRelational(Module &M);
   bool replaceIsInfAndIsNan(Module &M);
   bool replaceAllAndAny(Module &M);
+  bool replaceSelect(Module &M);
   bool replaceSignbit(Module &M);
   bool replaceMadandMad24andMul24(Module &M);
   bool replaceVloadHalf(Module &M);
@@ -117,6 +119,7 @@
   Changed |= replaceRelational(M);
   Changed |= replaceIsInfAndIsNan(M);
   Changed |= replaceAllAndAny(M);
+  Changed |= replaceSelect(M);
   Changed |= replaceSignbit(M);
   Changed |= replaceMadandMad24andMul24(M);
   Changed |= replaceVloadHalf(M);
@@ -780,6 +783,111 @@
   return Changed;
 }
 
+bool ReplaceOpenCLBuiltinPass::replaceSelect(Module &M) {
+  bool Changed = false;
+
+  for (auto const &SymVal : M.getValueSymbolTable()) {
+    // Skip symbols whose name doesn't match
+    if (!SymVal.getKey().startswith("_Z6select")) {
+      continue;
+    }
+    // Is there a function going by that name?
+    if (auto F = dyn_cast<Function>(SymVal.getValue())) {
+
+      SmallVector<Instruction *, 4> ToRemoves;
+
+      // Walk the users of the function.
+      for (auto &U : F->uses()) {
+        if (auto CI = dyn_cast<CallInst>(U.getUser())) {
+
+          // Get arguments
+          auto FalseValue = CI->getOperand(0);
+          auto TrueValue = CI->getOperand(1);
+          auto PredicateValue = CI->getOperand(2);
+
+          // Don't touch overloads that aren't in OpenCL C
+          auto FalseType = FalseValue->getType();
+          auto TrueType = TrueValue->getType();
+          auto PredicateType = PredicateValue->getType();
+
+          if (FalseType != TrueType) {
+            continue;
+          }
+
+          if (!PredicateType->isIntOrIntVectorTy()) {
+            continue;
+          }
+
+          if (!FalseType->isIntOrIntVectorTy() &&
+              !FalseType->getScalarType()->isFloatingPointTy()) {
+            continue;
+          }
+
+          if (FalseType->isVectorTy() && !PredicateType->isVectorTy()) {
+            continue;
+          }
+
+          if (FalseType->getScalarSizeInBits() !=
+              PredicateType->getScalarSizeInBits()) {
+            continue;
+          }
+
+          if (FalseType->isVectorTy()) {
+            if (FalseType->getVectorNumElements() !=
+                PredicateType->getVectorNumElements()) {
+              continue;
+            }
+
+            if ((FalseType->getVectorNumElements() != 2) &&
+                (FalseType->getVectorNumElements() != 3) &&
+                (FalseType->getVectorNumElements() != 4) &&
+                (FalseType->getVectorNumElements() != 8) &&
+                (FalseType->getVectorNumElements() != 16)) {
+              continue;
+            }
+          }
+
+          // Create constant
+          const auto ZeroValue = Constant::getNullValue(PredicateType);
+
+          // Scalar and vector are to be treated differently
+          CmpInst::Predicate Pred;
+          if (PredicateType->isVectorTy()) {
+            Pred = CmpInst::ICMP_SLT;
+          } else {
+            Pred = CmpInst::ICMP_NE;
+          }
+
+          // Create comparison instruction
+          auto Cmp = CmpInst::Create(Instruction::ICmp, Pred, PredicateValue,
+                                     ZeroValue, "", CI);
+
+          // Create select
+          Value *V = SelectInst::Create(Cmp, TrueValue, FalseValue, "", CI);
+
+          // Replace call with the selection
+          CI->replaceAllUsesWith(V);
+
+          // Lastly, remember to remove the user.
+          ToRemoves.push_back(CI);
+        }
+      }
+
+      Changed = !ToRemoves.empty();
+
+      // And cleanup the calls we don't use anymore.
+      for (auto V : ToRemoves) {
+        V->eraseFromParent();
+      }
+
+      // And remove the function we don't need either too.
+      F->eraseFromParent();
+    }
+  }
+
+  return Changed;
+}
+
 bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
   bool Changed = false;