[SLP] Support for horizontal min/max reduction.
SLP vectorizer supports horizontal reductions for Add/FAdd binary
operations. Patch adds support for horizontal min/max reductions.
Function getReductionCost() is split to getArithmeticReductionCost() for
binary operation reductions and getMinMaxReductionCost() for min/max
reductions.
Patch fixes PR26956.
Differential revision: https://reviews.llvm.org/D27846
llvm-svn: 312791
diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp
index 071e23e..47513f3 100644
--- a/llvm/lib/Analysis/CostModel.cpp
+++ b/llvm/lib/Analysis/CostModel.cpp
@@ -186,26 +186,56 @@
}
namespace {
+/// Kind of the reduction data.
+enum ReductionKind {
+ RK_None, /// Not a reduction.
+ RK_Arithmetic, /// Binary reduction data.
+ RK_MinMax, /// Min/max reduction data.
+ RK_UnsignedMinMax, /// Unsigned min/max reduction data.
+};
/// Contains opcode + LHS/RHS parts of the reduction operations.
struct ReductionData {
- explicit ReductionData() = default;
- ReductionData(unsigned Opcode, Value *LHS, Value *RHS)
- : Opcode(Opcode), LHS(LHS), RHS(RHS) {}
+ ReductionData() = delete;
+ ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
+ assert(Kind != RK_None && "expected binary or min/max reduction only.");
+ }
unsigned Opcode = 0;
Value *LHS = nullptr;
Value *RHS = nullptr;
+ ReductionKind Kind = RK_None;
+ bool hasSameData(ReductionData &RD) const {
+ return Kind == RD.Kind && Opcode == RD.Opcode;
+ }
};
} // namespace
static Optional<ReductionData> getReductionData(Instruction *I) {
Value *L, *R;
if (m_BinOp(m_Value(L), m_Value(R)).match(I))
- return ReductionData(I->getOpcode(), L, R);
+ return ReductionData(RK_Arithmetic, I->getOpcode(), L, R);
+ if (auto *SI = dyn_cast<SelectInst>(I)) {
+ if (m_SMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_SMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_OrdFMax(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_MinMax, CI->getOpcode(), L, R);
+ }
+ if (m_UMin(m_Value(L), m_Value(R)).match(SI) ||
+ m_UMax(m_Value(L), m_Value(R)).match(SI)) {
+ auto *CI = cast<CmpInst>(SI->getCondition());
+ return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R);
+ }
+ }
return llvm::None;
}
-static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
- unsigned NumLevels) {
+static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
+ unsigned Level,
+ unsigned NumLevels) {
// Match one level of pairwise operations.
// %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
// <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
@@ -213,24 +243,24 @@
// <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
// %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
if (!I)
- return false;
+ return RK_None;
assert(I->getType()->isVectorTy() && "Expecting a vector type");
Optional<ReductionData> RD = getReductionData(I);
if (!RD)
- return false;
+ return RK_None;
ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS);
if (!LS && Level)
- return false;
+ return RK_None;
ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS);
if (!RS && Level)
- return false;
+ return RK_None;
// On level 0 we can omit one shufflevector instruction.
if (!Level && !RS && !LS)
- return false;
+ return RK_None;
// Shuffle inputs must match.
Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
@@ -239,7 +269,7 @@
if (NextLevelOpR && NextLevelOpL) {
// If we have two shuffles their operands must match.
if (NextLevelOpL != NextLevelOpR)
- return false;
+ return RK_None;
NextLevelOp = NextLevelOpL;
} else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
@@ -250,45 +280,47 @@
// %NextLevelOpL = shufflevector %R, <1, undef ...>
// %BinOp = fadd %NextLevelOpL, %R
if (NextLevelOpL && NextLevelOpL != RD->RHS)
- return false;
+ return RK_None;
else if (NextLevelOpR && NextLevelOpR != RD->LHS)
- return false;
+ return RK_None;
NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS;
- } else
- return false;
+ } else {
+ return RK_None;
+ }
// Check that the next levels binary operation exists and matches with the
// current one.
if (Level + 1 != NumLevels) {
Optional<ReductionData> NextLevelRD =
getReductionData(cast<Instruction>(NextLevelOp));
- if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode)
- return false;
+ if (!NextLevelRD || !RD->hasSameData(*NextLevelRD))
+ return RK_None;
}
// Shuffle mask for pairwise operation must match.
if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) {
if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level))
- return false;
+ return RK_None;
} else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) {
if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level))
- return false;
- } else
- return false;
+ return RK_None;
+ } else {
+ return RK_None;
+ }
if (++Level == NumLevels)
- return true;
+ return RD->Kind;
// Match next level.
return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level,
NumLevels);
}
-static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
- unsigned &Opcode, Type *&Ty) {
+static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
if (!EnableReduxCost)
- return false;
+ return RK_None;
// Need to extract the first element.
ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -296,19 +328,19 @@
if (CI)
Idx = CI->getZExtValue();
if (Idx != 0)
- return false;
+ return RK_None;
auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
if (!RdxStart)
- return false;
+ return RK_None;
Optional<ReductionData> RD = getReductionData(RdxStart);
if (!RD)
- return false;
+ return RK_None;
Type *VecTy = RdxStart->getType();
unsigned NumVecElems = VecTy->getVectorNumElements();
if (!isPowerOf2_32(NumVecElems))
- return false;
+ return RK_None;
// We look for a sequence of shuffle,shuffle,add triples like the following
// that builds a pairwise reduction tree.
@@ -328,13 +360,14 @@
// <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
// %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
// %r = extractelement <4 x float> %bin.rdx8, i32 0
- if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)))
- return false;
+ if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) ==
+ RK_None)
+ return RK_None;
Opcode = RD->Opcode;
Ty = VecTy;
- return true;
+ return RD->Kind;
}
static std::pair<Value *, ShuffleVectorInst *>
@@ -348,10 +381,11 @@
return std::make_pair(L, S);
}
-static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
- unsigned &Opcode, Type *&Ty) {
+static ReductionKind
+matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
+ unsigned &Opcode, Type *&Ty) {
if (!EnableReduxCost)
- return false;
+ return RK_None;
// Need to extract the first element.
ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -359,19 +393,19 @@
if (CI)
Idx = CI->getZExtValue();
if (Idx != 0)
- return false;
+ return RK_None;
auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
if (!RdxStart)
- return false;
+ return RK_None;
Optional<ReductionData> RD = getReductionData(RdxStart);
if (!RD)
- return false;
+ return RK_None;
Type *VecTy = ReduxRoot->getOperand(0)->getType();
unsigned NumVecElems = VecTy->getVectorNumElements();
if (!isPowerOf2_32(NumVecElems))
- return false;
+ return RK_None;
// We look for a sequence of shuffles and adds like the following matching one
// fadd, shuffle vector pair at a time.
@@ -391,10 +425,10 @@
while (NumVecElemsRemain - 1) {
// Check for the right reduction operation.
if (!RdxOp)
- return false;
+ return RK_None;
Optional<ReductionData> RDLevel = getReductionData(RdxOp);
- if (!RDLevel || RDLevel->Opcode != RD->Opcode)
- return false;
+ if (!RDLevel || !RDLevel->hasSameData(*RD))
+ return RK_None;
Value *NextRdxOp;
ShuffleVectorInst *Shuffle;
@@ -403,9 +437,9 @@
// Check the current reduction operation and the shuffle use the same value.
if (Shuffle == nullptr)
- return false;
+ return RK_None;
if (Shuffle->getOperand(0) != NextRdxOp)
- return false;
+ return RK_None;
// Check that shuffle masks matches.
for (unsigned j = 0; j != MaskStart; ++j)
@@ -415,7 +449,7 @@
SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
if (ShuffleMask != Mask)
- return false;
+ return RK_None;
RdxOp = dyn_cast<Instruction>(NextRdxOp);
NumVecElemsRemain /= 2;
@@ -424,7 +458,7 @@
Opcode = RD->Opcode;
Ty = VecTy;
- return true;
+ return RD->Kind;
}
unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
@@ -519,13 +553,36 @@
unsigned ReduxOpCode;
Type *ReduxType;
- if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+ switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
/*IsPairwiseForm=*/false);
+ case RK_MinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/false, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
}
- if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+
+ switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+ case RK_Arithmetic:
return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
/*IsPairwiseForm=*/true);
+ case RK_MinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/false);
+ case RK_UnsignedMinMax:
+ return TTI->getMinMaxReductionCost(
+ ReduxType, CmpInst::makeCmpResultType(ReduxType),
+ /*IsPairwiseForm=*/true, /*IsUnsigned=*/true);
+ case RK_None:
+ break;
}
return TTI->getVectorInstrCost(I->getOpcode(),
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index e091381..8673b1b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -484,6 +484,15 @@
return Cost;
}
+int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy,
+ bool IsPairwiseForm,
+ bool IsUnsigned) const {
+ int Cost =
+ TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
unsigned
TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const {
return TTIImpl->getCostOfKeepingLiveOverCall(Tys);
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 871a38d..79f192c 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1999,6 +1999,152 @@
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
}
+int X86TTIImpl::getMinMaxReductionCost(Type *ValTy, Type *CondTy,
+ bool IsPairwise, bool IsUnsigned) {
+ std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
+
+ MVT MTy = LT.second;
+
+ int ISD;
+ if (ValTy->isIntOrIntVectorTy()) {
+ ISD = IsUnsigned ? ISD::UMIN : ISD::SMIN;
+ } else {
+ assert(ValTy->isFPOrFPVectorTy() &&
+ "Expected float point or integer vector type.");
+ ISD = ISD::FMINNUM;
+ }
+
+ // We use the Intel Architecture Code Analyzer(IACA) to measure the throughput
+ // and make it as the cost.
+
+ static const CostTblEntry SSE42CostTblPairWise[] = {
+ {ISD::FMINNUM, MVT::v2f64, 3},
+ {ISD::FMINNUM, MVT::v4f32, 2},
+ {ISD::SMIN, MVT::v2i64, 7}, // The data reported by the IACA is "6.8"
+ {ISD::UMIN, MVT::v2i64, 8}, // The data reported by the IACA is "8.6"
+ {ISD::SMIN, MVT::v4i32, 1}, // The data reported by the IACA is "1.5"
+ {ISD::UMIN, MVT::v4i32, 2}, // The data reported by the IACA is "1.8"
+ {ISD::SMIN, MVT::v8i16, 2},
+ {ISD::UMIN, MVT::v8i16, 2},
+ };
+
+ static const CostTblEntry AVX1CostTblPairWise[] = {
+ {ISD::FMINNUM, MVT::v4f32, 1},
+ {ISD::FMINNUM, MVT::v4f64, 1},
+ {ISD::FMINNUM, MVT::v8f32, 2},
+ {ISD::SMIN, MVT::v2i64, 3},
+ {ISD::UMIN, MVT::v2i64, 3},
+ {ISD::SMIN, MVT::v4i32, 1},
+ {ISD::UMIN, MVT::v4i32, 1},
+ {ISD::SMIN, MVT::v8i16, 1},
+ {ISD::UMIN, MVT::v8i16, 1},
+ {ISD::SMIN, MVT::v8i32, 3},
+ {ISD::UMIN, MVT::v8i32, 3},
+ };
+
+ static const CostTblEntry AVX2CostTblPairWise[] = {
+ {ISD::SMIN, MVT::v4i64, 2},
+ {ISD::UMIN, MVT::v4i64, 2},
+ {ISD::SMIN, MVT::v8i32, 1},
+ {ISD::UMIN, MVT::v8i32, 1},
+ {ISD::SMIN, MVT::v16i16, 1},
+ {ISD::UMIN, MVT::v16i16, 1},
+ {ISD::SMIN, MVT::v32i8, 2},
+ {ISD::UMIN, MVT::v32i8, 2},
+ };
+
+ static const CostTblEntry AVX512CostTblPairWise[] = {
+ {ISD::FMINNUM, MVT::v8f64, 1},
+ {ISD::FMINNUM, MVT::v16f32, 2},
+ {ISD::SMIN, MVT::v8i64, 2},
+ {ISD::UMIN, MVT::v8i64, 2},
+ {ISD::SMIN, MVT::v16i32, 1},
+ {ISD::UMIN, MVT::v16i32, 1},
+ };
+
+ static const CostTblEntry SSE42CostTblNoPairWise[] = {
+ {ISD::FMINNUM, MVT::v2f64, 3},
+ {ISD::FMINNUM, MVT::v4f32, 3},
+ {ISD::SMIN, MVT::v2i64, 7}, // The data reported by the IACA is "6.8"
+ {ISD::UMIN, MVT::v2i64, 9}, // The data reported by the IACA is "8.6"
+ {ISD::SMIN, MVT::v4i32, 1}, // The data reported by the IACA is "1.5"
+ {ISD::UMIN, MVT::v4i32, 2}, // The data reported by the IACA is "1.8"
+ {ISD::SMIN, MVT::v8i16, 1}, // The data reported by the IACA is "1.5"
+ {ISD::UMIN, MVT::v8i16, 2}, // The data reported by the IACA is "1.8"
+ };
+
+ static const CostTblEntry AVX1CostTblNoPairWise[] = {
+ {ISD::FMINNUM, MVT::v4f32, 1},
+ {ISD::FMINNUM, MVT::v4f64, 1},
+ {ISD::FMINNUM, MVT::v8f32, 1},
+ {ISD::SMIN, MVT::v2i64, 3},
+ {ISD::UMIN, MVT::v2i64, 3},
+ {ISD::SMIN, MVT::v4i32, 1},
+ {ISD::UMIN, MVT::v4i32, 1},
+ {ISD::SMIN, MVT::v8i16, 1},
+ {ISD::UMIN, MVT::v8i16, 1},
+ {ISD::SMIN, MVT::v8i32, 2},
+ {ISD::UMIN, MVT::v8i32, 2},
+ };
+
+ static const CostTblEntry AVX2CostTblNoPairWise[] = {
+ {ISD::SMIN, MVT::v4i64, 1},
+ {ISD::UMIN, MVT::v4i64, 1},
+ {ISD::SMIN, MVT::v8i32, 1},
+ {ISD::UMIN, MVT::v8i32, 1},
+ {ISD::SMIN, MVT::v16i16, 1},
+ {ISD::UMIN, MVT::v16i16, 1},
+ {ISD::SMIN, MVT::v32i8, 1},
+ {ISD::UMIN, MVT::v32i8, 1},
+ };
+
+ static const CostTblEntry AVX512CostTblNoPairWise[] = {
+ {ISD::FMINNUM, MVT::v8f64, 1},
+ {ISD::FMINNUM, MVT::v16f32, 2},
+ {ISD::SMIN, MVT::v8i64, 1},
+ {ISD::UMIN, MVT::v8i64, 1},
+ {ISD::SMIN, MVT::v16i32, 1},
+ {ISD::UMIN, MVT::v16i32, 1},
+ };
+
+ if (IsPairwise) {
+ if (ST->hasAVX512())
+ if (const auto *Entry = CostTableLookup(AVX512CostTblPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasAVX2())
+ if (const auto *Entry = CostTableLookup(AVX2CostTblPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasAVX())
+ if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasSSE42())
+ if (const auto *Entry = CostTableLookup(SSE42CostTblPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+ } else {
+ if (ST->hasAVX512())
+ if (const auto *Entry =
+ CostTableLookup(AVX512CostTblNoPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasAVX2())
+ if (const auto *Entry = CostTableLookup(AVX2CostTblNoPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasAVX())
+ if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+
+ if (ST->hasSSE42())
+ if (const auto *Entry = CostTableLookup(SSE42CostTblNoPairWise, ISD, MTy))
+ return LT.first * Entry->Cost;
+ }
+
+ return BaseT::getMinMaxReductionCost(ValTy, CondTy, IsPairwise, IsUnsigned);
+}
+
/// \brief Calculate the cost of materializing a 64-bit value. This helper
/// method might only calculate a fraction of a larger immediate. Therefore it
/// is valid to return a cost of ZERO.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index a8edc46..a7f500d 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -96,6 +96,9 @@
int getArithmeticReductionCost(unsigned Opcode, Type *Ty,
bool IsPairwiseForm);
+ int getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwiseForm,
+ bool IsUnsigned);
+
int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy,
unsigned Factor, ArrayRef<unsigned> Indices,
unsigned Alignment, unsigned AddressSpace);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b147445..53b1d87 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4627,11 +4627,17 @@
// Use map vector to make stable output.
MapVector<Instruction *, Value *> ExtraArgs;
+ /// Kind of the reduction data.
+ enum ReductionKind {
+ RK_None, /// Not a reduction.
+ RK_Arithmetic, /// Binary reduction data.
+ RK_Min, /// Minimum reduction data.
+ RK_UMin, /// Unsigned minimum reduction data.
+ RK_Max, /// Maximum reduction data.
+ RK_UMax, /// Unsigned maximum reduction data.
+ };
/// Contains info about operation, like its opcode, left and right operands.
- struct OperationData {
- /// true if the operation is a reduced value, false if reduction operation.
- bool IsReducedValue = false;
-
+ class OperationData {
/// Opcode of the instruction.
unsigned Opcode = 0;
@@ -4640,12 +4646,21 @@
/// Right operand of the reduction operation.
Value *RHS = nullptr;
+ /// Kind of the reduction operation.
+ ReductionKind Kind = RK_None;
+ /// True if float point min/max reduction has no NaNs.
+ bool NoNaN = false;
/// Checks if the reduction operation can be vectorized.
bool isVectorizable() const {
return LHS && RHS &&
- // We currently only support adds.
- (Opcode == Instruction::Add || Opcode == Instruction::FAdd);
+ // We currently only support adds && min/max reductions.
+ ((Kind == RK_Arithmetic &&
+ (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) ||
+ ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
+ (Kind == RK_Min || Kind == RK_Max)) ||
+ (Opcode == Instruction::ICmp &&
+ (Kind == RK_UMin || Kind == RK_UMax)));
}
public:
@@ -4653,43 +4668,90 @@
/// Construction for reduced values. They are identified by opcode only and
/// don't have associated LHS/RHS values.
- explicit OperationData(Value *V) : IsReducedValue(true) {
+ explicit OperationData(Value *V) : Kind(RK_None) {
if (auto *I = dyn_cast<Instruction>(V))
Opcode = I->getOpcode();
}
- /// Constructor for binary reduction operations with opcode and its left and
+ /// Constructor for reduction operations with opcode and its left and
/// right operands.
- OperationData(unsigned Opcode, Value *LHS, Value *RHS)
- : Opcode(Opcode), LHS(LHS), RHS(RHS) {}
-
+ OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind,
+ bool NoNaN = false)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) {
+ assert(Kind != RK_None && "One of the reduction operations is expected.");
+ }
explicit operator bool() const { return Opcode; }
/// Get the index of the first operand.
unsigned getFirstOperandIndex() const {
assert(!!*this && "The opcode is not set.");
+ switch (Kind) {
+ case RK_Min:
+ case RK_UMin:
+ case RK_Max:
+ case RK_UMax:
+ return 1;
+ case RK_Arithmetic:
+ case RK_None:
+ break;
+ }
return 0;
}
/// Total number of operands in the reduction operation.
unsigned getNumberOfOperands() const {
- assert(!IsReducedValue && !!*this && LHS && RHS &&
+ assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
- return 2;
+ switch (Kind) {
+ case RK_Arithmetic:
+ return 2;
+ case RK_Min:
+ case RK_UMin:
+ case RK_Max:
+ case RK_UMax:
+ return 3;
+ case RK_None:
+ llvm_unreachable("Reduction kind is not set");
+ }
}
/// Expected number of uses for reduction operations/reduced values.
unsigned getRequiredNumberOfUses() const {
- assert(!IsReducedValue && !!*this && LHS && RHS &&
+ assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
- return 1;
+ switch (Kind) {
+ case RK_Arithmetic:
+ return 1;
+ case RK_Min:
+ case RK_UMin:
+ case RK_Max:
+ case RK_UMax:
+ return 2;
+ case RK_None:
+ llvm_unreachable("Reduction kind is not set");
+ }
}
/// Checks if instruction is associative and can be vectorized.
bool isAssociative(Instruction *I) const {
- assert(!IsReducedValue && *this && LHS && RHS &&
+ assert(Kind != RK_None && *this && LHS && RHS &&
"Expected reduction operation.");
- return I->isAssociative();
+ switch (Kind) {
+ case RK_Arithmetic:
+ return I->isAssociative();
+ case RK_Min:
+ case RK_Max:
+ return Opcode == Instruction::ICmp ||
+ cast<Instruction>(I->getOperand(0))->hasUnsafeAlgebra();
+ case RK_UMin:
+ case RK_UMax:
+ assert(Opcode == Instruction::ICmp &&
+ "Only integer compare operation is expected.");
+ return true;
+ case RK_None:
+ break;
+ }
+ llvm_unreachable("Reduction kind is not set");
}
/// Checks if the reduction operation can be vectorized.
@@ -4700,18 +4762,17 @@
/// Checks if two operation data are both a reduction op or both a reduced
/// value.
bool operator==(const OperationData &OD) {
- assert(((IsReducedValue != OD.IsReducedValue) ||
- ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
+ assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
"One of the comparing operations is incorrect.");
- return this == &OD ||
- (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode);
+ return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
}
bool operator!=(const OperationData &OD) { return !(*this == OD); }
void clear() {
- IsReducedValue = false;
Opcode = 0;
LHS = nullptr;
RHS = nullptr;
+ Kind = RK_None;
+ NoNaN = false;
}
/// Get the opcode of the reduction operation.
@@ -4720,16 +4781,81 @@
return Opcode;
}
+ /// Get kind of reduction data.
+ ReductionKind getKind() const { return Kind; }
Value *getLHS() const { return LHS; }
Value *getRHS() const { return RHS; }
+ Type *getConditionType() const {
+ switch (Kind) {
+ case RK_Arithmetic:
+ return nullptr;
+ case RK_Min:
+ case RK_Max:
+ case RK_UMin:
+ case RK_UMax:
+ return CmpInst::makeCmpResultType(LHS->getType());
+ case RK_None:
+ break;
+ }
+ llvm_unreachable("Reduction kind is not set");
+ }
/// Creates reduction operation with the current opcode.
Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const {
- assert(!IsReducedValue &&
- (Opcode == Instruction::FAdd || Opcode == Instruction::Add) &&
- "Expected add|fadd reduction operation.");
- return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
- Name);
+ assert(isVectorizable() &&
+ "Expected add|fadd or min/max reduction operation.");
+ Value *Cmp;
+ switch (Kind) {
+ case RK_Arithmetic:
+ return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
+ Name);
+ case RK_Min:
+ Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS)
+ : Builder.CreateFCmpOLT(LHS, RHS);
+ break;
+ case RK_Max:
+ Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS)
+ : Builder.CreateFCmpOGT(LHS, RHS);
+ break;
+ case RK_UMin:
+ assert(Opcode == Instruction::ICmp && "Expected integer types.");
+ Cmp = Builder.CreateICmpULT(LHS, RHS);
+ break;
+ case RK_UMax:
+ assert(Opcode == Instruction::ICmp && "Expected integer types.");
+ Cmp = Builder.CreateICmpUGT(LHS, RHS);
+ break;
+ case RK_None:
+ llvm_unreachable("Unknown reduction operation.");
+ }
+ return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+ }
+ TargetTransformInfo::ReductionFlags getFlags() const {
+ TargetTransformInfo::ReductionFlags Flags;
+ Flags.NoNaN = NoNaN;
+ switch (Kind) {
+ case RK_Arithmetic:
+ break;
+ case RK_Min:
+ Flags.IsSigned = Opcode == Instruction::ICmp;
+ Flags.IsMaxOp = false;
+ break;
+ case RK_Max:
+ Flags.IsSigned = Opcode == Instruction::ICmp;
+ Flags.IsMaxOp = true;
+ break;
+ case RK_UMin:
+ Flags.IsSigned = false;
+ Flags.IsMaxOp = false;
+ break;
+ case RK_UMax:
+ Flags.IsSigned = false;
+ Flags.IsMaxOp = true;
+ break;
+ case RK_None:
+ llvm_unreachable("Reduction kind is not set");
+ }
+ return Flags;
}
};
@@ -4771,8 +4897,32 @@
Value *LHS;
Value *RHS;
- if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V))
- return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS);
+ if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) {
+ return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS,
+ RK_Arithmetic);
+ }
+ if (auto *Select = dyn_cast<SelectInst>(V)) {
+ // Look for a min/max pattern.
+ if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
+ } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(Instruction::ICmp, LHS, RHS, RK_Min);
+ } else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) ||
+ m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(
+ Instruction::FCmp, LHS, RHS, RK_Min,
+ cast<Instruction>(Select->getCondition())->hasNoNaNs());
+ } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
+ } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(Instruction::ICmp, LHS, RHS, RK_Max);
+ } else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) ||
+ m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
+ return OperationData(
+ Instruction::FCmp, LHS, RHS, RK_Max,
+ cast<Instruction>(Select->getCondition())->hasNoNaNs());
+ }
+ }
return OperationData(V);
}
@@ -4965,8 +5115,9 @@
if (VectorizedTree) {
Builder.SetCurrentDebugLocation(Loc);
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, ReducedSubTree);
- VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx");
+ VectorizedTree, ReducedSubTree,
+ ReductionData.getKind());
+ VectorizedTree = VectReductionData.createOp(Builder, "op.rdx");
propagateIRFlags(VectorizedTree, ReductionOps);
} else
VectorizedTree = ReducedSubTree;
@@ -4980,7 +5131,8 @@
auto *I = cast<Instruction>(ReducedVals[i]);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, I);
+ VectorizedTree, I,
+ ReductionData.getKind());
VectorizedTree = VectReductionData.createOp(Builder);
propagateIRFlags(VectorizedTree, ReductionOps);
}
@@ -4991,8 +5143,9 @@
for (auto *I : Pair.second) {
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, Pair.first);
- VectorizedTree = VectReductionData.createOp(Builder, "bin.extra");
+ VectorizedTree, Pair.first,
+ ReductionData.getKind());
+ VectorizedTree = VectReductionData.createOp(Builder, "op.extra");
propagateIRFlags(VectorizedTree, I);
}
}
@@ -5013,19 +5166,58 @@
Type *ScalarTy = FirstReducedVal->getType();
Type *VecTy = VectorType::get(ScalarTy, ReduxWidth);
- int PairwiseRdxCost =
- TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
- /*IsPairwiseForm=*/true);
- int SplittingRdxCost =
- TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
- /*IsPairwiseForm=*/false);
+ int PairwiseRdxCost;
+ int SplittingRdxCost;
+ bool IsUnsigned = true;
+ switch (ReductionData.getKind()) {
+ case RK_Arithmetic:
+ PairwiseRdxCost =
+ TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+ /*IsPairwiseForm=*/true);
+ SplittingRdxCost =
+ TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+ /*IsPairwiseForm=*/false);
+ break;
+ case RK_Min:
+ case RK_Max:
+ IsUnsigned = false;
+ case RK_UMin:
+ case RK_UMax: {
+ Type *VecCondTy = CmpInst::makeCmpResultType(VecTy);
+ PairwiseRdxCost =
+ TTI->getMinMaxReductionCost(VecTy, VecCondTy,
+ /*IsPairwiseForm=*/true, IsUnsigned);
+ SplittingRdxCost =
+ TTI->getMinMaxReductionCost(VecTy, VecCondTy,
+ /*IsPairwiseForm=*/false, IsUnsigned);
+ break;
+ }
+ case RK_None:
+ llvm_unreachable("Expected arithmetic or min/max reduction operation");
+ }
IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost;
int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost;
- int ScalarReduxCost =
- (ReduxWidth - 1) *
- TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
+ int ScalarReduxCost;
+ switch (ReductionData.getKind()) {
+ case RK_Arithmetic:
+ ScalarReduxCost =
+ TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
+ break;
+ case RK_Min:
+ case RK_Max:
+ case RK_UMin:
+ case RK_UMax:
+ ScalarReduxCost =
+ TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) +
+ TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
+ CmpInst::makeCmpResultType(ScalarTy));
+ break;
+ case RK_None:
+ llvm_unreachable("Expected arithmetic or min/max reduction operation");
+ }
+ ScalarReduxCost *= (ReduxWidth - 1);
DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost
<< " for reduction that starts with " << *FirstReducedVal
@@ -5047,7 +5239,7 @@
if (!IsPairwiseReduction)
return createSimpleTargetReduction(
Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
- TargetTransformInfo::ReductionFlags(), RedOps);
+ ReductionData.getFlags(), RedOps);
Value *TmpVec = VectorizedValue;
for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) {
@@ -5062,8 +5254,8 @@
TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),
"rdx.shuf.r");
OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
- RightShuf);
- TmpVec = VectReductionData.createOp(Builder, "bin.rdx");
+ RightShuf, ReductionData.getKind());
+ TmpVec = VectReductionData.createOp(Builder, "op.rdx");
propagateIRFlags(TmpVec, RedOps);
}
@@ -5224,9 +5416,11 @@
auto *Inst = dyn_cast<Instruction>(V);
if (!Inst)
continue;
- if (auto *BI = dyn_cast<BinaryOperator>(Inst)) {
+ auto *BI = dyn_cast<BinaryOperator>(Inst);
+ auto *SI = dyn_cast<SelectInst>(Inst);
+ if (BI || SI) {
HorizontalReduction HorRdx;
- if (HorRdx.matchAssociativeReduction(P, BI)) {
+ if (HorRdx.matchAssociativeReduction(P, Inst)) {
if (HorRdx.tryToReduce(R, TTI)) {
Res = true;
// Set P to nullptr to avoid re-analysis of phi node in
@@ -5235,7 +5429,7 @@
continue;
}
}
- if (P) {
+ if (P && BI) {
Inst = dyn_cast<Instruction>(BI->getOperand(0));
if (Inst == P)
Inst = dyn_cast<Instruction>(BI->getOperand(1));