spirv,nir: lower frexp_exp/frexp_sig inside a new NIR pass

This lowering isn't needed for RADV because AMDGCN has two
instructions. It will be disabled for RADV in an upcoming series.

While we are at it, factorize a little bit.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index eecbc6a..19a807d 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -305,6 +305,7 @@
 
 		NIR_PASS_V(nir, nir_lower_system_values);
 		NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
+		NIR_PASS_V(nir, nir_lower_frexp);
 	}
 
 	/* Vulkan uses the separate-shader linking model */
diff --git a/src/compiler/Makefile.sources b/src/compiler/Makefile.sources
index 722cfbb..5ac4d0d 100644
--- a/src/compiler/Makefile.sources
+++ b/src/compiler/Makefile.sources
@@ -242,6 +242,7 @@
 	nir/nir_lower_constant_initializers.c \
 	nir/nir_lower_double_ops.c \
 	nir/nir_lower_drawpixels.c \
+	nir/nir_lower_frexp.c \
 	nir/nir_lower_global_vars_to_local.c \
 	nir/nir_lower_gs_intrinsics.c \
 	nir/nir_lower_load_const_to_scalar.c \
diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build
index 4f1efb5..510e99c 100644
--- a/src/compiler/nir/meson.build
+++ b/src/compiler/nir/meson.build
@@ -123,6 +123,7 @@
   'nir_lower_constant_initializers.c',
   'nir_lower_double_ops.c',
   'nir_lower_drawpixels.c',
+  'nir_lower_frexp.c',
   'nir_lower_global_vars_to_local.c',
   'nir_lower_gs_intrinsics.c',
   'nir_lower_load_const_to_scalar.c',
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 1da9874..b6a2ba7 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3253,6 +3253,8 @@
 bool nir_lower_clip_fs(nir_shader *shader, unsigned ucp_enables);
 bool nir_lower_clip_cull_distance_arrays(nir_shader *nir);
 
+bool nir_lower_frexp(nir_shader *nir);
+
 void nir_lower_two_sided_color(nir_shader *shader);
 
 bool nir_lower_clamp_color_outputs(nir_shader *shader);
diff --git a/src/compiler/nir/nir_lower_frexp.c b/src/compiler/nir/nir_lower_frexp.c
new file mode 100644
index 0000000..3b95661
--- /dev/null
+++ b/src/compiler/nir/nir_lower_frexp.c
@@ -0,0 +1,208 @@
+/*
+ * Copyright © 2015 Intel Corporation
+ * Copyright © 2019 Valve Corporation
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ *
+ * Authors:
+ *    Jason Ekstrand (jason@jlekstrand.net)
+ *    Samuel Pitoiset (samuel.pitoiset@gmail.com>
+ */
+
+#include "nir.h"
+#include "nir_builder.h"
+
+static nir_ssa_def *
+lower_frexp_sig(nir_builder *b, nir_ssa_def *x)
+{
+   nir_ssa_def *abs_x = nir_fabs(b, x);
+   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
+   nir_ssa_def *sign_mantissa_mask, *exponent_value;
+   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
+
+   switch (x->bit_size) {
+   case 16:
+      /* Half-precision floating-point values are stored as
+       *   1 sign bit;
+       *   5 exponent bits;
+       *   10 mantissa bits.
+       *
+       * An exponent shift of 10 will shift the mantissa out, leaving only the
+       * exponent and sign bit (which itself may be zero, if the absolute value
+       * was taken before the bitcast and shift).
+       */
+      sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
+      /* Exponent of floating-point values in the range [0.5, 1.0). */
+      exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
+      break;
+   case 32:
+      /* Single-precision floating-point values are stored as
+       *   1 sign bit;
+       *   8 exponent bits;
+       *   23 mantissa bits.
+       *
+       * An exponent shift of 23 will shift the mantissa out, leaving only the
+       * exponent and sign bit (which itself may be zero, if the absolute value
+       * was taken before the bitcast and shift.
+       */
+      sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
+      /* Exponent of floating-point values in the range [0.5, 1.0). */
+      exponent_value = nir_imm_int(b, 0x3f000000u);
+      break;
+   case 64:
+      /* Double-precision floating-point values are stored as
+       *   1 sign bit;
+       *   11 exponent bits;
+       *   52 mantissa bits.
+       *
+       * An exponent shift of 20 will shift the remaining mantissa bits out,
+       * leaving only the exponent and sign bit (which itself may be zero, if
+       * the absolute value was taken before the bitcast and shift.
+       */
+      sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
+      /* Exponent of floating-point values in the range [0.5, 1.0). */
+      exponent_value = nir_imm_int(b, 0x3fe00000u);
+      break;
+   default:
+      unreachable("Invalid bitsize");
+   }
+
+   if (x->bit_size == 64) {
+      /* We only need to deal with the exponent so first we extract the upper
+       * 32 bits using nir_unpack_64_2x32_split_y.
+       */
+      nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
+      nir_ssa_def *zero32 = nir_imm_int(b, 0);
+
+      nir_ssa_def *new_upper =
+         nir_ior(b, nir_iand(b, upper_x, sign_mantissa_mask),
+                    nir_bcsel(b, is_not_zero, exponent_value, zero32));
+
+      nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
+
+      return nir_pack_64_2x32_split(b, lower_x, new_upper);
+   } else {
+      return nir_ior(b, nir_iand(b, x, sign_mantissa_mask),
+                        nir_bcsel(b, is_not_zero, exponent_value, zero));
+   }
+}
+
+static nir_ssa_def *
+lower_frexp_exp(nir_builder *b, nir_ssa_def *x)
+{
+   nir_ssa_def *abs_x = nir_fabs(b, x);
+   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
+   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
+   nir_ssa_def *exponent;
+
+   switch (x->bit_size) {
+   case 16: {
+      nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
+      nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
+
+      /* Significand return must be of the same type as the input, but the
+       * exponent must be a 32-bit integer.
+       */
+      exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
+                              nir_bcsel(b, is_not_zero, exponent_bias, zero)));
+      break;
+   }
+   case 32: {
+      nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
+      nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
+
+      exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
+                             nir_bcsel(b, is_not_zero, exponent_bias, zero));
+      break;
+   }
+   case 64: {
+      nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
+      nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
+
+      nir_ssa_def *zero32 = nir_imm_int(b, 0);
+      nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
+
+      exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
+                             nir_bcsel(b, is_not_zero, exponent_bias, zero32));
+      break;
+   }
+   default:
+      unreachable("Invalid bitsize");
+   }
+
+   return exponent;
+}
+
+static bool
+lower_frexp_impl(nir_function_impl *impl)
+{
+   bool progress = false;
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_alu)
+            continue;
+
+         nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
+         nir_ssa_def *lower;
+
+         b.cursor = nir_before_instr(instr);
+
+         switch (alu_instr->op) {
+         case nir_op_frexp_sig:
+            lower = lower_frexp_sig(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
+            break;
+         case nir_op_frexp_exp:
+            lower = lower_frexp_exp(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
+            break;
+         default:
+            continue;
+         }
+
+         nir_ssa_def_rewrite_uses(&alu_instr->dest.dest.ssa,
+                                  nir_src_for_ssa(lower));
+         nir_instr_remove(instr);
+         progress = true;
+      }
+   }
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   }
+
+   return progress;
+}
+
+bool
+nir_lower_frexp(nir_shader *shader)
+{
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
+      if (function->impl)
+         progress |= lower_frexp_impl(function->impl);
+   }
+
+   return progress;
+}
diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index 59ff4b8..ead2aff 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -385,123 +385,6 @@
                     nir_fneg(b, arc), arc);
 }
 
