Allow alternative form for refract().

Bug: 27266906
Change-Id: Ib24c8d065d922a987feb0271f397d83b837ae936
diff --git a/modules/glshared/glsBuiltinPrecisionTests.cpp b/modules/glshared/glsBuiltinPrecisionTests.cpp
index d32a373..fee0c53 100644
--- a/modules/glshared/glsBuiltinPrecisionTests.cpp
+++ b/modules/glshared/glsBuiltinPrecisionTests.cpp
@@ -3127,6 +3127,49 @@
 	}
 };
 
+template<int Size, typename Ret, typename Arg0, typename Arg1>
+struct ApplyRefract
+{
+	static ExprP<Ret> apply	(ExpandContext&			ctx,
+							 const ExprP<Arg0>&		i,
+							 const ExprP<Arg1>&		n,
+							 const ExprP<float>&	eta)
+	{
+		const ExprP<float>	dotNI	= bindExpression("dotNI", ctx, dot(n, i));
+		const ExprP<float>	k		= bindExpression("k", ctx, constant(1.0f) - eta * eta *
+												 (constant(1.0f) - dotNI * dotNI));
+
+		return cond(k < constant(0.0f),
+					genXType<float, Size>(constant(0.0f)),
+					i * eta - n * (eta * dotNI + sqrt(k)));
+	};
+};
+
+template<typename Ret, typename Arg0, typename Arg1>
+struct ApplyRefract<1, Ret, Arg0, Arg1>
+{
+	static ExprP<Ret> apply	(ExpandContext&			ctx,
+							 const ExprP<Arg0>&		i,
+							 const ExprP<Arg1>&		n,
+							 const ExprP<float>&	eta)
+	{
+		const ExprP<float>	dotNI	= bindExpression("dotNI", ctx, dot(n, i));
+		const ExprP<float>	k1		= bindExpression("k1", ctx, constant(1.0f) - eta * eta *
+												(constant(1.0f) - dotNI * dotNI));
+
+		const ExprP<float>	k2		= bindExpression("k2", ctx,
+												(((dotNI * (-dotNI)) + constant(1.0f)) * eta)
+												* (-eta) + constant(1.0f));
+
+		return alternatives(cond(k1 < constant(0.0f),
+								genXType<float, 1>(constant(0.0f)),
+								i * eta - n * (eta * dotNI + sqrt(k1))),
+							cond(k2 < constant(0.0f),
+								genXType<float, 1>(constant(0.0f)),
+								i * eta - n * (eta * dotNI + sqrt(k2))));
+	};
+};
+
 template <int Size>
 class Refract : public DerivedFunc<
 	Signature<typename ContainerOf<float, Size>::Container,
@@ -3151,13 +3194,8 @@
 		const ExprP<Arg0>&	i		= args.a;
 		const ExprP<Arg1>&	n		= args.b;
 		const ExprP<float>&	eta		= args.c;
-		const ExprP<float>	dotNI	= bindExpression("dotNI", ctx, dot(n, i));
-		const ExprP<float>	k		= bindExpression("k", ctx, constant(1.0f) - eta * eta *
-												 (constant(1.0f) - dotNI * dotNI));
 
-		return cond(k < constant(0.0f),
-					genXType<float, Size>(constant(0.0f)),
-					i * eta - n * (eta * dotNI + sqrt(k)));
+		return ApplyRefract<Size, Ret, Arg0, Arg1>::apply(ctx, i, n, eta);
 	}
 };