[ConstantFold] Support vector index when factoring out GEP index into preceding dimensions

Follow-up of r316824. This patch supports the vector type for both current and
previous index when factoring out the current one into the previous one.

Differential Revision: https://reviews.llvm.org/D39556

llvm-svn: 319683
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index c826f75..90b1030 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -2210,17 +2210,17 @@
   SmallVector<Constant *, 8> NewIdxs;
   Type *Ty = PointeeTy;
   Type *Prev = C->getType();
-  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  bool Unknown =
+      !isa<ConstantInt>(Idxs[0]) && !isa<ConstantDataVector>(Idxs[0]);
   for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
-    auto *CI = dyn_cast<ConstantInt>(Idxs[i]);
-    if (!CI) {
+    if (!isa<ConstantInt>(Idxs[i]) && !isa<ConstantDataVector>(Idxs[i])) {
       // We don't know if it's in range or not.
       Unknown = true;
       continue;
     }
-    if (!isa<ConstantInt>(Idxs[i - 1]))
-      // FIXME: add the support of cosntant vector index.
+    if (!isa<ConstantInt>(Idxs[i - 1]) && !isa<ConstantDataVector>(Idxs[i - 1]))
+      // Skip if the type of the previous index is not supported.
       continue;
     if (InRangeIndex && i == *InRangeIndex + 1) {
       // If an index is marked inrange, we cannot apply this canonicalization to
@@ -2238,46 +2238,91 @@
       Unknown = true;
       continue;
     }
-    if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
-      // It's in range, skip to the next index.
-      continue;
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
+      if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
+        // It's in range, skip to the next index.
+        continue;
+      if (CI->getSExtValue() < 0) {
+        // It's out of range and negative, don't try to factor it.
+        Unknown = true;
+        continue;
+      }
+    } else {
+      auto *CV = cast<ConstantDataVector>(Idxs[i]);
+      bool InRange = true;
+      for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) {
+        auto *CI = cast<ConstantInt>(CV->getElementAsConstant(I));
+        InRange &= isIndexInRangeOfArrayType(STy->getNumElements(), CI);
+        if (CI->getSExtValue() < 0) {
+          Unknown = true;
+          break;
+        }
+      }
+      if (InRange || Unknown)
+        // It's in range, skip to the next index.
+        // It's out of range and negative, don't try to factor it.
+        continue;
+    }
     if (isa<StructType>(Prev)) {
       // It's out of range, but the prior dimension is a struct
       // so we can't do anything about it.
       Unknown = true;
       continue;
     }
-    if (CI->getSExtValue() < 0) {
-      // It's out of range and negative, don't try to factor it.
-      Unknown = true;
-      continue;
-    }
     // It's out of range, but we can factor it into the prior
     // dimension.
     NewIdxs.resize(Idxs.size());
     // Determine the number of elements in our sequential type.
     uint64_t NumElements = STy->getArrayNumElements();
 
-    ConstantInt *Factor = ConstantInt::get(CI->getType(), NumElements);
-    NewIdxs[i] = ConstantExpr::getSRem(CI, Factor);
+    // Expand the current index or the previous index to a vector from a scalar
+    // if necessary.
+    Constant *CurrIdx = cast<Constant>(Idxs[i]);
+    auto *PrevIdx =
+        NewIdxs[i - 1] ? NewIdxs[i - 1] : cast<Constant>(Idxs[i - 1]);
+    bool IsCurrIdxVector = CurrIdx->getType()->isVectorTy();
+    bool IsPrevIdxVector = PrevIdx->getType()->isVectorTy();
+    bool UseVector = IsCurrIdxVector || IsPrevIdxVector;
 
-    Constant *PrevIdx = NewIdxs[i-1] ? NewIdxs[i-1] :
-                           cast<Constant>(Idxs[i - 1]);
-    Constant *Div = ConstantExpr::getSDiv(CI, Factor);
+    if (!IsCurrIdxVector && IsPrevIdxVector)
+      CurrIdx = ConstantDataVector::getSplat(
+          PrevIdx->getType()->getVectorNumElements(), CurrIdx);
+
+    if (!IsPrevIdxVector && IsCurrIdxVector)
+      PrevIdx = ConstantDataVector::getSplat(
+          CurrIdx->getType()->getVectorNumElements(), PrevIdx);
+
+    Constant *Factor =
+        ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
+    if (UseVector)
+      Factor = ConstantDataVector::getSplat(
+          IsPrevIdxVector ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements(),
+          Factor);
+
+    NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
+
+    Constant *Div = ConstantExpr::getSDiv(CurrIdx, Factor);
 
     unsigned CommonExtendedWidth =
-        std::max(PrevIdx->getType()->getIntegerBitWidth(),
-                 Div->getType()->getIntegerBitWidth());
+        std::max(PrevIdx->getType()->getScalarSizeInBits(),
+                 Div->getType()->getScalarSizeInBits());
     CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
 
     // Before adding, extend both operands to i64 to avoid
     // overflow trouble.
-    if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
-      PrevIdx = ConstantExpr::getSExt(
-          PrevIdx, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
-    if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
-      Div = ConstantExpr::getSExt(
-          Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+    Type *ExtendedTy = Type::getIntNTy(Div->getContext(), CommonExtendedWidth);
+    if (UseVector)
+      ExtendedTy = VectorType::get(
+          ExtendedTy, IsPrevIdxVector
+                          ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements());
+
+    if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);
+
+    if (!Div->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      Div = ConstantExpr::getSExt(Div, ExtendedTy);
 
     NewIdxs[i - 1] = ConstantExpr::getAdd(PrevIdx, Div);
   }