Complete support for the step and smoothstep built-in functions (#231)
* Complete support for the step and smoothstep built-in functions
This commit adds support for the missing scalar/vector variants.
We create vectors replicating the scalars, convert calls to
the vector variant and let the SPIRVProducer pass map the vector
variant to a GLSL extended instruction.
This implementation can pass the OpenCL conformance tests
(stepf, smoothstepf) via clvk on NVidia hardware.
Fixes #199.
Signed-off-by: Kévin Petit <kpet@free.fr>
diff --git a/lib/ReplaceOpenCLBuiltinPass.cpp b/lib/ReplaceOpenCLBuiltinPass.cpp
index 2dc81b5..c2d3964 100644
--- a/lib/ReplaceOpenCLBuiltinPass.cpp
+++ b/lib/ReplaceOpenCLBuiltinPass.cpp
@@ -78,6 +78,7 @@
bool replaceIsInfAndIsNan(Module &M);
bool replaceAllAndAny(Module &M);
bool replaceSelect(Module &M);
+ bool replaceStepSmoothStep(Module &M);
bool replaceSignbit(Module &M);
bool replaceMadandMad24andMul24(Module &M);
bool replaceVloadHalf(Module &M);
@@ -120,6 +121,7 @@
Changed |= replaceIsInfAndIsNan(M);
Changed |= replaceAllAndAny(M);
Changed |= replaceSelect(M);
+ Changed |= replaceStepSmoothStep(M);
Changed |= replaceSignbit(M);
Changed |= replaceMadandMad24andMul24(M);
Changed |= replaceVloadHalf(M);
@@ -888,6 +890,89 @@
return Changed;
}
+bool ReplaceOpenCLBuiltinPass::replaceStepSmoothStep(Module &M) {
+ bool Changed = false;
+
+ const std::map<const char *, const char *> Map = {
+ { "_Z4stepfDv2_f", "_Z4stepDv2_fS_" },
+ { "_Z4stepfDv3_f", "_Z4stepDv3_fS_" },
+ { "_Z4stepfDv4_f", "_Z4stepDv4_fS_" },
+ { "_Z10smoothstepffDv2_f", "_Z10smoothstepDv2_fS_S_" },
+ { "_Z10smoothstepffDv3_f", "_Z10smoothstepDv3_fS_S_" },
+ { "_Z10smoothstepffDv4_f", "_Z10smoothstepDv4_fS_S_" },
+ };
+
+ for (auto Pair : Map) {
+ // If we find a function with the matching name.
+ if (auto F = M.getFunction(Pair.first)) {
+ SmallVector<Instruction *, 4> ToRemoves;
+
+ // Walk the users of the function.
+ for (auto &U : F->uses()) {
+ if (auto CI = dyn_cast<CallInst>(U.getUser())) {
+
+ auto ReplacementFn = Pair.second;
+
+ SmallVector<Value*, 2> ArgsToSplat = {CI->getOperand(0)};
+ Value *VectorArg;
+
+ // First figure out which function we're dealing with
+ if (F->getName().startswith("_Z10smoothstep")) {
+ ArgsToSplat.push_back(CI->getOperand(1));
+ VectorArg = CI->getOperand(2);
+ } else {
+ VectorArg = CI->getOperand(1);
+ }
+
+ // Splat arguments that need to be
+ SmallVector<Value*, 2> SplatArgs;
+ auto VecType = VectorArg->getType();
+
+ for (auto arg : ArgsToSplat) {
+ Value* NewVectorArg = UndefValue::get(VecType);
+ for (auto i = 0; i < VecType->getVectorNumElements(); i++) {
+ auto index = ConstantInt::get(Type::getInt32Ty(M.getContext()), i);
+ NewVectorArg = InsertElementInst::Create(NewVectorArg, arg, index, "", CI);
+ }
+ SplatArgs.push_back(NewVectorArg);
+ }
+
+ // Replace the call with the vector/vector flavour
+ SmallVector<Type*, 3> NewArgTypes(ArgsToSplat.size() + 1, VecType);
+ const auto NewFType = FunctionType::get(CI->getType(), NewArgTypes, false);
+
+ const auto NewF = M.getOrInsertFunction(ReplacementFn, NewFType);
+
+ SmallVector<Value*, 3> NewArgs;
+ for (auto arg : SplatArgs) {
+ NewArgs.push_back(arg);
+ }
+ NewArgs.push_back(VectorArg);
+
+ const auto NewCI = CallInst::Create(NewF, NewArgs, "", CI);
+
+ CI->replaceAllUsesWith(NewCI);
+
+ // Lastly, remember to remove the user.
+ ToRemoves.push_back(CI);
+ }
+ }
+
+ Changed = !ToRemoves.empty();
+
+ // And cleanup the calls we don't use anymore.
+ for (auto V : ToRemoves) {
+ V->eraseFromParent();
+ }
+
+ // And remove the function we don't need either too.
+ F->eraseFromParent();
+ }
+ }
+
+ return Changed;
+}
+
bool ReplaceOpenCLBuiltinPass::replaceSignbit(Module &M) {
bool Changed = false;