[InstCombine] fold urem with sext bool divisor

Similar to other patches in this series:
https://reviews.llvm.org/rL335512
https://reviews.llvm.org/rL335527
https://reviews.llvm.org/rL335597
https://reviews.llvm.org/rL335616

...this is filling a gap in analysis that is exposed by an unrelated select-of-constants transform.
I didn't see a way to unify the sext cases because each div/rem opcode results in a different fold.

Note that in this case, the backend might want to convert the select into math:
Name: sext urem
%e = sext i1 %x to i32
%r = urem i32 %y, %e
=>
%c = icmp eq i32 %y, -1
%z = zext i1 %c to i32
%r = add i32 %z, %y

llvm-svn: 335622
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 7a8f2c2..9833bbc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1295,8 +1295,9 @@
 
   // X urem Y -> X and Y-1, where Y is a power of 2,
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  Type *Ty = I.getType();
   if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {
-    Constant *N1 = Constant::getAllOnesValue(I.getType());
+    Constant *N1 = Constant::getAllOnesValue(Ty);
     Value *Add = Builder.CreateAdd(Op1, N1);
     return BinaryOperator::CreateAnd(Op0, Add);
   }
@@ -1304,7 +1305,7 @@
   // 1 urem X -> zext(X != 1)
   if (match(Op0, m_One())) {
     Value *Cmp = Builder.CreateICmpNE(Op1, Op0);
-    Value *Ext = Builder.CreateZExt(Cmp, I.getType());
+    Value *Ext = Builder.CreateZExt(Cmp, Ty);
     return replaceInstUsesWith(I, Ext);
   }
 
@@ -1315,6 +1316,16 @@
     return SelectInst::Create(Cmp, Op0, Sub);
   }
 
+  // If the divisor is a sext of a boolean, then the divisor must be max
+  // unsigned value (-1). Therefore, the remainder is Op0 unless Op0 is also
+  // max unsigned value. In that case, the remainder is 0:
+  // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0
+  Value *X;
+  if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+    Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty));
+    return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0);
+  }
+
   return nullptr;
 }