Refactor the VectorTargetTransformInfo interface.

Add getCostXXX calls for different families of opcodes, such as casts, arithmetic, cmp, etc.

Port the LoopVectorizer to the new API.

The LoopVectorizer now finds instructions which will remain uniform after vectorization. It uses this information when calculating the cost of these instructions.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@166836 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/TargetTransformImpl.cpp b/lib/Target/TargetTransformImpl.cpp
index 40184ed..d3ab105 100644
--- a/lib/Target/TargetTransformImpl.cpp
+++ b/lib/Target/TargetTransformImpl.cpp
@@ -126,7 +126,7 @@
 
 std::pair<unsigned, EVT>
 VectorTargetTransformImpl::getTypeLegalizationCost(LLVMContext &C,
-                                                         EVT Ty) const {
+                                                   EVT Ty) const {
   unsigned Cost = 1;
   // We keep legalizing the type until we find a legal kind. We assume that
   // the only operation that costs anything is the split. After splitting
@@ -135,7 +135,7 @@
     TargetLowering::LegalizeKind LK = TLI->getTypeConversion(C, Ty);
 
     if (LK.first == TargetLowering::TypeLegal)
-      return std::make_pair(Cost, LK.second);
+      return std::make_pair(Cost, Ty);
 
     if (LK.first == TargetLowering::TypeSplitVector)
       Cost *= 2;
@@ -146,44 +146,144 @@
 }
 
 unsigned
