[ARM][ARMCGP] Remove unecessary zexts and truncs

r345840 slightly changed the way promotion happens which could
result in zext and truncs having the same source and destination
types. This fixes that issue.

We can now also remove the zext and trunc in the following case:
(zext (trunc (promoted op)), i32)

This means that we can no longer treat a value, that is only used by
a sink, to be safe to promote.

I've also added in some extra asserts and replaced a cast for a
dyn_cast.

Differential Revision: https://reviews.llvm.org/D54032

llvm-svn: 346125
diff --git a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
index 0a6ea9d..8a7555b 100644
--- a/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
+++ b/llvm/lib/Target/ARM/ARMCodeGenPrepare.cpp
@@ -114,8 +114,8 @@
   SmallPtrSet<Value*, 8> Promoted;
   Module *M = nullptr;
   LLVMContext &Ctx;
-  Type *ExtTy = nullptr;
-  Type *OrigTy = nullptr;
+  IntegerType *ExtTy = nullptr;
+  IntegerType *OrigTy = nullptr;
 
   void PrepareConstants(SmallPtrSetImpl<Value*> &Visited,
                          SmallPtrSetImpl<Instruction*> &SafeToPromote);
@@ -126,20 +126,12 @@
                    SmallPtrSetImpl<Instruction*> &SafeToPromote);
   void TruncateSinks(SmallPtrSetImpl<Value*> &Sources,
                      SmallPtrSetImpl<Instruction*> &Sinks);
+  void Cleanup(SmallPtrSetImpl<Instruction*> &Sinks);
 
 public:
   IRPromoter(Module *M) : M(M), Ctx(M->getContext()),
                           ExtTy(Type::getInt32Ty(Ctx)) { }
 
-  void Cleanup() {
-    for (auto *I : InstsToRemove) {
-      LLVM_DEBUG(dbgs() << "ARM CGP: Removing " << *I << "\n");
-      I->dropAllReferences();
-      I->eraseFromParent();
-    }
-    InstsToRemove.clear();
-    NewInsts.clear();
-  }
 
   void Mutate(Type *OrigTy,
               SmallPtrSetImpl<Value*> &Visited,
@@ -401,17 +393,7 @@
   if (generateSignBits(V))
     return false;
 
-  // If I is only being used by something that will require its value to be
-  // truncated, then we don't care about the promoted result.
-  auto *I = cast<Instruction>(V);
-  if (I->hasOneUse() && isSink(*I->use_begin())) {
-    LLVM_DEBUG(dbgs() << "ARM CGP: Only use is a sink: " << *V << "\n");
-    return true;
-  }
-
-  if (isa<OverflowingBinaryOperator>(I))
-    return false;
-  return true;
+  return !isa<OverflowingBinaryOperator>(V);
 }
 
 /// Return the intrinsic for the instruction that can perform the same
@@ -514,21 +496,24 @@
   IRBuilder<> Builder{Ctx};
 
   auto InsertZExt = [&](Value *V, Instruction *InsertPt) {
+    assert(V->getType() != ExtTy && "zext already extends to i32");
     LLVM_DEBUG(dbgs() << "ARM CGP: Inserting ZExt for " << *V << "\n");
     Builder.SetInsertPoint(InsertPt);
     if (auto *I = dyn_cast<Instruction>(V))
       Builder.SetCurrentDebugLocation(I->getDebugLoc());
-    auto *ZExt = cast<Instruction>(Builder.CreateZExt(V, ExtTy));
-    if (isa<Argument>(V))
-      ZExt->moveBefore(InsertPt);
-    else
-      ZExt->moveAfter(InsertPt);
+
+    Value *ZExt = Builder.CreateZExt(V, ExtTy);
+    if (auto *I = dyn_cast<Instruction>(ZExt)) {
+      if (isa<Argument>(V))
+        I->moveBefore(InsertPt);
+      else
+        I->moveAfter(InsertPt);
+      NewInsts.insert(I);
+    }
     ReplaceAllUsersOfWith(V, ZExt);
-    NewInsts.insert(ZExt);
     TruncTysMap[ZExt] = TruncTysMap[V];
   };
 
-
   // Now, insert extending instructions between the sources and their users.
   LLVM_DEBUG(dbgs() << "ARM CGP: Promoting sources:\n");
   for (auto V : Sources) {
@@ -664,6 +649,49 @@
       }
     }
   }
