Implement stp (post-indexed) for aarch64 assembler
PiperOrigin-RevId: 423432717
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index 46c1b34..9e2dcd3 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -13,6 +13,7 @@
// Min and max values for the imm7 for ldp, will be shifted right by 3 when encoding.
constexpr int32_t kImm7Min = -512;
constexpr int32_t kImm7Max = 504;
+constexpr uint32_t kImm7Mask = 0x7F;
// Max value for imm12, will be shifted right by 3 when encoding.
constexpr int32_t kImm12Max = 32760;
constexpr uint32_t kUint12Max = 4095;
@@ -37,7 +38,9 @@
inline uint32_t rd(VRegister vn) { return vn.code; }
inline uint32_t rd(XRegister xn) { return xn.code; }
+inline uint32_t rt(QRegister qn) { return qn.code; }
inline uint32_t rt(VRegister vn) { return vn.code; }
+inline uint32_t rt2(QRegister qn) { return qn.code << 10; }
inline uint32_t rm(XRegister xn) { return xn.code << 16; }
inline uint32_t rm(VRegister vn) { return vn.code << 16; }
inline uint32_t rm(VRegisterLane vn) { return vn.code << 16; }
@@ -192,6 +195,14 @@
}
}
+inline bool imm7_offset_valid(int32_t imm, XRegister xn) {
+ return imm >= kImm7Min && imm <= kImm7Max && (imm & 0x7) == 0;
+}
+
+inline bool imm7_offset_valid(int32_t imm, QRegister qn) {
+ return imm >= (kImm7Min * 2) && imm <= (kImm7Max * 2) && (imm & 0xF) == 0;
+}
+
// Base instructions.
Assembler& Assembler::add(XRegister xd, XRegister xn, uint16_t imm12) {
@@ -209,13 +220,13 @@
}
Assembler& Assembler::ldp(XRegister xt1, XRegister xt2, MemOperand xn) {
- if (xn.offset < kImm7Min || xn.offset > kImm7Max || std::abs(xn.offset) % 8 != 0) {
+ if (!imm7_offset_valid(xn.offset, xt1)) {
error_ = Error::kInvalidOperand;
return *this;
}
const uint32_t mode = xn.mode == AddressingMode::kOffset ? 2 : 1;
- const uint32_t offset = (xn.offset >> 3) & 0x7F;
+ const uint32_t offset = (xn.offset >> 3) & kImm7Mask;
return emit32(0xA8400000 | mode << 23 | offset << 15 | xt2.code << 10 | rn(xn.base) | xt1.code);
}
@@ -342,13 +353,13 @@
}
Assembler& Assembler::ldp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
- if (imm < -1024 || imm > 1008 || (imm & 0xF) != 0) {
+ if (!imm7_offset_valid(imm, qt1)) {
error_ = Error::kInvalidOperand;
return *this;
}
- const uint32_t offset = (imm >> 4) & 0x7F;
+ const uint32_t offset = (imm >> 4) & kImm7Mask;
- return emit32(0xACC00000 | offset << 15 | qt2.code << 10 | rn(xn.base) | qt1.code);
+ return emit32(0xACC00000 | offset << 15 | rt2(qt2) | rn(xn.base) | qt1.code);
}
Assembler& Assembler::ldr(QRegister qt, MemOperand xn, int32_t imm) {
@@ -397,6 +408,16 @@
return emit32(0x0C800000 | q(vt) | rm(xm) | opcode << 12 | size(vt) | rn(xn.base) | rt(vt));
}
+Assembler& Assembler::stp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
+ if (!imm7_offset_valid(imm, qt1)) {
+ error_ = Error::kInvalidOperand;
+ return *this;
+ }
+
+ const uint32_t offset = (imm >> 4) & kImm7Mask;
+ return emit32(0xAC800000 | offset << 15 | rt2(qt2) | rn(xn.base) | rt(qt1));
+}
+
Assembler& Assembler::emit32(uint32_t value) {
if (error_ != Error::kNoError) {
return *this;
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index e283631..7a91be3 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -268,6 +268,7 @@
Assembler& ldr(QRegister qt, MemOperand xn, int32_t imm);
Assembler& movi(VRegister vd, uint8_t imm);
Assembler& st1(VRegisterList vs, MemOperand xn, XRegister xm);
+ Assembler& stp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm);
// Binds Label l to the current location in the code buffer.
Assembler& bind(Label& l);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index 05c5f63..3b7ebd3 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -134,6 +134,13 @@
EXPECT_ERROR(Error::kInvalidOperand, a.st1({v20.v2d(), v21.v2d(), v22.v2d(), v23.v2s()}, mem[x29], x1));
EXPECT_ERROR(Error::kInvalidOperand, a.st1({v20.v2d(), v21.v2d(), v22.v2d(), v27.v2d()}, mem[x29], x1));
+ CHECK_ENCODING(0xAC8144D0, a.stp(q16, q17, mem[x6], 32));
+ CHECK_ENCODING(0xAC9FC4D0, a.stp(q16, q17, mem[x6], 1008));
+ CHECK_ENCODING(0xACA044D0, a.stp(q16, q17, mem[x6], -1024));
+ EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], 34));
+ EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], 1024));
+ EXPECT_ERROR(Error::kInvalidOperand, a.stp(q16, q17, mem[x6], -1040));
+
ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&b));
}