[ARM] Enable mixed types in ARM CGP

Previously, during the search, all values had to have the same
'TypeSize', which is equal to number of bits of the integer type of
the icmp operand. All values in the tree had to match this size;
meaning that, if we searched from i16, we wouldn't accept i8s. A
change in type size requires zext and truncs to perform the casts so,
to allow mixed narrow types, the handling of these instructions is
now slightly different:

- we allow casts if their result or operand is <= TypeSize.
- zexts are sinks if their result > TypeSize.
- truncs are still sinks if their operand == TypeSize.
- truncs are still sources if their result == TypeSize.

The transformation bails on finding an icmp that operates on data
smaller than the current TypeSize.

Differential Revision: https://reviews.llvm.org/D54108

llvm-svn: 346480
diff --git a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
index 0bd1f9c..06949be 100644
--- a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
+++ b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
@@ -126,7 +126,7 @@
                    SmallPtrSetImpl<Instruction*> &SafeToPromote);
   void TruncateSinks(SmallPtrSetImpl<Value*> &Sources,
                      SmallPtrSetImpl<Instruction*> &Sinks);
-  void Cleanup(SmallPtrSetImpl<Instruction*> &Sinks);
+  void Cleanup(SmallPtrSetImpl<Value*> &Visited);
 
 public:
   IRPromoter(Module *M) : M(M), Ctx(M->getContext()),
@@ -180,6 +180,18 @@
          Opc == Instruction::SRem;
 }
 
+static bool EqualTypeSize(Value *V) {
+  return V->getType()->getScalarSizeInBits() == ARMCodeGenPrepare::TypeSize;
+}
+
+static bool LessOrEqualTypeSize(Value *V) {
+  return V->getType()->getScalarSizeInBits() <= ARMCodeGenPrepare::TypeSize;
+}
+
+static bool GreaterThanTypeSize(Value *V) {
+  return V->getType()->getScalarSizeInBits() > ARMCodeGenPrepare::TypeSize;
+}
+
 /// Some instructions can use 8- and 16-bit operands, and we don't need to
 /// promote anything larger. We disallow booleans to make life easier when
 /// dealing with icmps but allow any other integer that is <= 16 bits. Void
@@ -194,11 +206,10 @@
   if (auto *Ld = dyn_cast<LoadInst>(V))
     Ty = cast<PointerType>(Ld->getPointerOperandType())->getElementType();
 
-  const IntegerType *IntTy = dyn_cast<IntegerType>(Ty);
-  if (!IntTy)
+  if (!isa<IntegerType>(Ty))
     return false;
 
-  return IntTy->getBitWidth() == ARMCodeGenPrepare::TypeSize;
+  return LessOrEqualTypeSize(V);
 }
 
 /// Return true if the given value is a source in the use-def chain, producing
@@ -221,7 +232,7 @@
   else if (auto *Call = dyn_cast<CallInst>(V))
     return Call->hasRetAttr(Attribute::AttrKind::ZExt);
   else if (auto *Trunc = dyn_cast<TruncInst>(V))
-    return isSupportedType(Trunc);
+    return EqualTypeSize(Trunc);
   return false;
 }
 
@@ -232,18 +243,15 @@
   // TODO The truncate also isn't actually necessary because we would already
   // proved that the data value is kept within the range of the original data
   // type.
-  auto UsesNarrowValue = [](Value *V) {
-    return V->getType()->getScalarSizeInBits() == ARMCodeGenPrepare::TypeSize;
-  };
 
   if (auto *Store = dyn_cast<StoreInst>(V))
-    return UsesNarrowValue(Store->getValueOperand());
+    return LessOrEqualTypeSize(Store->getValueOperand());
   if (auto *Return = dyn_cast<ReturnInst>(V))
-    return UsesNarrowValue(Return->getReturnValue());
+    return LessOrEqualTypeSize(Return->getReturnValue());
   if (auto *Trunc = dyn_cast<TruncInst>(V))
-    return UsesNarrowValue(Trunc->getOperand(0));
+    return EqualTypeSize(Trunc->getOperand(0));
   if (auto *ZExt = dyn_cast<ZExtInst>(V))
-    return UsesNarrowValue(ZExt->getOperand(0));
+    return GreaterThanTypeSize(ZExt);
   if (auto *ICmp = dyn_cast<ICmpInst>(V))
     return ICmp->isSigned();
 
@@ -649,36 +657,37 @@
       }
     }
   }
-
 }
 
