Implement ldp (d registers) offset and post index for aarch64 assembler
PiperOrigin-RevId: 424203062
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index 73d3fdf..f0b50d3 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -46,6 +46,8 @@
inline uint32_t q(VRegister vt) { return vt.q << 30; }
inline uint32_t size(VRegister vt) { return vt.size << 10; }
inline uint32_t fp_sz(VRegister vn) { return vn.is_s() ? 0 : 1 << 22; }
+inline uint32_t postindex(MemOperand op) { return (op.mode == AddressingMode::kPostIndex) ? 0 : 1 << 24; }
+inline uint32_t wb(MemOperand op) { return op.mode == AddressingMode::kOffset ? 0 : 1 << 23; }
inline bool is_same_shape(VRegister vt1, VRegister vt2) {
return vt1.size == vt2.size && vt1.q == vt2.q;
@@ -242,10 +244,9 @@
return *this;
}
- const uint32_t mode = xn.mode == AddressingMode::kOffset ? 2 : 1;
const uint32_t offset = (xn.offset >> 3) & kImm7Mask;
- return emit32(0xA8400000 | mode << 23 | offset << 15 | xt2.code << 10 | rn(xn.base) | xt1.code);
+ return emit32(0xA8400000 | postindex(xn) | wb(xn) | offset << 15 | rt2(xt2) | rn(xn.base) | xt1.code);
}
Assembler& Assembler::ldp(XRegister xt1, XRegister xt2, MemOperand xn, int32_t imm) {
@@ -391,6 +392,20 @@
return emit32(0x0D60C000 | q(xs.vt1) | size(xs.vt1) | rn(xn.base) | xs.vt1.code);
}
+Assembler& Assembler::ldp(DRegister dt1, DRegister dt2, MemOperand xn) {
+ if (!imm7_offset_valid(xn.offset, dt1)) {
+ error_ = Error::kInvalidOperand;
+ return *this;
+ }
+
+ const uint32_t offset = (xn.offset >> 3) & kImm7Mask;
+ return emit32(0x6C400000 | postindex(xn) | wb(xn) | offset << 15 | rt2(dt2) | rn(xn.base) | rt(dt1));
+}
+
+Assembler& Assembler::ldp(DRegister dt1, DRegister dt2, MemOperand xn, int32_t imm) {
+ return ldp(dt1, dt2, {xn.base, imm, AddressingMode::kPostIndex});
+}
+
Assembler& Assembler::ldp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
if (!imm7_offset_valid(imm, qt1)) {
error_ = Error::kInvalidOperand;
@@ -464,9 +479,8 @@
return *this;
}
- const uint32_t wb = (xn.mode == AddressingMode::kPreIndex ? 1 : 0) << 23;
const uint32_t offset = (xn.offset >> 3) & kImm7Mask;
- return emit32(0x6D000000 | wb | offset << 15 | rt2(dt2) | rn(xn.base) | rt(dt1));
+ return emit32(0x6D000000 | wb(xn) | offset << 15 | rt2(dt2) | rn(xn.base) | rt(dt1));
}
Assembler& Assembler::stp(QRegister qt1, QRegister qt2, MemOperand xn) {
@@ -475,9 +489,8 @@
return *this;
}
- const uint32_t wb = (xn.mode == AddressingMode::kPreIndex ? 1 : 0) << 23;
const uint32_t offset = (xn.offset >> 4) & kImm7Mask;
- return emit32(0xAD000000 | wb | offset << 15 | rt2(qt2) | rn(xn.base) | rt(qt1));
+ return emit32(0xAD000000 | wb(xn) | offset << 15 | rt2(qt2) | rn(xn.base) | rt(qt1));
}
Assembler& Assembler::stp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm) {
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index 14d287c..7cf7a49 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -352,6 +352,8 @@
Assembler& fmla(VRegister vd, VRegister vn, VRegisterLane vm);
Assembler& ld1(VRegisterList vs, MemOperand xn, int32_t imm);
Assembler& ld2r(VRegisterList xs, MemOperand xn);
+ Assembler& ldp(DRegister dt1, DRegister dt2, MemOperand xn);
+ Assembler& ldp(DRegister dt1, DRegister dt2, MemOperand xn, int32_t imm);
Assembler& ldp(QRegister qt1, QRegister qt2, MemOperand xn, int32_t imm);
Assembler& ldr(DRegister dt, MemOperand xn, int32_t imm);
Assembler& ldr(QRegister qt, MemOperand xn, int32_t imm);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index 957751e..fbf8332 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -112,6 +112,11 @@
EXPECT_ERROR(Error::kInvalidOperand, a.ld1({v16.v16b(), v17.v16b(), v18.v16b()}, mem[x15], 24));
EXPECT_ERROR(Error::kInvalidOperand, a.ld1({v16.v8b(), v17.v8b(), v18.v8b()}, mem[x15], 48));
+ CHECK_ENCODING(0x6D433FEE, a.ldp(d14, d15, mem[sp, 48]));
+ CHECK_ENCODING(0x6DC33FEE, a.ldp(d14, d15, mem[sp, 48]++));
+ CHECK_ENCODING(0x6CC427E8, a.ldp(d8, d9, mem[sp], 64));
+ EXPECT_ERROR(Error::kInvalidOperand, a.ldp(d14, d15, mem[sp, 7]));
+
CHECK_ENCODING(0xACC154B4, a.ldp(q20, q21, mem[x5], 32));
CHECK_ENCODING(0xACE054B4, a.ldp(q20, q21, mem[x5], -1024));
CHECK_ENCODING(0xACDFD4B4, a.ldp(q20, q21, mem[x5], 1008));