move the [Can]EvaluateInDifferentType functions out to InstCombineCasts.cpp


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92469 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index c5ad10f..b9ecaea 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -17,7 +17,215 @@
 using namespace llvm;
 using namespace PatternMatch;
 
-// FIXME: InstCombiner::EvaluateInDifferentType!
+/// CanEvaluateInDifferentType - Return true if we can take the specified value
+/// and return it as type Ty without inserting any new casts and without
+/// changing the computed value.  This is used by code that tries to decide
+/// whether promoting or shrinking integer operations to wider or smaller types
+/// will allow us to eliminate a truncate or extend.
+///
+/// This is a truncation operation if Ty is smaller than V->getType(), or an
+/// extension operation if Ty is larger.
+///
+/// If CastOpc is a truncation, then Ty will be a type smaller than V.  We
+/// should return true if trunc(V) can be computed by computing V in the smaller
+/// type.  If V is an instruction, then trunc(inst(x,y)) can be computed as
+/// inst(trunc(x),trunc(y)), which only makes sense if x and y can be
+/// efficiently truncated.
+///
+/// If CastOpc is a sext or zext, we are asking if the low bits of the value can
+/// bit computed in a larger type, which is then and'd or sext_in_reg'd to get
+/// the final result.
+bool InstCombiner::CanEvaluateInDifferentType(Value *V, const Type *Ty,
+                                              unsigned CastOpc,
+                                              int &NumCastsRemoved){
+  // We can always evaluate constants in another type.
+  if (isa<Constant>(V))
+    return true;
+  
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I) return false;
+  
+  const Type *OrigTy = V->getType();
+  
+  // If this is an extension or truncate, we can often eliminate it.
+  if (isa<TruncInst>(I) || isa<ZExtInst>(I) || isa<SExtInst>(I)) {
+    // If this is a cast from the destination type, we can trivially eliminate
+    // it, and this will remove a cast overall.
+    if (I->getOperand(0)->getType() == Ty) {
+      // If the first operand is itself a cast, and is eliminable, do not count
+      // this as an eliminable cast.  We would prefer to eliminate those two
+      // casts first.
+      if (!isa<CastInst>(I->getOperand(0)) && I->hasOneUse())
+        ++NumCastsRemoved;
+      return true;
+    }
+  }
+
+  // We can't extend or shrink something that has multiple uses: doing so would
+  // require duplicating the instruction in general, which isn't profitable.
+  if (!I->hasOneUse()) return false;
+
+  unsigned Opc = I->getOpcode();
+  switch (Opc) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+    // These operators can all arbitrarily be extended or truncated.
+    return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
+                                      NumCastsRemoved) &&
+           CanEvaluateInDifferentType(I->getOperand(1), Ty, CastOpc,
+                                      NumCastsRemoved);
+
+  case Instruction::UDiv:
+  case Instruction::URem: {
+    // UDiv and URem can be truncated if all the truncated bits are zero.
+    uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    if (BitWidth < OrigBitWidth) {
+      APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth);
+      if (MaskedValueIsZero(I->getOperand(0), Mask) &&
+          MaskedValueIsZero(I->getOperand(1), Mask)) {
+        return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
+                                          NumCastsRemoved) &&
+               CanEvaluateInDifferentType(I->getOperand(1), Ty, CastOpc,
+                                          NumCastsRemoved);
+      }
+    }
+    break;
+  }
+  case Instruction::Shl:
+    // If we are truncating the result of this SHL, and if it's a shift of a
+    // constant amount, we can always perform a SHL in a smaller type.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      uint32_t BitWidth = Ty->getScalarSizeInBits();
+      if (BitWidth < OrigTy->getScalarSizeInBits() &&
+          CI->getLimitedValue(BitWidth) < BitWidth)
+        return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
+                                          NumCastsRemoved);
+    }
+    break;
+  case Instruction::LShr:
+    // If this is a truncate of a logical shr, we can truncate it to a smaller
+    // lshr iff we know that the bits we would otherwise be shifting in are
+    // already zeros.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+      uint32_t BitWidth = Ty->getScalarSizeInBits();
+      if (BitWidth < OrigBitWidth &&
+          MaskedValueIsZero(I->getOperand(0),
+            APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) &&
+          CI->getLimitedValue(BitWidth) < BitWidth) {
+        return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
+                                          NumCastsRemoved);
+      }
+    }
+    break;
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::Trunc:
+    // If this is the same kind of case as our original (e.g. zext+zext), we
+    // can safely replace it.  Note that replacing it does not reduce the number
+    // of casts in the input.
+    if (Opc == CastOpc)
+      return true;
+
+    // sext (zext ty1), ty2 -> zext ty2
+    if (CastOpc == Instruction::SExt && Opc == Instruction::ZExt)
+      return true;
+    break;
+  case Instruction::Select: {
+    SelectInst *SI = cast<SelectInst>(I);
+    return CanEvaluateInDifferentType(SI->getTrueValue(), Ty, CastOpc,
+                                      NumCastsRemoved) &&
+           CanEvaluateInDifferentType(SI->getFalseValue(), Ty, CastOpc,
+                                      NumCastsRemoved);
+  }
+  case Instruction::PHI: {
+    // We can change a phi if we can change all operands.
+    PHINode *PN = cast<PHINode>(I);
+    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+      if (!CanEvaluateInDifferentType(PN->getIncomingValue(i), Ty, CastOpc,
+                                      NumCastsRemoved))
+        return false;
+    return true;
+  }
+  default:
+    // TODO: Can handle more cases here.
+    break;
+  }
+  
+  return false;
+}
+
+/// EvaluateInDifferentType - Given an expression that 
+/// CanEvaluateInDifferentType returns true for, actually insert the code to
+/// evaluate the expression.
+Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, 
+                                             bool isSigned) {
+  if (Constant *C = dyn_cast<Constant>(V))
+    return ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/);
+
+  // Otherwise, it must be an instruction.
+  Instruction *I = cast<Instruction>(V);
+  Instruction *Res = 0;
+  unsigned Opc = I->getOpcode();
+  switch (Opc) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::AShr:
+  case Instruction::LShr:
+  case Instruction::Shl:
+  case Instruction::UDiv:
+  case Instruction::URem: {
+    Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
+    Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+    Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
+    break;
+  }    
+  case Instruction::Trunc:
+  case Instruction::ZExt:
+  case Instruction::SExt:
+    // If the source type of the cast is the type we're trying for then we can
+    // just return the source.  There's no need to insert it because it is not
+    // new.
+    if (I->getOperand(0)->getType() == Ty)
+      return I->getOperand(0);
+    
+    // Otherwise, must be the same type of cast, so just reinsert a new one.
+    Res = CastInst::Create(cast<CastInst>(I)->getOpcode(), I->getOperand(0),Ty);
+    break;
+  case Instruction::Select: {
+    Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+    Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
+    Res = SelectInst::Create(I->getOperand(0), True, False);
+    break;
+  }
+  case Instruction::PHI: {
+    PHINode *OPN = cast<PHINode>(I);
+    PHINode *NPN = PHINode::Create(Ty);
+    for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
+      Value *V =EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
+      NPN->addIncoming(V, OPN->getIncomingBlock(i));
+    }
+    Res = NPN;
+    break;
+  }
+  default: 
+    // TODO: Can handle more cases here.
+    llvm_unreachable("Unreachable!");
+    break;
+  }
+  
+  Res->takeName(I);
+  return InsertNewInstBefore(Res, *I);
+}
 
 
 /// This function is a wrapper around CastInst::isEliminableCastPair. It