[FileCheck] Implement * and / operators for ExpressionValue.

Subscribers: arichardson, hiraditya, thopre, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80915
diff --git a/llvm/docs/CommandGuide/FileCheck.rst b/llvm/docs/CommandGuide/FileCheck.rst
index 0f72b41..1e69c76 100644
--- a/llvm/docs/CommandGuide/FileCheck.rst
+++ b/llvm/docs/CommandGuide/FileCheck.rst
@@ -709,8 +709,10 @@
 * ``name`` is a predefined string literal. Accepted values are:
 
   * add - Returns the sum of its two operands.
+  * div - Returns the quotient of its two operands.
   * max - Returns the largest of its two operands.
   * min - Returns the smallest of its two operands.
+  * mul - Returns the product of its two operands.
   * sub - Returns the difference of its two operands.
 
 * ``<arguments>`` is a comma seperated list of expressions.
diff --git a/llvm/lib/Support/FileCheck.cpp b/llvm/lib/Support/FileCheck.cpp
index fe4d641..d0e79c6 100644
--- a/llvm/lib/Support/FileCheck.cpp
+++ b/llvm/lib/Support/FileCheck.cpp
@@ -230,6 +230,58 @@
   }
 }
 
+Expected<ExpressionValue> llvm::operator*(const ExpressionValue &LeftOperand,
+                                          const ExpressionValue &RightOperand) {
+  // -A * -B == A * B
+  if (LeftOperand.isNegative() && RightOperand.isNegative())
+    return LeftOperand.getAbsolute() * RightOperand.getAbsolute();
+
+  // A * -B == -B * A
+  if (RightOperand.isNegative())
+    return RightOperand * LeftOperand;
+
+  assert(!RightOperand.isNegative() && "Unexpected negative operand!");
+
+  // Result will be negative and can underflow.
+  if (LeftOperand.isNegative()) {
+    auto Result = LeftOperand.getAbsolute() * RightOperand.getAbsolute();
+    if (!Result)
+      return Result;
+
+    return ExpressionValue(0) - *Result;
+  }
+
+  // Result will be positive and can overflow.
+  uint64_t LeftValue = cantFail(LeftOperand.getUnsignedValue());
+  uint64_t RightValue = cantFail(RightOperand.getUnsignedValue());
+  Optional<uint64_t> Result =
+      checkedMulUnsigned<uint64_t>(LeftValue, RightValue);
+  if (!Result)
+    return make_error<OverflowError>();
+
+  return ExpressionValue(*Result);
+}
+
+Expected<ExpressionValue> llvm::operator/(const ExpressionValue &LeftOperand,
+                                          const ExpressionValue &RightOperand) {
+  // -A / -B == A / B
+  if (LeftOperand.isNegative() && RightOperand.isNegative())
+    return LeftOperand.getAbsolute() / RightOperand.getAbsolute();
+
+  // Check for divide by zero.
+  if (RightOperand == ExpressionValue(0))
+    return make_error<OverflowError>();
+
+  // Result will be negative and can underflow.
+  if (LeftOperand.isNegative() || RightOperand.isNegative())
+    return ExpressionValue(0) -
+           cantFail(LeftOperand.getAbsolute() / RightOperand.getAbsolute());
+
+  uint64_t LeftValue = cantFail(LeftOperand.getUnsignedValue());
+  uint64_t RightValue = cantFail(RightOperand.getUnsignedValue());
+  return ExpressionValue(LeftValue / RightValue);
+}
+
 Expected<ExpressionValue> llvm::max(const ExpressionValue &LeftOperand,
                                     const ExpressionValue &RightOperand) {
   if (LeftOperand.isNegative() && RightOperand.isNegative()) {
@@ -592,8 +644,10 @@
 
   auto OptFunc = StringSwitch<Optional<binop_eval_t>>(FuncName)
                      .Case("add", operator+)
+                     .Case("div", operator/)
                      .Case("max", max)
                      .Case("min", min)
+                     .Case("mul", operator*)
                      .Case("sub", operator-)
                      .Default(None);
 
diff --git a/llvm/lib/Support/FileCheckImpl.h b/llvm/lib/Support/FileCheckImpl.h
index c909316..6ca67ec 100644
--- a/llvm/lib/Support/FileCheckImpl.h
+++ b/llvm/lib/Support/FileCheckImpl.h
@@ -152,6 +152,10 @@
                                     const ExpressionValue &Rhs);
 Expected<ExpressionValue> operator-(const ExpressionValue &Lhs,
                                     const ExpressionValue &Rhs);
+Expected<ExpressionValue> operator*(const ExpressionValue &Lhs,
+                                    const ExpressionValue &Rhs);
+Expected<ExpressionValue> operator/(const ExpressionValue &Lhs,
+                                    const ExpressionValue &Rhs);
 Expected<ExpressionValue> max(const ExpressionValue &Lhs,
                               const ExpressionValue &Rhs);
 Expected<ExpressionValue> min(const ExpressionValue &Lhs,
diff --git a/llvm/unittests/Support/FileCheckTest.cpp b/llvm/unittests/Support/FileCheckTest.cpp
index 3763130..9292cec 100644
--- a/llvm/unittests/Support/FileCheckTest.cpp
+++ b/llvm/unittests/Support/FileCheckTest.cpp
@@ -334,6 +334,8 @@
 
 const int64_t MinInt64 = std::numeric_limits<int64_t>::min();
 const int64_t MaxInt64 = std::numeric_limits<int64_t>::max();
+const uint64_t AbsoluteMinInt64 = static_cast<uint64_t>(-(MinInt64 + 1)) + 1;
+const uint64_t AbsoluteMaxInt64 = static_cast<uint64_t>(MaxInt64);
 
 TEST_F(FileCheckTest, ExpressionValueGetUnsigned) {
   // Test positive value.
@@ -478,6 +480,71 @@
   expectOperationValueResult(operator-, 10, 11, -1);
 }
 
+TEST_F(FileCheckTest, ExpressionValueMultiplication) {
+  // Test mixed signed values.
+  expectOperationValueResult(operator*, -3, 10, -30);
+  expectOperationValueResult(operator*, 2, -17, -34);
+  expectOperationValueResult(operator*, 0, MinInt64, 0);
+  expectOperationValueResult(operator*, MinInt64, 1, MinInt64);
+  expectOperationValueResult(operator*, 1, MinInt64, MinInt64);
+  expectOperationValueResult(operator*, MaxInt64, -1, -MaxInt64);
+  expectOperationValueResult(operator*, -1, MaxInt64, -MaxInt64);
+
+  // Test both negative values.
+  expectOperationValueResult(operator*, -3, -10, 30);
+  expectOperationValueResult(operator*, -2, -17, 34);
+  expectOperationValueResult(operator*, MinInt64, -1, AbsoluteMinInt64);
+
+  // Test both positive values.
+  expectOperationValueResult(operator*, 3, 10, 30);
+  expectOperationValueResult(operator*, 2, 17, 34);
+  expectOperationValueResult(operator*, 0, MaxUint64, 0);
+
+  // Test negative results that underflow.
+  expectOperationValueResult(operator*, -10, MaxInt64);
+  expectOperationValueResult(operator*, MaxInt64, -10);
+  expectOperationValueResult(operator*, 10, MinInt64);
+  expectOperationValueResult(operator*, MinInt64, 10);
+  expectOperationValueResult(operator*, -1, MaxUint64);
+  expectOperationValueResult(operator*, MaxUint64, -1);
+  expectOperationValueResult(operator*, -1, AbsoluteMaxInt64 + 2);
+  expectOperationValueResult(operator*, AbsoluteMaxInt64 + 2, -1);
+
+  // Test positive results that overflow.
+  expectOperationValueResult(operator*, 10, MaxUint64);
+  expectOperationValueResult(operator*, MaxUint64, 10);
+  expectOperationValueResult(operator*, MinInt64, -10);
+  expectOperationValueResult(operator*, -10, MinInt64);
+}
+
+TEST_F(FileCheckTest, ExpressionValueDivision) {
+  // Test mixed signed values.
+  expectOperationValueResult(operator/, -30, 10, -3);
+  expectOperationValueResult(operator/, 34, -17, -2);
+  expectOperationValueResult(operator/, 0, -10, 0);
+  expectOperationValueResult(operator/, MinInt64, 1, MinInt64);
+  expectOperationValueResult(operator/, MaxInt64, -1, -MaxInt64);
+  expectOperationValueResult(operator/, -MaxInt64, 1, -MaxInt64);
+
+  // Test both negative values.
+  expectOperationValueResult(operator/, -30, -10, 3);
+  expectOperationValueResult(operator/, -34, -17, 2);
+
+  // Test both positive values.
+  expectOperationValueResult(operator/, 30, 10, 3);
+  expectOperationValueResult(operator/, 34, 17, 2);
+  expectOperationValueResult(operator/, 0, 10, 0);
+
+  // Test divide by zero.
+  expectOperationValueResult(operator/, -10, 0);
+  expectOperationValueResult(operator/, 10, 0);
+  expectOperationValueResult(operator/, 0, 0);
+
+  // Test negative result that underflows.
+  expectOperationValueResult(operator/, MaxUint64, -1);
+  expectOperationValueResult(operator/, AbsoluteMaxInt64 + 2, -1);
+}
+
 TEST_F(FileCheckTest, ExpressionValueEquality) {
   // Test negative and positive value.
   EXPECT_FALSE(ExpressionValue(5) == ExpressionValue(-3));
@@ -1308,8 +1375,8 @@
   ASSERT_FALSE(Tester.parsePattern("[[#add(NUMVAR,13)]]"));
   EXPECT_THAT_EXPECTED(Tester.match("31"), Succeeded());
   Tester.initNextPattern();
-  ASSERT_FALSE(Tester.parsePattern("[[#sub(NUMVAR,7)]]"));
-  EXPECT_THAT_EXPECTED(Tester.match("11"), Succeeded());
+  ASSERT_FALSE(Tester.parsePattern("[[#div(NUMVAR,3)]]"));
+  EXPECT_THAT_EXPECTED(Tester.match("6"), Succeeded());
   Tester.initNextPattern();
   ASSERT_FALSE(Tester.parsePattern("[[#max(NUMVAR,5)]]"));
   EXPECT_THAT_EXPECTED(Tester.match("18"), Succeeded());
@@ -1322,6 +1389,12 @@
   Tester.initNextPattern();
   ASSERT_FALSE(Tester.parsePattern("[[#min(NUMVAR,99)]]"));
   EXPECT_THAT_EXPECTED(Tester.match("18"), Succeeded());
+  Tester.initNextPattern();
+  ASSERT_FALSE(Tester.parsePattern("[[#mul(NUMVAR,3)]]"));
+  EXPECT_THAT_EXPECTED(Tester.match("54"), Succeeded());
+  Tester.initNextPattern();
+  ASSERT_FALSE(Tester.parsePattern("[[#sub(NUMVAR,7)]]"));
+  EXPECT_THAT_EXPECTED(Tester.match("11"), Succeeded());
 
   // Check nested function calls.
   Tester.initNextPattern();