aco: refactor split_store_data() to always split into evenly sized elements

This fixes a couple of issues on GFX67 and
has no negative impact on newer hardware

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7105>
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index f88ce66..8c7e3e7 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -24,6 +24,7 @@
  */
 
 #include <algorithm>
+#include <numeric>
 #include <array>
 #include <stack>
 #include <map>
@@ -3464,19 +3465,13 @@
    return dst;
 }
 
-void split_store_data(isel_context *ctx, RegType dst_type, unsigned count, Temp *dst, unsigned *offsets, Temp src)
+void split_store_data(isel_context *ctx, RegType dst_type, unsigned count, Temp *dst, unsigned *bytes, Temp src)
 {
    if (!count)
       return;
 
    Builder bld(ctx->program, ctx->block);
 
-   ASSERTED bool is_subdword = false;
-   for (unsigned i = 0; i < count; i++)
-      is_subdword |= offsets[i] % 4;
-   is_subdword |= (src.bytes() - offsets[count - 1]) % 4;
-   assert(!is_subdword || dst_type == RegType::vgpr);
-
    /* count == 1 fast path */
    if (count == 1) {
       if (dst_type == RegType::sgpr)
@@ -3486,67 +3481,76 @@
       return;
    }
 
-   for (unsigned i = 0; i < count - 1; i++)
-      dst[i] = bld.tmp(RegClass::get(dst_type, offsets[i + 1] - offsets[i]));
-   dst[count - 1] = bld.tmp(RegClass::get(dst_type, src.bytes() - offsets[count - 1]));
+   /* elem_size_bytes is the greatest common divisor which is a power of 2 */
+   unsigned elem_size_bytes = 1u << (ffs(std::accumulate(bytes, bytes + count, 8, std::bit_or<>{})) - 1);
 
-   if (is_subdword && src.type() == RegType::sgpr) {
-      src = as_vgpr(ctx, src);
-   } else {
-      /* use allocated_vec if possible */
-      auto it = ctx->allocated_vec.find(src.id());
-      if (it != ctx->allocated_vec.end()) {
-         if (!it->second[0].id())
+   ASSERTED bool is_subdword = elem_size_bytes < 4;
+   assert(!is_subdword || dst_type == RegType::vgpr);
+
+   for (unsigned i = 0; i < count; i++)
+      dst[i] = bld.tmp(RegClass::get(dst_type, bytes[i]));
+
+   std::vector<Temp> temps;
+   /* use allocated_vec if possible */
+   auto it = ctx->allocated_vec.find(src.id());
+   if (it != ctx->allocated_vec.end()) {
+      if (!it->second[0].id())
+         goto split;
+      unsigned elem_size = it->second[0].bytes();
+      assert(src.bytes() % elem_size == 0);
+
+      for (unsigned i = 0; i < src.bytes() / elem_size; i++) {
+         if (!it->second[i].id())
             goto split;
-         unsigned elem_size = it->second[0].bytes();
-         assert(src.bytes() % elem_size == 0);
-
-         for (unsigned i = 0; i < src.bytes() / elem_size; i++) {
-            if (!it->second[i].id())
-               goto split;
-         }
-
-         for (unsigned i = 0; i < count; i++) {
-            if (offsets[i] % elem_size || dst[i].bytes() % elem_size)
-               goto split;
-         }
-
-         for (unsigned i = 0; i < count; i++) {
-            unsigned start_idx = offsets[i] / elem_size;
-            unsigned op_count = dst[i].bytes() / elem_size;
-            if (op_count == 1) {
-               if (dst_type == RegType::sgpr)
-                  dst[i] = bld.as_uniform(it->second[start_idx]);
-               else
-                  dst[i] = as_vgpr(ctx, it->second[start_idx]);
-               continue;
-            }
-
-            aco_ptr<Instruction> vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, op_count, 1)};
-            for (unsigned j = 0; j < op_count; j++) {
-               Temp tmp = it->second[start_idx + j];
-               if (dst_type == RegType::sgpr)
-                  tmp = bld.as_uniform(tmp);
-               vec->operands[j] = Operand(tmp);
-            }
-            vec->definitions[0] = Definition(dst[i]);
-            bld.insert(std::move(vec));
-         }
-         return;
       }
+      if (elem_size_bytes % elem_size)
+         goto split;
+
+      temps.insert(temps.end(), it->second.begin(),
+                   it->second.begin() + src.bytes() / elem_size);
+      elem_size_bytes = elem_size;
    }
 
    split:
+   /* split src if necessary */
+   if (temps.empty()) {
+      if (is_subdword && src.type() == RegType::sgpr)
+         src = as_vgpr(ctx, src);
+      if (dst_type == RegType::sgpr)
+         src = bld.as_uniform(src);
 
-   if (dst_type == RegType::sgpr)
-      src = bld.as_uniform(src);
+      unsigned num_elems = src.bytes() / elem_size_bytes;
+      aco_ptr<Instruction> split{create_instruction<Pseudo_instruction>(aco_opcode::p_split_vector, Format::PSEUDO, 1, num_elems)};
+      split->operands[0] = Operand(src);
+      for (unsigned i = 0; i < num_elems; i++) {
+         temps.emplace_back(bld.tmp(RegClass::get(dst_type, elem_size_bytes)));
+         split->definitions[i] = Definition(temps.back());
+      }
+      bld.insert(std::move(split));
+   }
 
-   /* just split it */
-   aco_ptr<Instruction> split{create_instruction<Pseudo_instruction>(aco_opcode::p_split_vector, Format::PSEUDO, 1, count)};
-   split->operands[0] = Operand(src);
-   for (unsigned i = 0; i < count; i++)
-      split->definitions[i] = Definition(dst[i]);
-   bld.insert(std::move(split));
+   unsigned idx = 0;
+   for (unsigned i = 0; i < count; i++) {
+      unsigned op_count = dst[i].bytes() / elem_size_bytes;
+      if (op_count == 1) {
+         if (dst_type == RegType::sgpr)
+            dst[i] = bld.as_uniform(temps[idx++]);
+         else
+            dst[i] = as_vgpr(ctx, temps[idx++]);
+         continue;
+      }
+
+      aco_ptr<Instruction> vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, op_count, 1)};
+      for (unsigned j = 0; j < op_count; j++) {
+         Temp tmp = temps[idx++];
+         if (dst_type == RegType::sgpr)
+            tmp = bld.as_uniform(tmp);
+         vec->operands[j] = Operand(tmp);
+      }
+      vec->definitions[0] = Definition(dst[i]);
+      bld.insert(std::move(vec));
+   }
+   return;
 }
 
 bool scan_write_mask(uint32_t mask, uint32_t todo_mask,
@@ -3582,18 +3586,20 @@
    unsigned write_count = 0;
    Temp write_datas[32];
    unsigned offsets[32];
+   unsigned bytes[32];
    aco_opcode opcodes[32];
 
    wrmask = widen_mask(wrmask, elem_size_bytes);
 
    uint32_t todo = u_bit_consecutive(0, data.bytes());
    while (todo) {
-      int offset, bytes;
-      if (!scan_write_mask(wrmask, todo, &offset, &bytes)) {
+      int offset, byte;
+      if (!scan_write_mask(wrmask, todo, &offset, &byte)) {
          offsets[write_count] = offset;
+         bytes[write_count] = byte;
          opcodes[write_count] = aco_opcode::num_opcodes;
          write_count++;
-         advance_write_mask(&todo, offset, bytes);
+         advance_write_mask(&todo, offset, byte);
          continue;
       }
 
@@ -3604,37 +3610,38 @@
 
       //TODO: use ds_write_b8_d16_hi/ds_write_b16_d16_hi if beneficial
       aco_opcode op = aco_opcode::num_opcodes;
-      if (bytes >= 16 && aligned16 && large_ds_write) {
+      if (byte >= 16 && aligned16 && large_ds_write) {
          op = aco_opcode::ds_write_b128;
-         bytes = 16;
-      } else if (bytes >= 12 && aligned16 && large_ds_write) {
+         byte = 16;
+      } else if (byte >= 12 && aligned16 && large_ds_write) {
          op = aco_opcode::ds_write_b96;
-         bytes = 12;
-      } else if (bytes >= 8 && aligned8) {
+         byte = 12;
+      } else if (byte >= 8 && aligned8) {
          op = aco_opcode::ds_write_b64;
-         bytes = 8;
-      } else if (bytes >= 4 && aligned4) {
+         byte = 8;
+      } else if (byte >= 4 && aligned4) {
          op = aco_opcode::ds_write_b32;
-         bytes = 4;
-      } else if (bytes >= 2 && aligned2) {
+         byte = 4;
+      } else if (byte >= 2 && aligned2) {
          op = aco_opcode::ds_write_b16;
-         bytes = 2;
-      } else if (bytes >= 1) {
+         byte = 2;
+      } else if (byte >= 1) {
          op = aco_opcode::ds_write_b8;
-         bytes = 1;
+         byte = 1;
       } else {
          assert(false);
       }
 
       offsets[write_count] = offset;
+      bytes[write_count] = byte;
       opcodes[write_count] = op;
       write_count++;
-      advance_write_mask(&todo, offset, bytes);
+      advance_write_mask(&todo, offset, byte);
    }
 
    Operand m = load_lds_size_m0(bld);
 
-   split_store_data(ctx, RegType::vgpr, write_count, write_datas, offsets, data);
+   split_store_data(ctx, RegType::vgpr, write_count, write_datas, bytes, data);
 
    for (unsigned i = 0; i < write_count; i++) {
       aco_opcode op = opcodes[i];
@@ -3718,42 +3725,45 @@
 {
    unsigned write_count_with_skips = 0;
    bool skips[16];
+   unsigned bytes[16];
 
    /* determine how to split the data */
    unsigned todo = u_bit_consecutive(0, data.bytes());
    while (todo) {
-      int offset, bytes;
-      skips[write_count_with_skips] = !scan_write_mask(writemask, todo, &offset, &bytes);
+      int offset, byte;
+      skips[write_count_with_skips] = !scan_write_mask(writemask, todo, &offset, &byte);
       offsets[write_count_with_skips] = offset;
       if (skips[write_count_with_skips]) {
-         advance_write_mask(&todo, offset, bytes);
+         bytes[write_count_with_skips] = byte;
+         advance_write_mask(&todo, offset, byte);
          write_count_with_skips++;
          continue;
       }
 
       /* only supported sizes are 1, 2, 4, 8, 12 and 16 bytes and can't be
        * larger than swizzle_element_size */
-      bytes = MIN2(bytes, swizzle_element_size);
-      if (bytes % 4)
-         bytes = bytes > 4 ? bytes & ~0x3 : MIN2(bytes, 2);
+      byte = MIN2(byte, swizzle_element_size);
+      if (byte % 4)
+         byte = byte > 4 ? byte & ~0x3 : MIN2(byte, 2);
 
       /* SMEM and GFX6 VMEM can't emit 12-byte stores */
-      if ((ctx->program->chip_class == GFX6 || smem) && bytes == 12)
-         bytes = 8;
+      if ((ctx->program->chip_class == GFX6 || smem) && byte == 12)
+         byte = 8;
 
       /* dword or larger stores have to be dword-aligned */
       unsigned align_mul = instr ? nir_intrinsic_align_mul(instr) : 4;
       unsigned align_offset = (instr ? nir_intrinsic_align_offset(instr) : 0) + offset;
       bool dword_aligned = align_offset % 4 == 0 && align_mul % 4 == 0;
       if (!dword_aligned)
-         bytes = MIN2(bytes, (align_offset % 2 == 0 && align_mul % 2 == 0) ? 2 : 1);
+         byte = MIN2(byte, (align_offset % 2 == 0 && align_mul % 2 == 0) ? 2 : 1);
 
-      advance_write_mask(&todo, offset, bytes);
+      bytes[write_count_with_skips] = byte;
+      advance_write_mask(&todo, offset, byte);
       write_count_with_skips++;
    }
 
    /* actually split data */
-   split_store_data(ctx, dst_type, write_count_with_skips, write_datas, offsets, data);
+   split_store_data(ctx, dst_type, write_count_with_skips, write_datas, bytes, data);
 
    /* remove skips */
    for (unsigned i = 0; i < write_count_with_skips; i++) {