-VectorTargetTransformImpl::getInstrCost(unsigned Opcode, Type *Ty1,
-                                        Type *Ty2) const {
-  // Check if any of the operands are vector operands.
-  int ISD = InstructionOpcodeToISD(Opcode);
+VectorTargetTransformImpl::getScalarizationOverhead(Type *Ty,
+                                                    bool Insert,
+                                                    bool Extract) const {
+  assert (Ty->isVectorTy() && "Can only scalarize vectors");
+   unsigned Cost = 0;
 
-  // If we don't have any information about this instruction assume it costs 1.
-  if (ISD == 0)
-    return 1;
-
-  // Selects on vectors are actually vector selects.
-  if (ISD == ISD::SELECT) {
-    assert(Ty2 && "Ty2 must hold the condition type");
-    if (Ty2->isVectorTy())
-    ISD = ISD::VSELECT;
+  for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) {
+    if (Insert)
+      Cost += getVectorInstrCost(Instruction::InsertElement, Ty, i);
+    if (Extract)
+      Cost += getVectorInstrCost(Instruction::ExtractElement, Ty, i);
   }
 
-  assert(Ty1 && "We need to have at least one type");
+  return Cost;
+}
 
-  // From this stage we look at the legalized type.
-  std::pair<unsigned, EVT>  LT =
-  getTypeLegalizationCost(Ty1->getContext(), TLI->getValueType(Ty1));
+unsigned VectorTargetTransformImpl::getArithmeticInstrCost(unsigned Opcode,
+                                                           Type *Ty) const {
+  // Check if any of the operands are vector operands.
+  int ISD = InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
 
-  if (TLI->isOperationLegalOrCustom(ISD, LT.second)) {
+  std::pair<unsigned, EVT> LT =
+  getTypeLegalizationCost(Ty->getContext(), TLI->getValueType(Ty));
+
+  if (!TLI->isOperationExpand(ISD, LT.second)) {
     // The operation is legal. Assume it costs 1. Multiply
     // by the type-legalization overhead.
     return LT.first * 1;
   }
 
-  unsigned NumElem =
-    (LT.second.isVector() ? LT.second.getVectorNumElements() : 1);
+  // Else, assume that we need to scalarize this op.
+  if (Ty->isVectorTy()) {
+    unsigned Num = Ty->getVectorNumElements();
+    unsigned Cost = getArithmeticInstrCost(Opcode, Ty->getScalarType());
+    // return the cost of multiple scalar invocation plus the cost of inserting
+    // and extracting the values.
+    return getScalarizationOverhead(Ty, true, true) + Num * Cost;
+  }
 
-  // We will probably scalarize this instruction. Assume that the cost is the
-  // number of the vector elements.
-  return LT.first * NumElem * 1;
+  // We don't know anything about this scalar instruction.
+  return 1;
+}
+
+unsigned VectorTargetTransformImpl::getBroadcastCost(Type *Tp) const {
+  return 1;
+}
+
+unsigned VectorTargetTransformImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
+                                  Type *Src) const {
+  assert(Src->isVectorTy() == Dst->isVectorTy() && "Invalid input types");
+  int ISD = InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
+
+  std::pair<unsigned, EVT> SrcLT =
+  getTypeLegalizationCost(Src->getContext(), TLI->getValueType(Src));
+
+  std::pair<unsigned, EVT> DstLT =
+  getTypeLegalizationCost(Dst->getContext(), TLI->getValueType(Dst));
+
+  // If the cast is between same-sized registers, then the check is simple.
+  if (SrcLT.first == DstLT.first &&
+      SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) {
+    // Just check the op cost:
+    if (!TLI->isOperationExpand(ISD, DstLT.second)) {
+      // The operation is legal. Assume it costs 1. Multiply
+      // by the type-legalization overhead.
+      return SrcLT.first * 1;
+    }
+  }
+
+  // Otherwise, assume that the cast is scalarized.
+  if (Dst->isVectorTy()) {
+    unsigned Num = Dst->getVectorNumElements();
+    unsigned Cost = getCastInstrCost(Opcode, Src->getScalarType(),
+                                     Dst->getScalarType());
+    // return the cost of multiple scalar invocation plus the cost of inserting
+    // and extracting the values.
+    return getScalarizationOverhead(Dst, true, true) + Num * Cost;
+  }
+
+  // Unknown scalar opcode.
+  return 1;
+}
+
+unsigned VectorTargetTransformImpl::getCFInstrCost(unsigned Opcode) const {
+  return 1;
+}
+
+unsigned VectorTargetTransformImpl::getCmpSelInstrCost(unsigned Opcode,
+                                                       Type *ValTy,
+                                                       Type *CondTy) const {
+  int ISD = InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
+  
+  // Selects on vectors are actually vector selects.
+  if (ISD == ISD::SELECT) {
+    assert(CondTy && "CondTy must exist");
+    if (CondTy->isVectorTy())
+      ISD = ISD::VSELECT;
+  }
+
+  std::pair<unsigned, EVT> LT =
+  getTypeLegalizationCost(ValTy->getContext(), TLI->getValueType(ValTy));
+
+  if (!TLI->isOperationExpand(ISD, LT.second)) {
+    // The operation is legal. Assume it costs 1. Multiply
+    // by the type-legalization overhead.
+    return LT.first * 1;
+  }
+
+  // Otherwise, assume that the cast is scalarized.
+  if (ValTy->isVectorTy()) {
+    unsigned Num = ValTy->getVectorNumElements();
+    if (CondTy)
+      CondTy = CondTy->getScalarType();
+    unsigned Cost = getCmpSelInstrCost(Opcode, ValTy->getScalarType(),
+                                       CondTy);
+
+    // return the cost of multiple scalar invocation plus the cost of inserting
+    // and extracting the values.
+    return getScalarizationOverhead(ValTy, true, false) + Num * Cost;
+  }
+
+  // Unknown scalar opcode. 
+  return 1;
+}
+
+/// Returns the expected cost of Vector Insert and Extract.
+unsigned VectorTargetTransformImpl::getVectorInstrCost(unsigned Opcode,
+                                                       Type *Val,
+                                                       unsigned Index) const {
+  return 1;
 }
 
 unsigned
-VectorTargetTransformImpl::getBroadcastCost(Type *Tp) const {
+VectorTargetTransformImpl::getInstrCost(unsigned Opcode, Type *Ty1,
+                                        Type *Ty2) const {
   return 1;
 }
 
@@ -191,17 +291,15 @@
 VectorTargetTransformImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
                                            unsigned Alignment,
                                            unsigned AddressSpace) const {
-  // From this stage we look at the legalized type.
-  std::pair<unsigned, EVT>  LT =
+  std::pair<unsigned, EVT> LT =
   getTypeLegalizationCost(Src->getContext(), TLI->getValueType(Src));
+
   // Assume that all loads of legal types cost 1.
   return LT.first;
 }
 
 unsigned
 VectorTargetTransformImpl::getNumberOfParts(Type *Tp) const {
-  std::pair<unsigned, EVT>  LT =
-  getTypeLegalizationCost(Tp->getContext(), TLI->getValueType(Tp));
-  return LT.first;
+  return TLI->getNumRegisters(Tp->getContext(), TLI->getValueType(Tp));
 }