-static nir_ssa_def *
-build_frexp16(nir_builder *b, nir_ssa_def *x, nir_ssa_def **exponent)
-{
-   assert(x->bit_size == 16);
-
-   nir_ssa_def *abs_x = nir_fabs(b, x);
-   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, 16);
-
-   /* Half-precision floating-point values are stored as
-    *   1 sign bit;
-    *   5 exponent bits;
-    *   10 mantissa bits.
-    *
-    * An exponent shift of 10 will shift the mantissa out, leaving only the
-    * exponent and sign bit (which itself may be zero, if the absolute value
-    * was taken before the bitcast and shift).
-    */
-   nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
-   nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
-
-   nir_ssa_def *sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
-
-   /* Exponent of floating-point values in the range [0.5, 1.0). */
-   nir_ssa_def *exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
-
-   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
-
-   /* Significand return must be of the same type as the input, but the
-    * exponent must be a 32-bit integer.
-    */
-   *exponent =
-      nir_i2i32(b,
-                nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
-                            nir_bcsel(b, is_not_zero, exponent_bias, zero)));
-
-   return nir_ior(b, nir_iand(b, x, sign_mantissa_mask),
-                     nir_bcsel(b, is_not_zero, exponent_value, zero));
-}
-
-static nir_ssa_def *
-build_frexp32(nir_builder *b, nir_ssa_def *x, nir_ssa_def **exponent)
-{
-   nir_ssa_def *abs_x = nir_fabs(b, x);
-   nir_ssa_def *zero = nir_imm_float(b, 0.0f);
-
-   /* Single-precision floating-point values are stored as
-    *   1 sign bit;
-    *   8 exponent bits;
-    *   23 mantissa bits.
-    *
-    * An exponent shift of 23 will shift the mantissa out, leaving only the
-    * exponent and sign bit (which itself may be zero, if the absolute value
-    * was taken before the bitcast and shift.
-    */
-   nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
-   nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
-
-   nir_ssa_def *sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
-
-   /* Exponent of floating-point values in the range [0.5, 1.0). */
-   nir_ssa_def *exponent_value = nir_imm_int(b, 0x3f000000u);
-
-   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
-
-   *exponent =
-      nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
-                  nir_bcsel(b, is_not_zero, exponent_bias, zero));
-
-   return nir_ior(b, nir_iand(b, x, sign_mantissa_mask),
-                     nir_bcsel(b, is_not_zero, exponent_value, zero));
-}
-
-static nir_ssa_def *
-build_frexp64(nir_builder *b, nir_ssa_def *x, nir_ssa_def **exponent)
-{
-   nir_ssa_def *abs_x = nir_fabs(b, x);
-   nir_ssa_def *zero = nir_imm_double(b, 0.0);
-   nir_ssa_def *zero32 = nir_imm_float(b, 0.0f);
-
-   /* Double-precision floating-point values are stored as
-    *   1 sign bit;
-    *   11 exponent bits;
-    *   52 mantissa bits.
-    *
-    * We only need to deal with the exponent so first we extract the upper 32
-    * bits using nir_unpack_64_2x32_split_y.
-    */
-   nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
-   nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
-
-   /* An exponent shift of 20 will shift the remaining mantissa bits out,
-    * leaving only the exponent and sign bit (which itself may be zero, if the
-    * absolute value was taken before the bitcast and shift.
-    */
-   nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
-   nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
-
-   nir_ssa_def *sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
-
-   /* Exponent of floating-point values in the range [0.5, 1.0). */
-   nir_ssa_def *exponent_value = nir_imm_int(b, 0x3fe00000u);
-
-   nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero);
-
-   *exponent =
-      nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
-                  nir_bcsel(b, is_not_zero, exponent_bias, zero32));
-
-   nir_ssa_def *new_upper =
-      nir_ior(b, nir_iand(b, upper_x, sign_mantissa_mask),
-                 nir_bcsel(b, is_not_zero, exponent_value, zero32));
-
-   nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
-
-   return nir_pack_64_2x32_split(b, lower_x, new_upper);
-}
-
 static nir_op
 vtn_nir_alu_op_for_spirv_glsl_opcode(struct vtn_builder *b,
                                      enum GLSLstd450 opcode)
