Implement stp (post-indexed) for aarch64 assembler

PiperOrigin-RevId: 423432717
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index 46c1b34..9e2dcd3 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -13,6 +13,7 @@
 // Min and max values for the imm7 for ldp, will be shifted right by 3 when encoding.
 constexpr int32_t kImm7Min = -512;
 constexpr int32_t kImm7Max = 504;
+constexpr uint32_t kImm7Mask = 0x7F;
 // Max value for imm12, will be shifted right by 3 when encoding.
 constexpr int32_t kImm12Max = 32760;
 constexpr uint32_t kUint12Max = 4095;
@@ -37,7 +38,9 @@
 
 inline uint32_t rd(VRegister vn) { return vn.code; }
 inline uint32_t rd(XRegister xn) { return xn.code; }
+inline uint32_t rt(QRegister qn) { return qn.code; }
 inline uint32_t rt(VRegister vn) { return vn.code; }
+inline uint32_t rt2(QRegister qn) { return qn.code << 10; }
 inline uint32_t rm(XRegister xn) { return xn.code << 16; }
 inline uint32_t rm(VRegister vn) { return vn.code << 16; }
 inline uint32_t rm(VRegisterLane vn) { return vn.code << 16; }
@@ -192,6 +195,14 @@
   }
 }
 
+inline bool imm7_offset_valid(int32_t imm, XRegister xn) {
+  return imm >= kImm7Min && imm <= kImm7Max && (imm & 0x7) == 0;
+}
+
+inline bool imm7_offset_valid(int32_t imm, QRegister qn) {
+  return imm >= (kImm7Min * 2) && imm <= (kImm7Max * 2) && (imm & 0xF) == 0;
+}
+
 // Base instructions.
 
 Assembler& Assembler::add(XRegister xd, XRegister xn, uint16_t imm12) {
@@ -209,13 +220,13 @@
 }
 
 Assembler& Assembler::ldp(XRegister xt1, XRegister xt2, MemOperand xn) {
-  if (xn.offset < kImm7Min || xn.offset > kImm7Max || std::abs(xn.offset) % 8 != 0) {
+  if (!imm7_offset_valid(xn.offset, xt1)) {
     error_ = Error::kInvalidOperand;
     return *this;
   }
 
   const uint32_t mode = xn.mode == AddressingMode::kOffset ? 2 : 1;
-  const uint32_t offset = (xn.offset >> 3) & 0x7F;
+  const uint32_t offset = (xn.offset >> 3) & kImm7Mask;
 
   return emit32(0xA8400000 | mode << 23 | offset << 15 | xt2.code << 10 | rn(xn.base) | xt1.code);
 }
@@ -342,13 +353,13 @@
 }
 
 Assembler& Assembler::ldp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
-  if (imm < -1024 || imm > 1008 || (imm & 0xF) != 0) {
+  if (!imm7_offset_valid(imm, qt1)) {
     error_ = Error::kInvalidOperand;
     return *this;
   }
-  const uint32_t offset = (imm >> 4) & 0x7F;
+  const uint32_t offset = (imm >> 4) & kImm7Mask;
 
-  return emit32(0xACC00000 | offset << 15 | qt2.code << 10 | rn(xn.base) | qt1.code);
+  return emit32(0xACC00000 | offset << 15 | rt2(qt2) | rn(xn.base) | qt1.code);
 }
 
 Assembler& Assembler::ldr(QRegister qt, MemOperand xn, int32_t imm) {
@@ -397,6 +408,16 @@
   return emit32(0x0C800000 | q(vt) | rm(xm) | opcode << 12 | size(vt) | rn(xn.base) | rt(vt));
 }
 
+Assembler& Assembler::stp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
+  if (!imm7_offset_valid(imm, qt1)) {
+    error_ = Error::kInvalidOperand;
+    return *this;
+  }
+
+  const uint32_t offset = (imm >> 4) & kImm7Mask;
+  return emit32(0xAC800000 | offset << 15 | rt2(qt2) | rn(xn.base) | rt(qt1));
+}
+
 Assembler& Assembler::emit32(uint32_t value) {
   if (error_ != Error::kNoError) {
     return *this;
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index e283631..7a91be3 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -268,6 +268,7 @@
   Assembler& ldr(QRegister qt, MemOperand xn, int32_t imm);
   Assembler& movi(VRegister vd, uint8_t imm);
   Assembler& st1(VRegisterList vs, MemOperand xn, XRegister xm);
+  Assembler& stp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm);
 
   // Binds Label l to the current location in the code buffer.
   Assembler& bind(Label& l);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index 05c5f63..3b7ebd3 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -134,6 +134,13 @@
   EXPECT_ERROR(Error::kInvalidOperand, a.st1({v20.v2d(), v21.v2d(), v22.v2d(), v23.v2s()}, mem[x29], x1));
   EXPECT_ERROR(Error::kInvalidOperand, a.st1({v20.v2d(), v21.v2d(), v22.v2d(), v27.v2d()}, mem[x29], x1));
 
+  CHECK_ENCODING(0xAC8144D0, a.stp(q16, q17, mem[x6], 32));
+  CHECK_ENCODING(0xAC9FC4D0, a.stp(q16, q17, mem[x6], 1008));
+  CHECK_ENCODING(0xACA044D0, a.stp(q16, q17, mem[x6], -1024));
+  EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], 34));
+  EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], 1024));
+  EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], -1040));
+
   ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&b));
 }