[ARM][ARMCGP] Remove unecessary zexts and truncs

r345840 slightly changed the way promotion happens which could
result in zext and truncs having the same source and destination
types. This fixes that issue.

We can now also remove the zext and trunc in the following case:
(zext (trunc (promoted op)), i32)

This means that we can no longer treat a value, that is only used by
a sink, to be safe to promote.

I've also added in some extra asserts and replaced a cast for a
dyn_cast.

Differential Revision: https://reviews.llvm.org/D54032

llvm-svn: 346125
diff --git a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
index 0a6ea9d..8a7555b 100644
--- a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
+++ b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
@@ -114,8 +114,8 @@
   SmallPtrSet<Value*, 8> Promoted;
   Module *M = nullptr;
   LLVMContext &Ctx;
-  Type *ExtTy = nullptr;
-  Type *OrigTy = nullptr;
+  IntegerType *ExtTy = nullptr;
+  IntegerType *OrigTy = nullptr;
 
   void PrepareConstants(SmallPtrSetImpl<Value*> &Visited,
                          SmallPtrSetImpl<Instruction*> &SafeToPromote);
@@ -126,20 +126,12 @@
                    SmallPtrSetImpl<Instruction*> &SafeToPromote);
   void TruncateSinks(SmallPtrSetImpl<Value*> &Sources,
                      SmallPtrSetImpl<Instruction*> &Sinks);
+  void Cleanup(SmallPtrSetImpl<Instruction*> &Sinks);
 
 public:
   IRPromoter(Module *M) : M(M), Ctx(M->getContext()),
                           ExtTy(Type::getInt32Ty(Ctx)) { }
 
-  void Cleanup() {
-    for (auto *I : InstsToRemove) {
-      LLVM_DEBUG(dbgs() << "ARM CGP: Removing " << *I << "\n");
-      I->dropAllReferences();
-      I->eraseFromParent();
-    }
-    InstsToRemove.clear();
-    NewInsts.clear();
-  }
 
   void Mutate(Type *OrigTy,
               SmallPtrSetImpl<Value*> &Visited,
@@ -401,17 +393,7 @@
   if (generateSignBits(V))
     return false;
 
-  // If I is only being used by something that will require its value to be
-  // truncated, then we don't care about the promoted result.
-  auto *I = cast<Instruction>(V);
-  if (I->hasOneUse() && isSink(*I->use_begin())) {
-    LLVM_DEBUG(dbgs() << "ARM CGP: Only use is a sink: " << *V << "\n");
-    return true;
-  }
-
-  if (isa<OverflowingBinaryOperator>(I))
-    return false;
-  return true;
+  return !isa<OverflowingBinaryOperator>(V);
 }
 
 /// Return the intrinsic for the instruction that can perform the same
@@ -514,21 +496,24 @@
   IRBuilder<> Builder{Ctx};
 
   auto InsertZExt = [&](Value *V, Instruction *InsertPt) {
+    assert(V->getType() != ExtTy && "zext already extends to i32");
     LLVM_DEBUG(dbgs() << "ARM CGP: Inserting ZExt for " << *V << "\n");
     Builder.SetInsertPoint(InsertPt);
     if (auto *I = dyn_cast<Instruction>(V))
       Builder.SetCurrentDebugLocation(I->getDebugLoc());
-    auto *ZExt = cast<Instruction>(Builder.CreateZExt(V, ExtTy));
-    if (isa<Argument>(V))
-      ZExt->moveBefore(InsertPt);
-    else
-      ZExt->moveAfter(InsertPt);
+
+    Value *ZExt = Builder.CreateZExt(V, ExtTy);
+    if (auto *I = dyn_cast<Instruction>(ZExt)) {
+      if (isa<Argument>(V))
+        I->moveBefore(InsertPt);
+      else
+        I->moveAfter(InsertPt);
+      NewInsts.insert(I);
+    }
     ReplaceAllUsersOfWith(V, ZExt);
-    NewInsts.insert(ZExt);
     TruncTysMap[ZExt] = TruncTysMap[V];
   };
 
-
   // Now, insert extending instructions between the sources and their users.
   LLVM_DEBUG(dbgs() << "ARM CGP: Promoting sources:\n");
   for (auto V : Sources) {
@@ -664,6 +649,49 @@
       }
     }
   }
+
+}
+
+void IRPromoter::Cleanup(SmallPtrSetImpl<Instruction*> &Sinks) {
+  // Some zext sinks will now have become redundant, along with their trunc
+  // operands, so remove them.
+  for (auto I : Sinks) {
+    if (auto *ZExt = dyn_cast<ZExtInst>(I)) {
+      if (ZExt->getDestTy() != ExtTy)
+        continue;
+
+      Value *Src = ZExt->getOperand(0);
+      if (ZExt->getSrcTy() == ZExt->getDestTy()) {
+        LLVM_DEBUG(dbgs() << "ARM CGP: Removing unnecessary zext\n");
+        ReplaceAllUsersOfWith(ZExt, Src);
+        InstsToRemove.push_back(ZExt);
+        continue;
+      }
+
+      // For any truncs that we insert to handle zexts, we can replace the
+      // result of the zext with the input to the trunc.
+      if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
+        auto *Trunc = cast<TruncInst>(Src);
+        assert(Trunc->getOperand(0)->getType() == ExtTy &&
+               "expected inserted trunc to be operating on i32");
+        LLVM_DEBUG(dbgs() << "ARM CGP: Replacing zext with trunc operand: "
+                   << *Trunc->getOperand(0));
+        ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
+        InstsToRemove.push_back(ZExt);
+      }
+    }
+  }
+
+  for (auto *I : InstsToRemove) {
+    LLVM_DEBUG(dbgs() << "ARM CGP: Removing " << *I << "\n");
+    I->dropAllReferences();
+    I->eraseFromParent();
+  }
+
+  InstsToRemove.clear();
+  NewInsts.clear();
+  TruncTysMap.clear();
+  Promoted.clear();
 }
 
 void IRPromoter::Mutate(Type *OrigTy,
@@ -673,7 +701,11 @@
                         SmallPtrSetImpl<Instruction*> &SafeToPromote) {
   LLVM_DEBUG(dbgs() << "ARM CGP: Promoting use-def chains to from "
              << ARMCodeGenPrepare::TypeSize << " to 32-bits\n");
-  this->OrigTy = OrigTy;
+
+  assert(isa<IntegerType>(OrigTy) && "expected integer type");
+  this->OrigTy = cast<IntegerType>(OrigTy);
+  assert(OrigTy->getPrimitiveSizeInBits() < ExtTy->getPrimitiveSizeInBits() &&
+         "original type not smaller than extended type");
 
   // Cache original types.
   for (auto *V : Visited)
@@ -691,9 +723,13 @@
   // promote.
   PromoteTree(Visited, Sources, Sinks, SafeToPromote);
 
-  // Finally, insert trunc instructions for use by calls, stores etc...
+  // Insert trunc instructions for use by calls, stores etc...
   TruncateSinks(Sources, Sinks);
 
+  // Finally, remove unecessary zexts and truncs, delete old instructions and
+  // clear the data structures.
+  Cleanup(Sinks);
+
   LLVM_DEBUG(dbgs() << "ARM CGP: Mutation complete:\n");
   LLVM_DEBUG(dbgs();
              for (auto *V : Sources)
@@ -943,9 +979,8 @@
         }
       }
     }
-    Promoter->Cleanup();
     LLVM_DEBUG(if (verifyFunction(F, &dbgs())) {
-                dbgs();
+                dbgs() << F;
                 report_fatal_error("Broken function after type promotion");
                });
   }