blob: 830200fceff981bc5ec8ce38bb6d40ed6427b6f9 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <xnnpack/allocator.h>
#include <xnnpack/assembler.h>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
namespace xnnpack {
namespace aarch32 {
enum class SpecialFPRegister {
kFPSCR = 1,
};
constexpr SpecialFPRegister FPSCR = SpecialFPRegister::kFPSCR;
struct CoreRegister {
uint8_t code;
};
constexpr CoreRegister r0{0};
constexpr CoreRegister r1{1};
constexpr CoreRegister r2{2};
constexpr CoreRegister r3{3};
constexpr CoreRegister r4{4};
constexpr CoreRegister r5{5};
constexpr CoreRegister r6{6};
constexpr CoreRegister r7{7};
constexpr CoreRegister r8{8};
constexpr CoreRegister r9{9};
constexpr CoreRegister r10{10};
constexpr CoreRegister r11{11};
constexpr CoreRegister r12{12};
constexpr CoreRegister r13{13};
constexpr CoreRegister r14{14};
constexpr CoreRegister r15{15};
constexpr CoreRegister sp = r13;
constexpr CoreRegister lr = r14;
constexpr CoreRegister pc = r15;
constexpr CoreRegister APSR_nzcv = r15;
static inline bool operator==(const CoreRegister lhs, const CoreRegister rhs) {
return lhs.code == rhs.code;
}
struct CoreRegisterList {
CoreRegisterList(std::initializer_list<CoreRegister> rs) {
for (auto r : rs) {
list |= 1 << r.code;
}
}
bool has_more_than_one_register() { return (list & (list - 1)) != 0; }
// Bit i is set if CoreRegister is in the list.
uint16_t list = 0;
};
static inline bool operator==(int i, CoreRegisterList registers) {
return i == registers.list;
}
struct SRegister {
uint8_t code;
uint8_t d() const { return code & 0x1; }
uint8_t vd() const { return (code & 0x1e) >> 1; }
};
static inline bool operator==(const SRegister lhs, const SRegister rhs) {
return lhs.code == rhs.code;
}
constexpr SRegister s0{0};
constexpr SRegister s1{1};
constexpr SRegister s2{2};
constexpr SRegister s3{3};
constexpr SRegister s4{4};
constexpr SRegister s5{5};
constexpr SRegister s6{6};
constexpr SRegister s7{7};
constexpr SRegister s8{8};
constexpr SRegister s9{9};
constexpr SRegister s10{10};
constexpr SRegister s11{11};
constexpr SRegister s12{12};
constexpr SRegister s13{13};
constexpr SRegister s14{14};
constexpr SRegister s15{15};
constexpr SRegister s16{16};
constexpr SRegister s17{17};
constexpr SRegister s18{18};
constexpr SRegister s19{19};
constexpr SRegister s20{20};
constexpr SRegister s21{21};
constexpr SRegister s22{22};
constexpr SRegister s23{23};
constexpr SRegister s24{24};
constexpr SRegister s25{25};
constexpr SRegister s26{26};
constexpr SRegister s27{27};
constexpr SRegister s28{28};
constexpr SRegister s29{29};
constexpr SRegister s30{30};
constexpr SRegister s31{31};
// Define DRegisterLane before DRegister so that we can have the operator[] overloading for nice syntax.
struct DRegisterLane {
uint8_t code;
uint8_t lane;
uint8_t d() const { return (code & 0x10) >> 4; }
uint8_t vd() const { return code & 0xf; }
};
static inline bool operator==(const DRegisterLane lhs, const DRegisterLane rhs) {
return lhs.code == rhs.code && lhs.lane == rhs.lane;
}
struct DRegister {
uint8_t code;
uint8_t d() const { return (code & 0x10) >> 4; }
uint8_t vd() const { return code & 0xf; }
const DRegisterLane operator[](std::size_t pos) const {
return DRegisterLane{code, static_cast<uint8_t>(pos)};
}
};
static inline bool operator==(const DRegister lhs, const DRegister rhs) {
return lhs.code == rhs.code;
}
constexpr DRegister d0{0};
constexpr DRegister d1{1};
constexpr DRegister d2{2};
constexpr DRegister d3{3};
constexpr DRegister d4{4};
constexpr DRegister d5{5};
constexpr DRegister d6{6};
constexpr DRegister d7{7};
constexpr DRegister d8{8};
constexpr DRegister d9{9};
constexpr DRegister d10{10};
constexpr DRegister d11{11};
constexpr DRegister d12{12};
constexpr DRegister d13{13};
constexpr DRegister d14{14};
constexpr DRegister d15{15};
constexpr DRegister d16{16};
constexpr DRegister d17{17};
constexpr DRegister d18{18};
constexpr DRegister d19{19};
constexpr DRegister d20{20};
constexpr DRegister d21{21};
constexpr DRegister d22{22};
constexpr DRegister d23{23};
constexpr DRegister d24{24};
constexpr DRegister d25{25};
constexpr DRegister d26{26};
constexpr DRegister d27{27};
constexpr DRegister d28{28};
constexpr DRegister d29{29};
constexpr DRegister d30{30};
constexpr DRegister d31{31};
struct QRegister {
uint8_t code;
// Encode code * 2.
uint8_t d() const { return (code & 0x8) >> 3; }
uint8_t vd() const { return (code & 0x7) << 1; }
};
static inline bool operator==(const QRegister lhs, const QRegister rhs) {
return lhs.code == rhs.code;
}
constexpr QRegister q0{0};
constexpr QRegister q1{1};
constexpr QRegister q2{2};
constexpr QRegister q3{3};
constexpr QRegister q4{4};
constexpr QRegister q5{5};
constexpr QRegister q6{6};
constexpr QRegister q7{7};
constexpr QRegister q8{8};
constexpr QRegister q9{9};
constexpr QRegister q10{10};
constexpr QRegister q11{11};
constexpr QRegister q12{12};
constexpr QRegister q13{13};
constexpr QRegister q14{14};
constexpr QRegister q15{15};
// SIMD register lists are used in a more restrictive way, compared to core
// registers, only consecutive registers are used as an operand to instruction.
template <typename RegType>
struct ConsecutiveRegisterList {
// End must be >= start.
ConsecutiveRegisterList(RegType s, RegType end)
: start(s),
length(end.code - s.code + 1) {}
explicit ConsecutiveRegisterList(RegType s, int len)
: start(s),
length(len) {}
ConsecutiveRegisterList(RegType start)
: ConsecutiveRegisterList(start, start) {}
RegType start;
uint8_t length;
};
using SRegisterList = ConsecutiveRegisterList<SRegister>;
using DRegisterList = ConsecutiveRegisterList<DRegister>;
static inline SRegisterList operator-(const SRegister lhs, const SRegister rhs) {
return SRegisterList(lhs, rhs);
}
static inline DRegisterList operator-(const DRegister lhs, const DRegister rhs) {
return DRegisterList(lhs, rhs);
}
struct QRegisterList {
QRegisterList(QRegister s) : start(s), length(1) {}
QRegisterList(QRegister s, QRegister end) : start(s), length(end.code - s.code + 1) {}
// Explicit conversion to DRegisterList.
explicit operator DRegisterList() const {
return DRegisterList({static_cast<uint8_t>(start.code * 2)}, length * 2);
}
QRegister start;
uint8_t length;
};
static inline QRegisterList operator-(const QRegister lhs, const QRegister rhs) {
return QRegisterList(lhs, rhs);
}
// A8.5 Addressing modes for memory access.
enum class AddressingMode {
// [<Rn>, <offset>], offset applied to address in Rn.
kOffset,
// Pre-indexed not used, so not implemented.
// [<Rn>], <offset>, address from Rn, offset applied, written back to Rn.
kPostIndexed,
};
// Memory operands, operands for memory access instructions. See
// "MemOperandHelper mem" for a nicer syntax that is closer to assembly.
class MemOperand {
public:
MemOperand(CoreRegister rn, int32_t offset)
: mode_(AddressingMode::kOffset),
rn_(rn),
offset_(offset) {}
MemOperand(CoreRegister rn, int32_t offset, AddressingMode mode)
: mode_(mode),
rn_(rn),
offset_(offset) {}
CoreRegister base() const { return rn_; }
int32_t offset() const { return offset_; }
AddressingMode mode() const { return mode_; }
// These are bits used for encoding, named based on the encoding description.
int32_t u() { return offset_ >= 0; }
int32_t p() { return mode_ != AddressingMode::kPostIndexed; }
// Note, kPostIndexed will write back, but doesn't need to set bit w.
int32_t w() { return 0; }
// Overload postfix increment to indicate a post-indexed addressing mode for load/stores.
MemOperand operator++(int) {
mode_ = AddressingMode::kPostIndexed;
return *this;
}
private:
AddressingMode mode_;
CoreRegister rn_;
int32_t offset_;
};
static inline bool operator==(const MemOperand lhs, const MemOperand rhs) {
return lhs.mode() == rhs.mode() && lhs.base() == rhs.base() && lhs.offset() == rhs.offset();
}
static inline MemOperand operator,(CoreRegister r, int32_t offset) {
return MemOperand(r, offset);
}
// Helper struct for some syntax sugar to look like native assembly, see mem.
struct MemOperandHelper {
const MemOperand operator[](MemOperand op) const { return op; }
MemOperand operator[](CoreRegister r) const { return MemOperand(r, 0); }
};
// Use "mem" (and its overload of array subscript operator) to get some syntax
// that looks closer to native assembly when accessing memory. For example:
// - ldr(r0, mem[rn, offset]); // offset
// - ldr(r0, mem[rn], offset); // post-indexed
constexpr MemOperandHelper mem;
// Conditional execution, only support AL (always) for now.
enum Condition : uint32_t {
kEQ = 0x00000000,
kNE = 0x10000000,
kCS = 0x20000000,
kCC = 0x30000000,
kMI = 0x40000000,
kPL = 0x50000000,
kVS = 0x60000000,
kVC = 0x70000000,
kHI = 0x80000000,
kLS = 0x90000000,
kGE = 0xa0000000,
kLT = 0xB0000000,
kGT = 0xC0000000,
kLE = 0xD0000000,
kAL = 0xE0000000,
kHS = kCS,
kLO = kCC,
};
enum DataSize {
k8 = 0,
k16 = 1,
k32 = 2,
};
// A simple AAarch32 assembler.
class Assembler : public AssemblerBase {
public:
using AssemblerBase::AssemblerBase;
void add(CoreRegister rn, CoreRegister rm) { add(rn, rn, rm); }
void add(CoreRegister rd, CoreRegister rn, CoreRegister rm);
// Only support uint8_t immediates for now, it simplifies encoding.
void add(CoreRegister rd, CoreRegister rn, uint8_t imm);
void adds(CoreRegister rd, CoreRegister rn, uint8_t imm);
void and_(CoreRegister rd, CoreRegister rn, uint8_t imm);
void b(Label& l) { b(kAL, l); }
void beq(Label& l) { b(kEQ, l); }
void bne(Label& l) { b(kNE, l); }
void bhi(Label& l) { b(kHI, l); }
void bhs(Label& l) { b(kHS, l); }
void blo(Label& l) { b(kLO, l); }
void bic(CoreRegister rd, CoreRegister rn, uint8_t imm);
void bx(CoreRegister rm);
// Cmp supports a subset of uint32_t offsets, see "A5.2.4 Modified immediate
// constants in ARM instructions", for simplicity we start with uint8_t, which
// is fully representation using a "rotation" of 0.
void cmp(CoreRegister rn, uint8_t imm);
void cmp(CoreRegister rn, CoreRegister rm);
void ldr(CoreRegister rt, MemOperand operand, int32_t offset);
void ldr(CoreRegister rt, MemOperand operand);
// LDRD <Rt>, <Rt2>, [<Rn>{, #+/-<imm>}].
void ldrd(CoreRegister rt, CoreRegister rt2, MemOperand op);
void mov(CoreRegister rd, CoreRegister rm);
void moveq(CoreRegister rd, CoreRegister rm) { mov(kEQ, rd, rm); }
void movlo(CoreRegister rd, CoreRegister rm) { mov(kLO, rd, rm); }
void movls(CoreRegister rd, CoreRegister rm) { mov(kLS, rd, rm); }
void nop();
void pld(MemOperand operand);
void pop(CoreRegisterList regs);
void push(CoreRegisterList regs);
void str(CoreRegister rt, MemOperand op);
void sub(CoreRegister rd, CoreRegister rn, uint8_t imm);
void sub(CoreRegister rd, CoreRegister rn, CoreRegister rm);
// Only support uint8_t immediates for now, it simplifies encoding.
void subs(CoreRegister rd, CoreRegister rn, uint8_t imm);
void tst(CoreRegister rn, uint8_t imm);
// SIMD instructions.
void vcmpe_f32(SRegister sd, SRegister sm);
void vcvt_f32_s32(QRegister qd, QRegister qm);
void vcvt_s32_f32(QRegister qd, QRegister qm);
void vcvtn_s32_f32(QRegister qd, QRegister qm);
void vdup_8(QRegister qd, DRegisterLane dm) { vdup(k8, qd, dm); }
void vdup_16(QRegister qd, DRegisterLane dm) { vdup(k16, qd, dm); }
void vdup_32(QRegister qd, DRegisterLane dm) { vdup(k32, qd, dm); }
void vext_8(QRegister qd, QRegister qn, QRegister qm, uint8_t imm4);
// VLD1.8 <list>, [<Rn>]{!} (multiple single elements).
void vld1_8(DRegisterList regs, MemOperand op) { vld1(k8, regs, op); }
void vld1_8(DRegisterList regs, MemOperand op, CoreRegister rm) { vld1(k8, regs, op, rm); }
void vld1_8(QRegisterList regs, MemOperand op) { vld1(k8, static_cast<DRegisterList>(regs), op); }
// VLD1.32 <list>, [<Rn>]{!} (multiple single elements).
void vld1_32(DRegisterList regs, MemOperand op) { vld1(k32, regs, op); }
void vld1_32(QRegisterList regs, MemOperand op) { vld1(k32, static_cast<DRegisterList>(regs), op); }
// VLD1.32 <list>, [<Rn>]{!} (single element to one lane).
void vld1_32(DRegisterLane dd, MemOperand op);
// VLD1.32 <list>, [<Rn>]{!} (single element to all lanes).
// We cannot differentiate the register list in C++ syntax, so use an instruction name similar to AArch64 LD1R.
void vld1r_32(DRegisterList regs, MemOperand op);
// VLDM <Rn>{!}, <list> (IA).
void vldm(MemOperand rn, SRegisterList regs);
void vldm(MemOperand rn, DRegisterList regs);
void vldr(SRegister sd, MemOperand op);
void vldr(DRegister dd, MemOperand op);
void vmax_f32(QRegister qd, QRegister qn, QRegister qm);
void vmax_s8(QRegister qd, QRegister qn, QRegister qm);
void vmin_f32(QRegister qd, QRegister qn, QRegister qm);
void vmin_s8(QRegister qd, QRegister qn, QRegister qm);
// VMLA.F32 <Sd>, <Sn>, <Sm>
void vmla_f32(SRegister sd, SRegister sn, SRegister sm);
// VMLA.F32 <Qd>, <Qn>, <Dm[x]>
void vmla_f32(QRegister qd, QRegister qn, DRegisterLane dm);
// VMLAL.S16 <Qd>, <Dn>, <Dm[x]>
void vmlal_s16(QRegister qd, DRegister dn, DRegisterLane dm);
// VMOV.F32 <Sd>, <Sm>; encoding A2.
void vmov(SRegister sd, SRegister sm);
// VMOV <Dm>, <Rt>, <Rt2>; encoding A1.
void vmov(DRegister dm, CoreRegister rt, CoreRegister rt2);
// VMOV <Dd>, <Dm>; encoding A1.
void vmov(DRegister dd, DRegister dm);
// VMOV <Qd>, <Qm>; encoding A1.
void vmov(QRegister qd, QRegister qm);
// VMOV_F32 <Sd>, <Sm>
void vmov_f32(SRegister sd, SRegister sm) { vmov_f32(kAL, sd, sm); }
void vmovpl_f32(SRegister sd, SRegister sm) { vmov_f32(kPL, sd, sm); }
void vmovmi_f32(SRegister sd, SRegister sm) { vmov_f32(kMI, sd, sm); }
// VMOV_F64 <Dd>, <Dm>
void vmov_f64(DRegister dd, DRegister dm);
// VMOVL.S8 <Qd>, <Dm>
void vmovl_s8(QRegister qd, DRegister dm);
void vmrs(CoreRegister rt, SpecialFPRegister spec_reg);
void vmul_f32(QRegister qd, QRegister qn, QRegister qm);
void vpop(DRegisterList regs);
void vpush(DRegisterList regs);
void vpush(SRegisterList regs);
void vqadd_s16(QRegister qd, QRegister qn, QRegister qm);
void vqdmulh_s32(QRegister qd, QRegister qn, DRegisterLane dm);
void vqmovn_s16(DRegister dd, QRegister qm);
void vqmovn_s32(DRegister dd, QRegister qm);
void vqshl_s32(QRegister qd, QRegister qm, QRegister qn);
void vrshl_s32(QRegister qd, QRegister qm, QRegister qn);
void vsdot_s8(QRegister qd, QRegister qn, DRegisterLane dm);
// VST1.8 <list>, [<Rn>]{!} (multiple single elements).
void vst1_8(DRegisterList regs, MemOperand op) { vst1(k8, regs, op); }
// VST1.8 <list>, [<Rn>]{!}, <Rm> (multiple single elements).
void vst1_8(DRegisterList regs, MemOperand op, CoreRegister rm) { vst1(k8, regs, op, rm); }
// VST1.8 <list>, [<Rn>]{!} (single element form one lane).
void vst1_8(DRegisterLane dd, MemOperand op) { vst1(k8, dd, op); }
// VST1.16 <list>, [<Rn>]{!} (multiple single elements).
void vst1_16(DRegisterList regs, MemOperand op) { vst1(k16, regs, op); }
// VST1.16 <list>, [<Rn>]{!}, <Rm> (multiple single elements).
void vst1_16(DRegisterList regs, MemOperand op, CoreRegister rm) { vst1(k16, regs, op, rm); }
// VST1.16 <list>, [<Rn>]{!} (single element form one lane).
void vst1_16(DRegisterLane dd, MemOperand op) { vst1(k16, dd, op); }
// VST1.32 <list>, [<Rn>]{!} (multiple single elements).
void vst1_32(DRegisterList regs, MemOperand op) { vst1(k32, regs, op); }
// VST1.32 <list>, [<Rn>]{!}, <Rm> (multiple single elements).
void vst1_32(DRegisterList regs, MemOperand op, CoreRegister rm) { vst1(k32, regs, op, rm); }
// VST1.32 <list>, [<Rn>]{!} (single element form one lane).
void vst1_32(DRegisterLane dd, MemOperand op) { vst1(k32, dd, op); }
// VSTM <Rn>{!}, <list>, consecutive 64-bit registers.
void vstm(MemOperand rn, DRegisterList regs);
// VSTR <Sd>, [Rn{, #+/-<imm>}], store single extension register to memory.
void vstr(SRegister rn, MemOperand op);
// Binds Label l to the current location in the code buffer.
void bind(Label& l);
// Align the cursor to specified number of bytes, `n` must be a power of 2.
void align(uint8_t n);
private:
void mov(Condition c, CoreRegister rd, CoreRegister rm);
void b(Condition c, Label& l);
void vdup(DataSize size, QRegister qd, DRegisterLane dm);
void vmov_f32(Condition c, SRegister sd, SRegister sm);
void vld1(DataSize size, DRegisterList regs, MemOperand op);
void vld1(DataSize size, DRegisterList regs, MemOperand op, CoreRegister rm);
void vst1(DataSize size, DRegisterList regs, MemOperand op);
void vst1(DataSize size, DRegisterList regs, MemOperand op, CoreRegister rm);
void vst1(DataSize size, DRegisterLane dd, MemOperand op);
};
} // namespace aarch32
} // namespace xnnpack