@@ -782,28 +665,16 @@
       return;
 
    case GLSLstd450Frexp: {
-      nir_ssa_def *exponent;
-      if (src[0]->bit_size == 64)
-         val->ssa->def = build_frexp64(nb, src[0], &exponent);
-      else if (src[0]->bit_size == 32)
-         val->ssa->def = build_frexp32(nb, src[0], &exponent);
-      else
-         val->ssa->def = build_frexp16(nb, src[0], &exponent);
+      nir_ssa_def *exponent = nir_frexp_exp(nb, src[0]);
+      val->ssa->def = nir_frexp_sig(nb, src[0]);
       nir_store_deref(nb, vtn_nir_deref(b, w[6]), exponent, 0xf);
       return;
    }
 
    case GLSLstd450FrexpStruct: {
       vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
-      if (src[0]->bit_size == 64)
-         val->ssa->elems[0]->def = build_frexp64(nb, src[0],
-                                                 &val->ssa->elems[1]->def);
-      else if (src[0]->bit_size == 32)
-         val->ssa->elems[0]->def = build_frexp32(nb, src[0],
-                                                 &val->ssa->elems[1]->def);
-      else
-         val->ssa->elems[0]->def = build_frexp16(nb, src[0],
-                                                 &val->ssa->elems[1]->def);
+      val->ssa->elems[0]->def = nir_frexp_sig(nb, src[0]);
+      val->ssa->elems[1]->def = nir_frexp_exp(nb, src[0]);
       return;
    }
 
diff --git a/src/freedreno/vulkan/tu_shader.c b/src/freedreno/vulkan/tu_shader.c
index 2a70136..c2fdff9 100644
--- a/src/freedreno/vulkan/tu_shader.c
+++ b/src/freedreno/vulkan/tu_shader.c
@@ -173,6 +173,7 @@
                             ir3_glsl_type_size);
 
    NIR_PASS_V(nir, nir_lower_system_values);
+   NIR_PASS_V(nir, nir_lower_frexp);
    NIR_PASS_V(nir, nir_lower_io, nir_var_all, ir3_glsl_type_size, 0);
 
    nir_shader_gather_info(nir, entry_point->impl);
diff --git a/src/gallium/drivers/freedreno/ir3/ir3_cmdline.c b/src/gallium/drivers/freedreno/ir3/ir3_cmdline.c
index 2892e7c..1481c08 100644
--- a/src/gallium/drivers/freedreno/ir3/ir3_cmdline.c
+++ b/src/gallium/drivers/freedreno/ir3/ir3_cmdline.c
@@ -181,6 +181,7 @@
 			ir3_glsl_type_size);
 
 	NIR_PASS_V(nir, nir_lower_system_values);
+	NIR_PASS_V(nir, nir_lower_frexp);
 	NIR_PASS_V(nir, nir_lower_io, nir_var_all, ir3_glsl_type_size, 0);
 	NIR_PASS_V(nir, gl_nir_lower_samplers, prog);
 
diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c
index e9319f5..90942a4 100644
--- a/src/intel/vulkan/anv_pipeline.c
+++ b/src/intel/vulkan/anv_pipeline.c
@@ -226,6 +226,8 @@
    NIR_PASS_V(nir, nir_lower_io_to_temporaries,
               entry_point->impl, true, false);
 
+   NIR_PASS_V(nir, nir_lower_frexp);
+
    /* Vulkan uses the separate-shader linking model */
    nir->info.separate_shader = true;