-void IRPromoter::Cleanup(SmallPtrSetImpl<Instruction*> &Sinks) {
-  // Some zext sinks will now have become redundant, along with their trunc
-  // operands, so remove them.
-  for (auto I : Sinks) {
-    if (auto *ZExt = dyn_cast<ZExtInst>(I)) {
-      if (ZExt->getDestTy() != ExtTy)
-        continue;
+void IRPromoter::Cleanup(SmallPtrSetImpl<Value*> &Visited) {
+  // Some zexts will now have become redundant, along with their trunc
+  // operands, so remove them
+  for (auto V : Visited) {
+    if (!isa<ZExtInst>(V))
+      continue;
 
-      Value *Src = ZExt->getOperand(0);
-      if (ZExt->getSrcTy() == ZExt->getDestTy()) {
-        LLVM_DEBUG(dbgs() << "ARM CGP: Removing unnecessary zext\n");
-        ReplaceAllUsersOfWith(ZExt, Src);
-        InstsToRemove.push_back(ZExt);
-        continue;
-      }
+    auto ZExt = cast<ZExtInst>(V);
+    if (ZExt->getDestTy() != ExtTy)
+      continue;
 
-      // For any truncs that we insert to handle zexts, we can replace the
-      // result of the zext with the input to the trunc.
-      if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
-        auto *Trunc = cast<TruncInst>(Src);
-        assert(Trunc->getOperand(0)->getType() == ExtTy &&
-               "expected inserted trunc to be operating on i32");
-        LLVM_DEBUG(dbgs() << "ARM CGP: Replacing zext with trunc operand: "
-                   << *Trunc->getOperand(0));
-        ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
-        InstsToRemove.push_back(ZExt);
-      }
+    Value *Src = ZExt->getOperand(0);
+    if (ZExt->getSrcTy() == ZExt->getDestTy()) {
+      LLVM_DEBUG(dbgs() << "ARM CGP: Removing unnecessary cast.\n");
+      ReplaceAllUsersOfWith(ZExt, Src);
+      InstsToRemove.push_back(ZExt);
+      continue;
+    }
+
+    // For any truncs that we insert to handle zexts, we can replace the
+    // result of the zext with the input to the trunc.
+    if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
+      auto *Trunc = cast<TruncInst>(Src);
+      assert(Trunc->getOperand(0)->getType() == ExtTy &&
+             "expected inserted trunc to be operating on i32");
+      LLVM_DEBUG(dbgs() << "ARM CGP: Replacing zext with trunc operand: "
+                 << *Trunc->getOperand(0));
+      ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
+      InstsToRemove.push_back(ZExt);
     }
   }
 
@@ -728,18 +737,9 @@
 
   // Finally, remove unecessary zexts and truncs, delete old instructions and
   // clear the data structures.
-  Cleanup(Sinks);
+  Cleanup(Visited);
 
-  LLVM_DEBUG(dbgs() << "ARM CGP: Mutation complete:\n");
-  LLVM_DEBUG(dbgs();
-             for (auto *V : Sources)
-               V->dump();
-             for (auto *I : NewInsts)
-               I->dump();
-             for (auto *V : Visited) {
-               if (!Sources.count(V))
-                 V->dump();
-              });
+  LLVM_DEBUG(dbgs() << "ARM CGP: Mutation complete\n");
 }
 
 /// We accept most instructions, as well as Arguments and ConstantInsts. We
@@ -747,8 +747,15 @@
 /// return value is zeroext. We don't allow opcodes that can introduce sign
 /// bits.
 bool ARMCodeGenPrepare::isSupportedValue(Value *V) {
-  if (isa<ICmpInst>(V))
-    return true;
+  if (auto *I = dyn_cast<ICmpInst>(V)) {
+    // Now that we allow small types than TypeSize, only allow icmp of
+    // TypeSize because they will require a trunc to be legalised.
+    // TODO: Allow icmp of smaller types, and calculate at the end
+    // whether the transform would be beneficial.
+    if (isa<PointerType>(I->getOperand(0)->getType()))
+      return true;
+    return EqualTypeSize(I->getOperand(0));
+  }
 
   // Memory instructions
   if (isa<StoreInst>(V) || isa<GetElementPtrInst>(V))
@@ -766,12 +773,11 @@
       isa<LoadInst>(V))
     return isSupportedType(V);
 
-  // Truncs can be either sources or sinks.
-  if (auto *Trunc = dyn_cast<TruncInst>(V))
-    return isSupportedType(Trunc) || isSupportedType(Trunc->getOperand(0));
+  if (isa<SExtInst>(V))
+    return false;
 
-  if (isa<CastInst>(V) && !isa<SExtInst>(V))
-    return isSupportedType(cast<CastInst>(V)->getOperand(0));
+  if (auto *Cast = dyn_cast<CastInst>(V))
+    return isSupportedType(Cast) || isSupportedType(Cast->getOperand(0));
 
   // Special cases for calls as we need to check for zeroext
   // TODO We should accept calls even if they don't have zeroext, as they can
@@ -901,13 +907,17 @@
     // Calls can be both sources and sinks.
     if (isSink(V))
       Sinks.insert(cast<Instruction>(V));
+
     if (isSource(V))
       Sources.insert(V);
-    else if (auto *I = dyn_cast<Instruction>(V)) {
-      // Visit operands of any instruction visited.
-      for (auto &U : I->operands()) {
-        if (!AddLegalInst(U))
-          return false;
+
+    if (!isSink(V) && !isSource(V)) {
+      if (auto *I = dyn_cast<Instruction>(V)) {
+        // Visit operands of any instruction visited.
+        for (auto &U : I->operands()) {
+          if (!AddLegalInst(U))
+            return false;
+        }
       }
     }
 
@@ -973,6 +983,8 @@
         if (CI.isSigned() || !isa<IntegerType>(CI.getOperand(0)->getType()))
           continue;
 
+        LLVM_DEBUG(dbgs() << "ARM CGP: Searching from: " << CI << "\n");
+
         for (auto &Op : CI.operands()) {
           if (auto *I = dyn_cast<Instruction>(Op))
             MadeChange |= TryToPromote(I);