Complete support for the step and smoothstep built-in functions (#231)

* Complete support for the step and smoothstep built-in functions

This commit adds support for the missing scalar/vector variants.
We create vectors replicating the scalars, convert calls to
the vector variant and let the SPIRVProducer pass map the vector
variant to a GLSL extended instruction.

This implementation can pass the OpenCL conformance tests
(stepf, smoothstepf) via clvk on NVidia hardware.

Fixes #199.

Signed-off-by: Kévin Petit <kpet@free.fr>
diff --git a/lib/ReplaceOpenCLBuiltinPass.cpp b/lib/ReplaceOpenCLBuiltinPass.cpp
index 2dc81b5..c2d3964 100644
--- a/lib/ReplaceOpenCLBuiltinPass.cpp
+++ b/lib/ReplaceOpenCLBuiltinPass.cpp
@@ -78,6 +78,7 @@
   bool replaceIsInfAndIsNan(Module &M);
   bool replaceAllAndAny(Module &M);
   bool replaceSelect(Module &M);
+  bool replaceStepSmoothStep(Module &M);
   bool replaceSignbit(Module &M);
   bool replaceMadandMad24andMul24(Module &M);
   bool replaceVloadHalf(Module &M);
@@ -120,6 +121,7 @@
   Changed |= replaceIsInfAndIsNan(M);
   Changed |= replaceAllAndAny(M);
   Changed |= replaceSelect(M);
+  Changed |= replaceStepSmoothStep(M);
   Changed |= replaceSignbit(M);
   Changed |= replaceMadandMad24andMul24(M);
   Changed |= replaceVloadHalf(M);
@@ -888,6 +890,89 @@
   return Changed;
 }
 
+bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
+  bool Changed = false;
+
+  const std::map<const char *, const char *> Map = {
+    { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
+    { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
+    { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
+    { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
+    { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
+    { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
+  };
+
+  for (auto Pair : Map) {
+    // If we find a function with the matching name.
+    if (auto F = M.getFunction(Pair.first)) {
+      SmallVector<Instruction *, 4> ToRemoves;
+
+      // Walk the users of the function.
+      for (auto &U : F->uses()) {
+        if (auto CI = dyn_cast<CallInst>(U.getUser())) {
+
+          auto ReplacementFn = Pair.second;
+
+          SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
+          Value *VectorArg;
+
+          // First figure out which function we're dealing with
+          if (F->getName().startswith("_Z10smoothstep")) {
+            ArgsToSplat.push_back(CI->getOperand(1));
+            VectorArg = CI->getOperand(2);
+          } else {
+            VectorArg = CI->getOperand(1);
+          }
+
+          // Splat arguments that need to be
+          SmallVector<Value*, 2> SplatArgs;
+          auto VecType = VectorArg->getType();
+
+          for (auto arg : ArgsToSplat) {
+            Value* NewVectorArg = UndefValue::get(VecType);
+            for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
+              auto index =  ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
+              NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
+            }
+            SplatArgs.push_back(NewVectorArg);
+          }
+
+          // Replace the call with the vector/vector flavour
+          SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
+          const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
+
+          const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
+
+          SmallVector<Value*, 3> NewArgs;
+          for (auto arg : SplatArgs) {
+            NewArgs.push_back(arg);
+          }
+          NewArgs.push_back(VectorArg);
+
+          const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
+
+          CI->replaceAllUsesWith(NewCI);
+
+          // Lastly, remember to remove the user.
+          ToRemoves.push_back(CI);
+        }
+      }
+
+      Changed = !ToRemoves.empty();
+
+      // And cleanup the calls we don't use anymore.
+      for (auto V : ToRemoves) {
+        V->eraseFromParent();
+      }
+
+      // And remove the function we don't need either too.
+      F->eraseFromParent();
+    }
+  }
+
+  return Changed;
+}
+
 bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
   bool Changed = false;