Handle subtract in expression classifier


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@548 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/Expressions.cpp b/lib/Analysis/Expressions.cpp
index e41182a..2235cda 100644
--- a/lib/Analysis/Expressions.cpp
+++ b/lib/Analysis/Expressions.cpp
@@ -15,6 +15,37 @@
 using namespace opt;  // Get all the constant handling stuff
 using namespace analysis;
 
+ExprType::ExprType(Value *Val) {
+  if (Val && Val->isConstant() && Val->getType()->isIntegral()) {
+    Offset = (ConstPoolInt*)Val->castConstant();
+    Var = 0;
+    ExprTy = Constant;
+  } else {
+    Var = Val; Offset = 0;
+    ExprTy = Var ? Linear : Constant;
+  }
+  Scale = 0;
+}
+
+ExprType::ExprType(const ConstPoolInt *scale, Value *var, 
+		   const ConstPoolInt *offset) {
+  Scale = scale; Var = var; Offset = offset;
+  ExprTy = Scale ? ScaledLinear : (Var ? Linear : Constant);
+  if (Scale && Scale->equalsInt(0)) {  // Simplify 0*Var + const
+    Scale = 0; Var = 0;
+    ExprTy = Constant;
+  }
+}
+
+
+const Type *ExprType::getExprType(const Type *Default) const {
+  if (Offset) return Offset->getType();
+  if (Scale) return Scale->getType();
+  return Var ? Var->getType() : Default;
+}
+
+
+
 class DefVal {
   const ConstPoolInt * const Val;
   const Type * const Ty;
@@ -37,18 +68,8 @@
 };
 
 
-// getIntegralConstant - Wrapper around the ConstPoolInt member of the same
-// name.  This method first checks to see if the desired constant is already in
-// the constant pool.  If it is, it is quickly recycled, otherwise a new one
-// is allocated and added to the constant pool.
-//
-static ConstPoolInt *getIntegralConstant(unsigned char V, const Type *Ty) {
-  return ConstPoolInt::get(Ty, V);
-}
-
 static ConstPoolInt *getUnsignedConstant(uint64_t V, const Type *Ty) {
   if (Ty->isPointerType()) Ty = Type::ULongTy;
-
   return Ty->isSigned() ? ConstPoolSInt::get(Ty, V) : ConstPoolUInt::get(Ty, V);
 }
 
@@ -77,11 +98,8 @@
 
   // Check to see if the result is one of the special cases that we want to
   // recognize...
-  if (ResultI->equalsInt(DefOne ? 1 : 0)) {
-    // Yes it is, simply delete the constant and return null.
-    delete ResultI;
-    return 0;
-  }
+  if (ResultI->equalsInt(DefOne ? 1 : 0))
+    return 0;  // Yes it is, simply return null.
 
   return ResultI;
 }
@@ -95,11 +113,11 @@
 inline const ConstPoolInt *operator+(const DefOne &L, const DefOne &R) {
   if (L == 0) {
     if (R == 0)
-      return getIntegralConstant(2, L.getType());
+      return getUnsignedConstant(2, L.getType());
     else
-      return Add(getIntegralConstant(1, L.getType()), R, true);
+      return Add(getUnsignedConstant(1, L.getType()), R, true);
   } else if (R == 0) {
-    return Add(L, getIntegralConstant(1, L.getType()), true);
+    return Add(L, getUnsignedConstant(1, L.getType()), true);
   }
   return Add(L, R, true);
 }
@@ -130,11 +148,8 @@
 
   // Check to see if the result is one of the special cases that we want to
   // recognize...
-  if (ResultI->equalsInt(DefOne ? 1 : 0)) {
-    // Yes it is, simply delete the constant and return null.
-    delete ResultI;
-    return 0;
-  }
+  if (ResultI->equalsInt(DefOne ? 1 : 0))
+    return 0; // Yes it is, simply return null.
 
   return ResultI;
 }
@@ -144,7 +159,7 @@
   return Mul(L, R, false);
 }
 inline const ConstPoolInt *operator*(const DefOne &L, const DefZero &R) {
-  if (R == 0) return getIntegralConstant(0, L.getType());
+  if (R == 0) return getUnsignedConstant(0, L.getType());
   if (L == 0) return R->equalsInt(1) ? 0 : R.getVal();
   return Mul(L, R, false);
 }
@@ -152,6 +167,45 @@
   return R*L;
 }
 
