[InstSimplify] Teach InstSimplify how to simplify extractelement

llvm-svn: 242008
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index eab5887..67f68dc 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -357,3 +357,55 @@
 
   return Stride;
 }
+
+/// \brief Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
+  assert(V->getType()->isVectorTy() && "Not looking at a vector?");
+  VectorType *VTy = cast<VectorType>(V->getType());
+  unsigned Width = VTy->getNumElements();
+  if (EltNo >= Width)  // Out of range access.
+    return UndefValue::get(VTy->getElementType());
+
+  if (Constant *C = dyn_cast<Constant>(V))
+    return C->getAggregateElement(EltNo);
+
+  if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
+    // If this is an insert to a variable element, we don't know what it is.
+    if (!isa<ConstantInt>(III->getOperand(2)))
+      return nullptr;
+    unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
+
+    // If this is an insert to the element we are looking for, return the
+    // inserted value.
+    if (EltNo == IIElt)
+      return III->getOperand(1);
+
+    // Otherwise, the insertelement doesn't modify the value, recurse on its
+    // vector input.
+    return findScalarElement(III->getOperand(0), EltNo);
+  }
+
+  if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
+    unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
+    int InEl = SVI->getMaskValue(EltNo);
+    if (InEl < 0)
+      return UndefValue::get(VTy->getElementType());
+    if (InEl < (int)LHSWidth)
+      return findScalarElement(SVI->getOperand(0), InEl);
+    return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
+  }
+
+  // Extract a value from a vector add operation with a constant zero.
+  Value *Val = nullptr; Constant *Con = nullptr;
+  if (match(V,
+            llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val),
+                                      llvm::PatternMatch::m_Constant(Con)))) {
+    if (Con->getAggregateElement(EltNo)->isNullValue())
+      return findScalarElement(Val, EltNo);
+  }
+
+  // Otherwise, we don't know.
+  return nullptr;
+}