[RISCV] Add support for the various RISC-V FMA instruction variants
Adds support for the various RISC-V FMA instructions (fmadd, fmsub, fnmsub, fnmadd).
The criteria for choosing whether a fused add or subtract is used, as well as
whether the product is negated or not, is whether some of the arguments to the
llvm.fma.* intrinsic are negated or not. In the tests, extraneous fadd
instructions were added to avoid the negation being performed using a xor
trick, which prevented the proper FMA forms from being selected and thus
tested.
The FMA instruction patterns might seem incorrect (e.g., fnmadd: -rs1 * rs2 -
rs3), but they should be correct. The misleading names were inherited from
MIPS, where the negation happens after computing the sum.
The llvm.fmuladd.* intrinsics still do not generate RISC-V FMA instructions,
as that depends on TargetLowering::isFMAFasterthanFMulAndFAdd.
Some comments in the test files about what type of instructions are there
tested were updated, to better reflect the current content of those test
files.
Differential Revision: https://reviews.llvm.org/D54205
Patch by Luís Marques.
llvm-svn: 349023
diff --git a/llvm/test/CodeGen/RISCV/float-arith.ll b/llvm/test/CodeGen/RISCV/float-arith.ll
index f3ec61b..ab87447 100644
--- a/llvm/test/CodeGen/RISCV/float-arith.ll
+++ b/llvm/test/CodeGen/RISCV/float-arith.ll
@@ -2,6 +2,10 @@
; RUN: llc -mtriple=riscv32 -mattr=+f -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=RV32IF %s
+; These tests are each targeted at a particular RISC-V FPU instruction. Most
+; other files in this folder exercise LLVM IR instructions that don't directly
+; match a RISC-V instruction.
+
define float @fadd_s(float %a, float %b) nounwind {
; RV32IF-LABEL: fadd_s:
; RV32IF: # %bb.0:
@@ -186,3 +190,78 @@
%2 = zext i1 %1 to i32
ret i32 %2
}
+
+declare float @llvm.fma.f32(float, float, float)
+
+define float @fmadd_s(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fmadd_s:
+; RV32IF: # %bb.0:
+; RV32IF-NEXT: fmv.w.x ft0, a2
+; RV32IF-NEXT: fmv.w.x ft1, a1
+; RV32IF-NEXT: fmv.w.x ft2, a0
+; RV32IF-NEXT: fmadd.s ft0, ft2, ft1, ft0
+; RV32IF-NEXT: fmv.x.w a0, ft0
+; RV32IF-NEXT: ret
+ %1 = call float @llvm.fma.f32(float %a, float %b, float %c)
+ ret float %1
+}
+
+define float @fmsub_s(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fmsub_s:
+; RV32IF: # %bb.0:
+; RV32IF-NEXT: fmv.w.x ft0, a2
+; RV32IF-NEXT: lui a2, %hi(.LCPI15_0)
+; RV32IF-NEXT: addi a2, a2, %lo(.LCPI15_0)
+; RV32IF-NEXT: flw ft1, 0(a2)
+; RV32IF-NEXT: fadd.s ft0, ft0, ft1
+; RV32IF-NEXT: fmv.w.x ft1, a1
+; RV32IF-NEXT: fmv.w.x ft2, a0
+; RV32IF-NEXT: fmsub.s ft0, ft2, ft1, ft0
+; RV32IF-NEXT: fmv.x.w a0, ft0
+; RV32IF-NEXT: ret
+ %c_ = fadd float 0.0, %c ; avoid negation using xor
+ %negc = fsub float -0.0, %c_
+ %1 = call float @llvm.fma.f32(float %a, float %b, float %negc)
+ ret float %1
+}
+
+define float @fnmadd_s(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fnmadd_s:
+; RV32IF: # %bb.0:
+; RV32IF-NEXT: fmv.w.x ft0, a2
+; RV32IF-NEXT: lui a2, %hi(.LCPI16_0)
+; RV32IF-NEXT: addi a2, a2, %lo(.LCPI16_0)
+; RV32IF-NEXT: flw ft1, 0(a2)
+; RV32IF-NEXT: fadd.s ft0, ft0, ft1
+; RV32IF-NEXT: fmv.w.x ft2, a0
+; RV32IF-NEXT: fadd.s ft1, ft2, ft1
+; RV32IF-NEXT: fmv.w.x ft2, a1
+; RV32IF-NEXT: fnmadd.s ft0, ft1, ft2, ft0
+; RV32IF-NEXT: fmv.x.w a0, ft0
+; RV32IF-NEXT: ret
+ %a_ = fadd float 0.0, %a
+ %c_ = fadd float 0.0, %c
+ %nega = fsub float -0.0, %a_
+ %negc = fsub float -0.0, %c_
+ %1 = call float @llvm.fma.f32(float %nega, float %b, float %negc)
+ ret float %1
+}
+
+define float @fnmsub_s(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fnmsub_s:
+; RV32IF: # %bb.0:
+; RV32IF-NEXT: fmv.w.x ft0, a0
+; RV32IF-NEXT: lui a0, %hi(.LCPI17_0)
+; RV32IF-NEXT: addi a0, a0, %lo(.LCPI17_0)
+; RV32IF-NEXT: flw ft1, 0(a0)
+; RV32IF-NEXT: fadd.s ft0, ft0, ft1
+; RV32IF-NEXT: fmv.w.x ft1, a2
+; RV32IF-NEXT: fmv.w.x ft2, a1
+; RV32IF-NEXT: fnmsub.s ft0, ft0, ft2, ft1
+; RV32IF-NEXT: fmv.x.w a0, ft0
+; RV32IF-NEXT: ret
+ %a_ = fadd float 0.0, %a
+ %nega = fsub float -0.0, %a_
+ %1 = call float @llvm.fma.f32(float %nega, float %b, float %c)
+ ret float %1
+}