nir/lower_bit_size: Pass a nir_instr to the callback
This way we can start supporting more than just ALU ops.
Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7482>
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index c917b11..f480a57 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -2948,11 +2948,15 @@
}
static unsigned
-lower_bit_size_callback(const nir_alu_instr *alu, void *_)
+lower_bit_size_callback(const nir_instr *instr, void *_)
{
struct radv_device *device = _;
enum chip_class chip = device->physical_device->rad_info.chip_class;
+ if (instr->type != nir_instr_type_alu)
+ return 0;
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+
if (alu->dest.dest.ssa.bit_size & (8 | 16)) {
unsigned bit_size = alu->dest.dest.ssa.bit_size;
switch (alu->op) {
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 389c585..9365c16 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -4948,7 +4948,7 @@
bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);
-typedef unsigned (*nir_lower_bit_size_callback)(const nir_alu_instr *, void *);
+typedef unsigned (*nir_lower_bit_size_callback)(const nir_instr *, void *);
bool nir_lower_bit_size(nir_shader *shader,
nir_lower_bit_size_callback callback,
diff --git a/src/compiler/nir/nir_lower_bit_size.c b/src/compiler/nir/nir_lower_bit_size.c
index 0508bdd..a53090a 100644
--- a/src/compiler/nir/nir_lower_bit_size.c
+++ b/src/compiler/nir/nir_lower_bit_size.c
@@ -46,7 +46,7 @@
}
static void
-lower_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size)
+lower_alu_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size)
{
const nir_op op = alu->op;
unsigned dst_bit_size = alu->dest.dest.ssa.bit_size;
@@ -109,14 +109,11 @@
if (instr->type != nir_instr_type_alu)
continue;
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- assert(alu->dest.dest.is_ssa);
-
- unsigned lower_bit_size = callback(alu, callback_data);
+ unsigned lower_bit_size = callback(instr, callback_data);
if (lower_bit_size == 0)
continue;
- lower_instr(&b, alu, lower_bit_size);
+ lower_alu_instr(&b, nir_instr_as_alu(instr), lower_bit_size);
progress = true;
}
}
diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c
index 245003c..61c1ef9 100644
--- a/src/intel/compiler/brw_nir.c
+++ b/src/intel/compiler/brw_nir.c
@@ -633,8 +633,12 @@
}
static unsigned
-lower_bit_size_callback(const nir_alu_instr *alu, UNUSED void *data)
+lower_bit_size_callback(const nir_instr *instr, UNUSED void *data)
{
+ if (instr->type != nir_instr_type_alu)
+ return 0;
+
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
assert(alu->dest.dest.is_ssa);
if (alu->dest.dest.ssa.bit_size >= 32)
return 0;