use Constant::getAggregateElement to simplify a bunch of code.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@148934 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp
index 2926508..6e8b434 100644
--- a/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/lib/Transforms/IPO/GlobalOpt.cpp
@@ -276,38 +276,6 @@
   return false;
 }
 
-static Constant *getAggregateConstantElement(Constant *Agg, Constant *Idx) {
-  ConstantInt *CI = dyn_cast<ConstantInt>(Idx);
-  if (!CI) return 0;
-  unsigned IdxV = CI->getZExtValue();
-
-  if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Agg)) {
-    if (IdxV < CS->getNumOperands()) return CS->getOperand(IdxV);
-  } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Agg)) {
-    if (IdxV < CA->getNumOperands()) return CA->getOperand(IdxV);
-  } else if (ConstantVector *CP = dyn_cast<ConstantVector>(Agg)) {
-    if (IdxV < CP->getNumOperands()) return CP->getOperand(IdxV);
-  } else if (isa<ConstantAggregateZero>(Agg)) {
-    if (StructType *STy = dyn_cast<StructType>(Agg->getType())) {
-      if (IdxV < STy->getNumElements())
-        return Constant::getNullValue(STy->getElementType(IdxV));
-    } else if (SequentialType *STy =
-               dyn_cast<SequentialType>(Agg->getType())) {
-      return Constant::getNullValue(STy->getElementType());
-    }
-  } else if (isa<UndefValue>(Agg)) {
-    if (StructType *STy = dyn_cast<StructType>(Agg->getType())) {
-      if (IdxV < STy->getNumElements())
-        return UndefValue::get(STy->getElementType(IdxV));
-    } else if (SequentialType *STy =
-               dyn_cast<SequentialType>(Agg->getType())) {
-      return UndefValue::get(STy->getElementType());
-    }
-  }
-  return 0;
-}
-
-
 /// CleanupConstantGlobalUsers - We just marked GV constant.  Loop over all
 /// users of the global, cleaning up the obvious ones.  This is largely just a
 /// quick scan over the use list to clean up the easy and obvious cruft.  This
@@ -520,8 +488,7 @@
     NewGlobals.reserve(STy->getNumElements());
     const StructLayout &Layout = *TD.getStructLayout(STy);
     for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
-      Constant *In = getAggregateConstantElement(Init,
-                    ConstantInt::get(Type::getInt32Ty(STy->getContext()), i));
+      Constant *In = Init->getAggregateElement(i);
       assert(In && "Couldn't get element of initializer?");
       GlobalVariable *NGV = new GlobalVariable(STy->getElementType(i), false,
                                                GlobalVariable::InternalLinkage,
@@ -553,8 +520,7 @@
     uint64_t EltSize = TD.getTypeAllocSize(STy->getElementType());
     unsigned EltAlign = TD.getABITypeAlignment(STy->getElementType());
     for (unsigned i = 0, e = NumElements; i != e; ++i) {
-      Constant *In = getAggregateConstantElement(Init,
-                    ConstantInt::get(Type::getInt32Ty(Init->getContext()), i));
+      Constant *In = Init->getAggregateElement(i);
       assert(In && "Couldn't get element of initializer?");
 
       GlobalVariable *NGV = new GlobalVariable(STy->getElementType(), false,
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 7b1617d..6873d15 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -834,59 +834,39 @@
   }
 
   UndefElts = 0;
-  if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
+  
+  // Handle ConstantAggregateZero, ConstantVector, ConstantDataSequential.
+  if (Constant *C = dyn_cast<Constant>(V)) {
+    // Check if this is identity. If so, return 0 since we are not simplifying
+    // anything.
+    if (DemandedElts.isAllOnesValue())
+      return 0;
+
     Type *EltTy = cast<VectorType>(V->getType())->getElementType();
     Constant *Undef = UndefValue::get(EltTy);
-
-    std::vector<Constant*> Elts;
-    for (unsigned i = 0; i != VWidth; ++i)
+    
+    SmallVector<Constant*, 16> Elts;
+    for (unsigned i = 0; i != VWidth; ++i) {
       if (!DemandedElts[i]) {   // If not demanded, set to undef.
         Elts.push_back(Undef);
         UndefElts.setBit(i);
-      } else if (isa<UndefValue>(CV->getOperand(i))) {   // Already undef.
+        continue;
+      }
+      
+      Constant *Elt = C->getAggregateElement(i);
+      if (Elt == 0) return 0;
+      
+      if (isa<UndefValue>(Elt)) {   // Already undef.
         Elts.push_back(Undef);
         UndefElts.setBit(i);
       } else {                               // Otherwise, defined.
-        Elts.push_back(CV->getOperand(i));
+        Elts.push_back(Elt);
       }
-
+    }
+    
     // If we changed the constant, return it.
     Constant *NewCV = ConstantVector::get(Elts);
-    return NewCV != CV ? NewCV : 0;
-  }
-  
-  if (ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
-    // Check if this is identity. If so, return 0 since we are not simplifying
-    // anything.
-    if (DemandedElts.isAllOnesValue())
-      return 0;
-
-    // Simplify to a ConstantVector where the non-demanded elements are undef.
-    Constant *Undef = UndefValue::get(CDV->getElementType());
-    
-    SmallVector<Constant*, 16> Elts;
-    for (unsigned i = 0; i != VWidth; ++i)
-      Elts.push_back(DemandedElts[i] ? CDV->getElementAsConstant(i) : Undef);
-    UndefElts = DemandedElts ^ EltMask;
-    return ConstantVector::get(Elts);
-
-  }
-  
-  if (ConstantAggregateZero *CAZ = dyn_cast<ConstantAggregateZero>(V)) {
-    // Check if this is identity. If so, return 0 since we are not simplifying
-    // anything.
-    if (DemandedElts.isAllOnesValue())
-      return 0;
-    
-    // Simplify the CAZ to a ConstantVector where the non-demanded elements are
-    // set to undef.
-    Constant *Zero = CAZ->getSequentialElement();
-    Constant *Undef = UndefValue::get(Zero->getType());
-    SmallVector<Constant*, 16> Elts;
-    for (unsigned i = 0; i != VWidth; ++i)
-      Elts.push_back(DemandedElts[i] ? Zero : Undef);
-    UndefElts = DemandedElts ^ EltMask;
-    return ConstantVector::get(Elts);
+    return NewCV != C ? NewCV : 0;
   }
   
   // Limit search depth.