Implement fadd (vector) for aarch64 assembler
PiperOrigin-RevId: 423404270
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index cae664f..757585a 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -36,10 +36,14 @@
kTbnz,
};
+inline uint32_t rd(VRegister vn) { return vn.code; }
+inline uint32_t rm(VRegister vn) { return vn.code << 16; }
+inline uint32_t rm(VRegisterLane vn) { return vn.code << 16; }
inline uint32_t rn(XRegister xn) { return xn.code << 5; }
inline uint32_t rn(VRegister vn) { return vn.code << 5; }
inline uint32_t q(VRegister vt) { return vt.q << 30; }
inline uint32_t size(VRegister vt) { return vt.size << 10; }
+inline uint32_t fp_sz(VRegister vn) { return vn.is_s() ? 0 : 1 << 22; }
inline bool is_same_shape(VRegister vt1, VRegister vt2) {
return vt1.size == vt2.size && vt1.q == vt2.q;
@@ -125,9 +129,9 @@
}
}
-inline uint32_t encode_hl(VRegisterLane vl) {
+inline uint32_t hl(VRegisterLane vl) {
if (vl.is_s()) {
- return (vl.lane & 1) << 21 | ((vl.lane & 2) >> 1 << 11);
+ return (vl.lane & 1) << 21 | ((vl.lane & 2) << 10);
} else {
return (vl.lane & 1) << 11;
}
@@ -226,6 +230,15 @@
// SIMD instructions.
+Assembler& Assembler::fadd(VRegister vd, VRegister vn, VRegister vm) {
+ if (!is_same_shape(vd, vn, vm)) {
+ error_ = Error::kInvalidOperand;
+ return *this;
+ }
+
+ return emit32(0x0E20D400 | q(vd) | fp_sz(vn) | rm(vm) | rn(vn) | rd(vd));
+}
+
Assembler& Assembler::fmla(VRegister vd, VRegister vn, VRegisterLane vm) {
if (!is_same_shape(vd, vn) || !is_same_data_type(vd, vm)) {
error_ = Error::kInvalidOperand;
@@ -236,8 +249,7 @@
return *this;
}
- uint32_t sz = vm.is_s() ? 0 : 1;
- return emit32(0x0F801000 | q(vd) | sz << 22 | encode_hl(vm) | vm.code << 16 | rn(vn) | vd.code);
+ return emit32(0x0F801000 | q(vd) | fp_sz(vd) | hl(vm) | rm(vm) | rn(vn) | rd(vd));
}
Assembler& Assembler::ld1(VRegisterList vs, MemOperand xn, int32_t imm) {
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index 439ea2e..355fb27 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -79,6 +79,8 @@
VRegister v2d() const { return {code, 3, 1}; }
ScalarVRegister s() const { return {code, 2}; }
+
+ const bool is_s() { return size == 2; };
};
constexpr VRegister v0{0};
@@ -244,6 +246,7 @@
Assembler& tbnz(XRegister xd, uint8_t bit, Label& l);
// SIMD instructions
+ Assembler& fadd(VRegister vd, VRegister vn, VRegister vm);
Assembler& fmla(VRegister vd, VRegister vn, VRegisterLane vm);
Assembler& ld1(VRegisterList vs, MemOperand xn, int32_t imm);
Assembler& ld2r(VRegisterList xs, MemOperand xn);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index 2a0b1f1..affdd83 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -59,6 +59,9 @@
xnn_allocate_code_memory(&b, XNN_DEFAULT_CODE_BUFFER_SIZE);
Assembler a(&b);
+ CHECK_ENCODING(0x4E25D690, a.fadd(v16.v4s(), v20.v4s(), v5.v4s()));
+ EXPECT_ERROR(Error::kInvalidOperand, a.fadd(v16.v4s(), v20.v4s(), v5.v2s()));
+
CHECK_ENCODING(0x4F801290, a.fmla(v16.v4s(), v20.v4s(), v0.s()[0]));
EXPECT_ERROR(Error::kInvalidOperand, a.fmla(v16.v4s(), v20.v2s(), v0.s()[0]));
EXPECT_ERROR(Error::kInvalidOperand, a.fmla(v16.v2d(), v20.v2d(), v0.s()[0]));