+
+}
+
+void IRPromoter::Cleanup(SmallPtrSetImpl<Instruction*> &Sinks) {
+  // Some zext sinks will now have become redundant, along with their trunc
+  // operands, so remove them.
+  for (auto I : Sinks) {
+    if (auto *ZExt = dyn_cast<ZExtInst>(I)) {
+      if (ZExt->getDestTy() != ExtTy)
+        continue;
+
+      Value *Src = ZExt->getOperand(0);
+      if (ZExt->getSrcTy() == ZExt->getDestTy()) {
+        LLVM_DEBUG(dbgs() << "ARM CGP: Removing unnecessary zext\n");
+        ReplaceAllUsersOfWith(ZExt, Src);
+        InstsToRemove.push_back(ZExt);
+        continue;
+      }
+
+      // For any truncs that we insert to handle zexts, we can replace the
+      // result of the zext with the input to the trunc.
+      if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
+        auto *Trunc = cast<TruncInst>(Src);
+        assert(Trunc->getOperand(0)->getType() == ExtTy &&
+               "expected inserted trunc to be operating on i32");
+        LLVM_DEBUG(dbgs() << "ARM CGP: Replacing zext with trunc operand: "
+                   << *Trunc->getOperand(0));
+        ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
+        InstsToRemove.push_back(ZExt);
+      }
+    }
+  }
+
+  for (auto *I : InstsToRemove) {
+    LLVM_DEBUG(dbgs() << "ARM CGP: Removing " << *I << "\n");
+    I->dropAllReferences();
+    I->eraseFromParent();
+  }
+
+  InstsToRemove.clear();
+  NewInsts.clear();
+  TruncTysMap.clear();
+  Promoted.clear();
 }
 
 void IRPromoter::Mutate(Type *OrigTy,
@@ -673,7 +701,11 @@
                         SmallPtrSetImpl<Instruction*> &SafeToPromote) {
   LLVM_DEBUG(dbgs() << "ARM CGP: Promoting use-def chains to from "
              << ARMCodeGenPrepare::TypeSize << " to 32-bits\n");
-  this->OrigTy = OrigTy;
+
+  assert(isa<IntegerType>(OrigTy) && "expected integer type");
+  this->OrigTy = cast<IntegerType>(OrigTy);
+  assert(OrigTy->getPrimitiveSizeInBits() < ExtTy->getPrimitiveSizeInBits() &&
+         "original type not smaller than extended type");
 
   // Cache original types.
   for (auto *V : Visited)
@@ -691,9 +723,13 @@
   // promote.
   PromoteTree(Visited, Sources, Sinks, SafeToPromote);
 
-  // Finally, insert trunc instructions for use by calls, stores etc...
+  // Insert trunc instructions for use by calls, stores etc...
   TruncateSinks(Sources, Sinks);
 
+  // Finally, remove unecessary zexts and truncs, delete old instructions and
+  // clear the data structures.
+  Cleanup(Sinks);
+
   LLVM_DEBUG(dbgs() << "ARM CGP: Mutation complete:\n");
   LLVM_DEBUG(dbgs();
              for (auto *V : Sources)
@@ -943,9 +979,8 @@
         }
       }
     }
-    Promoter->Cleanup();
     LLVM_DEBUG(if (verifyFunction(F, &dbgs())) {
-                dbgs();
+                dbgs() << F;
                 report_fatal_error("Broken function after type promotion");
                });
   }
diff --git a/llvm/test/CodeGen/ARM/CGP/arm-cgp-calls.ll b/llvm/test/CodeGen/ARM/CGP/arm-cgp-calls.ll
index 5972980..244c6bd 100644
--- a/llvm/test/CodeGen/ARM/CGP/arm-cgp-calls.ll
+++ b/llvm/test/CodeGen/ARM/CGP/arm-cgp-calls.ll
@@ -144,10 +144,9 @@
   br label %for.cond.backedge
 }
 
-; Transform will bail because of the zext
 ; Check that d.sroa.0.0.be is promoted passed directly into the tail call.
 ; CHECK-LABEL: check_zext_phi_call_arg
-; CHECK: uxt
+; CHECK-NOT: uxt
 define i32 @check_zext_phi_call_arg() {
 entry:
   br label %for.cond
diff --git a/llvm/test/CodeGen/ARM/CGP/arm-cgp-casts.ll b/llvm/test/CodeGen/ARM/CGP/arm-cgp-casts.ll
index 23467c9..43184648 100644
--- a/llvm/test/CodeGen/ARM/CGP/arm-cgp-casts.ll
+++ b/llvm/test/CodeGen/ARM/CGP/arm-cgp-casts.ll
@@ -232,9 +232,10 @@
 ; promote %1 for the call - unless we can generate a uadd16.
 ; CHECK-COMMON-LABEL: zext_load_sink_call:
 ; CHECK-COMMON: uxt
-; uadd16
-; cmp
-; CHECK-COMMON: uxt
+; CHECK-DSP-IMM: uadd16
+; CHECK-COMMON: cmp
+; CHECK-NODSP: uxt
+; CHECK-DSP-IMM-NOT: uxt
 define i32 @zext_load_sink_call(i16* %ptr, i16 %exp) {
 entry:
   %0 = load i16, i16* %ptr, align 4
@@ -338,3 +339,27 @@
 @d_uch = hidden local_unnamed_addr global [16 x i8] zeroinitializer, align 1
 @sh1 = hidden local_unnamed_addr global i16 0, align 2
 @d_sh = hidden local_unnamed_addr global [16 x i16] zeroinitializer, align 2
+
+; CHECK-LABEL: two_stage_zext_trunc_mix
+; CHECK-NOT: uxt
+define i8* @two_stage_zext_trunc_mix(i32* %this, i32 %__pos1, i32 %__n1, i32** %__str, i32 %__pos2, i32 %__n2) {
+entry:
+  %__size_.i.i.i.i = bitcast i32** %__str to i8*
+  %0 = load i8, i8* %__size_.i.i.i.i, align 4
+  %1 = and i8 %0, 1
+  %tobool.i.i.i.i = icmp eq i8 %1, 0
+  %__size_.i5.i.i = getelementptr inbounds i32*, i32** %__str, i32 %__n1
+  %cast = bitcast i32** %__size_.i5.i.i to i32*
+  %2 = load i32, i32* %cast, align 4
+  %3 = lshr i8 %0, 1
+  %4 = zext i8 %3 to i32
+  %cond.i.i = select i1 %tobool.i.i.i.i, i32 %4, i32 %2
+  %__size_.i.i.i.i.i = bitcast i32* %this to i8*
+  %5 = load i8, i8* %__size_.i.i.i.i.i, align 4
+  %6 = and i8 %5, 1
+  %tobool.i.i.i.i.i = icmp eq i8 %6, 0
+  %7 = getelementptr inbounds i8, i8* %__size_.i.i.i.i, i32 %__pos1
+  %8 = getelementptr inbounds i8, i8* %__size_.i.i.i.i, i32 %__pos2
+  %res = select i1 %tobool.i.i.i.i.i,  i8* %7, i8* %8
+  ret i8* %res
+}