Revert an earlier change to SIGN_EXTEND_INREG for vectors. The VTSDNode
really does need to be a vector type, because
TargetLowering::getOperationAction for SIGN_EXTEND_INREG uses that type,
and it needs to be able to distinguish between vectors and scalars.

Also, fix some more issues with legalization of vector casts.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@93043 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index ca8c17b..cb1a0d6 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1741,7 +1741,7 @@
     return;
   case ISD::SIGN_EXTEND_INREG: {
     EVT EVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
-    unsigned EBits = EVT.getSizeInBits();
+    unsigned EBits = EVT.getScalarType().getSizeInBits();
 
     // Sign extension.  Compute the demanded bits in the result that are not
     // present in the input.
@@ -1786,7 +1786,7 @@
     if (ISD::isZEXTLoad(Op.getNode())) {
       LoadSDNode *LD = cast<LoadSDNode>(Op);
       EVT VT = LD->getMemoryVT();
-      unsigned MemBits = VT.getSizeInBits();
+      unsigned MemBits = VT.getScalarType().getSizeInBits();
       KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits) & Mask;
     }
     return;
@@ -2025,7 +2025,8 @@
 
   case ISD::SIGN_EXTEND_INREG:
     // Max of the input and what this extends.
-    Tmp = cast<VTSDNode>(Op.getOperand(1))->getVT().getSizeInBits();
+    Tmp =
+      cast<VTSDNode>(Op.getOperand(1))->getVT().getScalarType().getSizeInBits();
     Tmp = VTBits-Tmp+1;
 
     Tmp2 = ComputeNumSignBits(Op.getOperand(0), Depth+1);
@@ -2169,10 +2170,10 @@
     switch (ExtType) {
     default: break;
     case ISD::SEXTLOAD:    // '17' bits known
-      Tmp = LD->getMemoryVT().getSizeInBits();
+      Tmp = LD->getMemoryVT().getScalarType().getSizeInBits();
       return VTBits-Tmp+1;
     case ISD::ZEXTLOAD:    // '16' bits known
-      Tmp = LD->getMemoryVT().getSizeInBits();
+      Tmp = LD->getMemoryVT().getScalarType().getSizeInBits();
       return VTBits-Tmp;
     }
   }
@@ -2664,6 +2665,12 @@
     assert(VT == N1.getValueType() && "Not an inreg round!");
     assert(VT.isFloatingPoint() && EVT.isFloatingPoint() &&
            "Cannot FP_ROUND_INREG integer types");
+    assert(EVT.isVector() == VT.isVector() &&
+           "FP_ROUND_INREG type should be vector iff the operand "
+           "type is vector!");
+    assert((!EVT.isVector() ||
+            EVT.getVectorNumElements() == VT.getVectorNumElements()) &&
+           "Vector element counts must match in FP_ROUND_INREG");
     assert(EVT.bitsLE(VT) && "Not rounding down!");
     if (cast<VTSDNode>(N2)->getVT() == VT) return N1;  // Not actually rounding.
     break;
@@ -2693,15 +2700,18 @@
     assert(VT == N1.getValueType() && "Not an inreg extend!");
     assert(VT.isInteger() && EVT.isInteger() &&
            "Cannot *_EXTEND_INREG FP types");
-    assert(!EVT.isVector() &&
-           "SIGN_EXTEND_INREG type should be the vector element type rather "
-           "than the vector type!");
-    assert(EVT.bitsLE(VT.getScalarType()) && "Not extending!");
+    assert(EVT.isVector() == VT.isVector() &&
+           "SIGN_EXTEND_INREG type should be vector iff the operand "
+           "type is vector!");
+    assert((!EVT.isVector() ||
+            EVT.getVectorNumElements() == VT.getVectorNumElements()) &&
+           "Vector element counts must match in SIGN_EXTEND_INREG");
+    assert(EVT.bitsLE(VT) && "Not extending!");
     if (EVT == VT) return N1;  // Not actually extending
 
     if (N1C) {
       APInt Val = N1C->getAPIntValue();
-      unsigned FromBits = EVT.getSizeInBits();
+      unsigned FromBits = EVT.getScalarType().getSizeInBits();
       Val <<= Val.getBitWidth()-FromBits;
       Val = Val.ashr(Val.getBitWidth()-FromBits);
       return getConstant(Val, VT);
@@ -4109,7 +4119,7 @@
       if (ConstantSDNode *AndRHS = dyn_cast<ConstantSDNode>(N3.getOperand(1))) {
         // If the and is only masking out bits that cannot effect the shift,
         // eliminate the and.
-        unsigned NumBits = VT.getSizeInBits()*2;
+        unsigned NumBits = VT.getScalarType().getSizeInBits()*2;
         if ((AndRHS->getValue() & (NumBits-1)) == NumBits-1)
           return getNode(Opcode, DL, VT, N1, N2, N3.getOperand(0));
       }
@@ -5946,6 +5956,13 @@
       Scalars.push_back(getNode(N->getOpcode(), dl, EltVT, Operands[0],
                                 getShiftAmountOperand(Operands[1])));
       break;
+    case ISD::SIGN_EXTEND_INREG:
+    case ISD::FP_ROUND_INREG: {
+      EVT ExtVT = cast<VTSDNode>(Operands[1])->getVT().getVectorElementType();
+      Scalars.push_back(getNode(N->getOpcode(), dl, EltVT,
+                                Operands[0],
+                                getValueType(ExtVT)));
+    }
     }
   }