+// handleAddition - Add two expressions together, creating a new expression that
+// represents the composite of the two...
+//
+static ExprType handleAddition(ExprType Left, ExprType Right, Value *V) {
+  const Type *Ty = V->getType();
+  if (Left.ExprTy > Right.ExprTy)
+    swap(Left, Right);   // Make left be simpler than right
+
+  switch (Left.ExprTy) {
+  case ExprType::Constant:
+    return ExprType(Right.Scale, Right.Var,
+		    DefZero(Right.Offset, Ty) + DefZero(Left.Offset, Ty));
+  case ExprType::Linear:              // RHS side must be linear or scaled
+  case ExprType::ScaledLinear:        // RHS must be scaled
+    if (Left.Var != Right.Var)        // Are they the same variables?
+      return ExprType(V);             //   if not, we don't know anything!
+
+    return ExprType(DefOne(Left.Scale  , Ty) + DefOne(Right.Scale , Ty),
+		    Left.Var,
+		    DefZero(Left.Offset, Ty) + DefZero(Right.Offset, Ty));
+  default:
+    assert(0 && "Dont' know how to handle this case!");
+    return ExprType();
+  }
+}
+
+// negate - Negate the value of the specified expression...
+//
+static inline ExprType negate(const ExprType &E, Value *V) {
+  const Type *Ty = V->getType();
+  const Type *ETy = E.getExprType(Ty);
+  ConstPoolInt *Zero   = getUnsignedConstant(0, ETy);
+  ConstPoolInt *One    = getUnsignedConstant(1, ETy);
+  ConstPoolInt *NegOne = (ConstPoolInt*)(*Zero - *One);
+  if (NegOne == 0) return V;  // Couldn't subtract values...
+
+  return ExprType(DefOne (E.Scale , Ty) * NegOne, E.Var,
+		  DefZero(E.Offset, Ty) * NegOne);
+}
 
 
 // ClassifyExpression: Analyze an expression to determine the complexity of the
@@ -174,7 +228,7 @@
     ConstPoolVal *CPV = Expr->castConstantAsserting();
     if (CPV->getType()->isIntegral()) { // It's an integral constant!
       ConstPoolInt *CPI = (ConstPoolInt*)Expr;
-      return ExprType(CPI->equalsInt(0) ? 0 : (ConstPoolInt*)Expr);
+      return ExprType(CPI->equalsInt(0) ? 0 : CPI);
     }
     return Expr;
   }
@@ -186,24 +240,15 @@
   case Instruction::Add: {
     ExprType Left (ClassifyExpression(I->getOperand(0)));
     ExprType Right(ClassifyExpression(I->getOperand(1)));
-    if (Left.ExprTy > Right.ExprTy)
-      swap(Left, Right);   // Make left be simpler than right
-
-    switch (Left.ExprTy) {
-    case ExprType::Constant:
-      return ExprType(Right.Scale, Right.Var,
-		      DefZero(Right.Offset, Ty) + DefZero(Left.Offset, Ty));
-    case ExprType::Linear:        // RHS side must be linear or scaled
-    case ExprType::ScaledLinear:  // RHS must be scaled
-      if (Left.Var != Right.Var)        // Are they the same variables?
-	return ExprType(I);       //   if not, we don't know anything!
-
-      return ExprType( DefOne(Left.Scale , Ty) +  DefOne(Right.Scale , Ty),
-		      Left.Var,
-	              DefZero(Left.Offset, Ty) + DefZero(Right.Offset, Ty));
-    }
+    return handleAddition(Left, Right, I);
   }  // end case Instruction::Add
 
+  case Instruction::Sub: {
+    ExprType Left (ClassifyExpression(I->getOperand(0)));
+    ExprType Right(ClassifyExpression(I->getOperand(1)));
+    return handleAddition(Left, negate(Right, I), I);
+  }  // end case Instruction::Sub
+
   case Instruction::Shl: { 
     ExprType Right(ClassifyExpression(I->getOperand(1)));
     if (Right.ExprTy != ExprType::Constant) break;
@@ -240,12 +285,13 @@
     const ConstPoolInt *Offs = Src.Offset;
     if (Offs == 0) return ExprType();
 
-    if (I->getType()->isPointerType())
-      return Offs;  // Pointer types do not lose precision
+    const Type *DestTy = I->getType();
+    if (DestTy->isPointerType())
+      DestTy = Type::ULongTy;  // Pointer types are represented as ulong
 
-    assert(I->getType()->isIntegral() && "Can only handle integral types!");
+    assert(DestTy->isIntegral() && "Can only handle integral types!");
 
-    const ConstPoolVal *CPV =ConstRules::get(*Offs)->castTo(Offs, I->getType());
+    const ConstPoolVal *CPV =ConstRules::get(*Offs)->castTo(Offs, DestTy);
     if (!CPV) return I;
     assert(CPV->getType()->isIntegral() && "Must have an integral type!");
     return (ConstPoolInt*)CPV;