Add robust math to constant folding.

Previously our multiplication and other operators could
do overflows, which can lead to security bugs.

BUG=chromium:637050

Change-Id: Icee22a87909e205b71bda1c5bc1627fcf5e26e90
Reviewed-on: https://chromium-review.googlesource.com/382678
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ConstantUnion.cpp b/src/compiler/translator/ConstantUnion.cpp
index 042f610..74bf570 100644
--- a/src/compiler/translator/ConstantUnion.cpp
+++ b/src/compiler/translator/ConstantUnion.cpp
@@ -7,8 +7,62 @@
 
 #include "compiler/translator/ConstantUnion.h"
 
+#include "base/numerics/safe_math.h"
 #include "compiler/translator/Diagnostics.h"
 
+namespace
+{
+
+template <typename T>
+T CheckedSum(base::CheckedNumeric<T> lhs,
+             base::CheckedNumeric<T> rhs,
+             TDiagnostics *diag,
+             const TSourceLoc &line)
+{
+    ASSERT(lhs.IsValid() && rhs.IsValid());
+    auto result = lhs + rhs;
+    if (!result.IsValid())
+    {
+        diag->error(line, "Addition out of range", "*", "");
+        return 0;
+    }
+    return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedDiff(base::CheckedNumeric<T> lhs,
+              base::CheckedNumeric<T> rhs,
+              TDiagnostics *diag,
+              const TSourceLoc &line)
+{
+    ASSERT(lhs.IsValid() && rhs.IsValid());
+    auto result = lhs - rhs;
+    if (!result.IsValid())
+    {
+        diag->error(line, "Difference out of range", "*", "");
+        return 0;
+    }
+    return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedMul(base::CheckedNumeric<T> lhs,
+             base::CheckedNumeric<T> rhs,
+             TDiagnostics *diag,
+             const TSourceLoc &line)
+{
+    ASSERT(lhs.IsValid() && rhs.IsValid());
+    auto result = lhs * rhs;
+    if (!result.IsValid())
+    {
+        diag->error(line, "Multiplication out of range", "*", "");
+        return 0;
+    }
+    return result.ValueOrDefault(0);
+}
+
+}  // anonymous namespace
+
 TConstantUnion::TConstantUnion()
 {
     iConst = 0;
@@ -221,20 +275,21 @@
 // static
 TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
                                    const TConstantUnion &rhs,
-                                   TDiagnostics *diag)
+                                   TDiagnostics *diag,
+                                   const TSourceLoc &line)
 {
     TConstantUnion returnValue;
     ASSERT(lhs.type == rhs.type);
     switch (lhs.type)
     {
         case EbtInt:
-            returnValue.setIConst(lhs.iConst + rhs.iConst);
+            returnValue.setIConst(CheckedSum<int>(lhs.iConst, rhs.iConst, diag, line));
             break;
         case EbtUInt:
-            returnValue.setUConst(lhs.uConst + rhs.uConst);
+            returnValue.setUConst(CheckedSum<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
             break;
         case EbtFloat:
-            returnValue.setFConst(lhs.fConst + rhs.fConst);
+            returnValue.setFConst(CheckedSum<float>(lhs.fConst, rhs.fConst, diag, line));
             break;
         default:
             UNREACHABLE();
@@ -246,20 +301,21 @@
 // static
 TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
                                    const TConstantUnion &rhs,
-                                   TDiagnostics *diag)
+                                   TDiagnostics *diag,
+                                   const TSourceLoc &line)
 {
     TConstantUnion returnValue;
     ASSERT(lhs.type == rhs.type);
     switch (lhs.type)
     {
         case EbtInt:
-            returnValue.setIConst(lhs.iConst - rhs.iConst);
+            returnValue.setIConst(CheckedDiff<int>(lhs.iConst, rhs.iConst, diag, line));
             break;
         case EbtUInt:
-            returnValue.setUConst(lhs.uConst - rhs.uConst);
+            returnValue.setUConst(CheckedDiff<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
             break;
         case EbtFloat:
-            returnValue.setFConst(lhs.fConst - rhs.fConst);
+            returnValue.setFConst(CheckedDiff<float>(lhs.fConst, rhs.fConst, diag, line));
             break;
         default:
             UNREACHABLE();
@@ -271,20 +327,21 @@
 // static
 TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
                                    const TConstantUnion &rhs,
-                                   TDiagnostics *diag)
+                                   TDiagnostics *diag,
+                                   const TSourceLoc &line)
 {
     TConstantUnion returnValue;
     ASSERT(lhs.type == rhs.type);
     switch (lhs.type)
     {
         case EbtInt:
-            returnValue.setIConst(lhs.iConst * rhs.iConst);
+            returnValue.setIConst(CheckedMul<int>(lhs.iConst, rhs.iConst, diag, line));
             break;
         case EbtUInt:
-            returnValue.setUConst(lhs.uConst * rhs.uConst);
+            returnValue.setUConst(CheckedMul<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
             break;
         case EbtFloat:
-            returnValue.setFConst(lhs.fConst * rhs.fConst);
+            returnValue.setFConst(CheckedMul<float>(lhs.fConst, rhs.fConst, diag, line));
             break;
         default:
             UNREACHABLE();
diff --git a/src/compiler/translator/ConstantUnion.h b/src/compiler/translator/ConstantUnion.h
index 917c637..d148d2c 100644
--- a/src/compiler/translator/ConstantUnion.h
+++ b/src/compiler/translator/ConstantUnion.h
@@ -46,13 +46,16 @@
     bool operator<(const TConstantUnion &constant) const;
     static TConstantUnion add(const TConstantUnion &lhs,
                               const TConstantUnion &rhs,
-                              TDiagnostics *diag);
+                              TDiagnostics *diag,
+                              const TSourceLoc &line);
     static TConstantUnion sub(const TConstantUnion &lhs,
                               const TConstantUnion &rhs,
-                              TDiagnostics *diag);
+                              TDiagnostics *diag,
+                              const TSourceLoc &line);
     static TConstantUnion mul(const TConstantUnion &lhs,
                               const TConstantUnion &rhs,
-                              TDiagnostics *diag);
+                              TDiagnostics *diag,
+                              const TSourceLoc &line);
     TConstantUnion operator%(const TConstantUnion &constant) const;
     TConstantUnion operator>>(const TConstantUnion &constant) const;
     TConstantUnion operator<<(const TConstantUnion &constant) const;
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 055d529..97a6076 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -1026,7 +1026,8 @@
             {
                 return nullptr;
             }
-            TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, diagnostics);
+            TConstantUnion *constArray =
+                leftConstant->foldBinary(mOp, rightConstant, diagnostics, mLeft->getLine());
 
             // Nodes may be constant folded without being qualified as constant.
             return CreateFoldedNode(constArray, this, mType.getQualifier());
@@ -1097,7 +1098,8 @@
 //
 TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
                                                  TIntermConstantUnion *rightNode,
-                                                 TDiagnostics *diagnostics)
+                                                 TDiagnostics *diagnostics,
+                                                 const TSourceLoc &line)
 {
     const TConstantUnion *leftArray  = getUnionArrayPointer();
     const TConstantUnion *rightArray = rightNode->getUnionArrayPointer();
@@ -1125,12 +1127,12 @@
       case EOpAdd:
         resultArray = new TConstantUnion[objectSize];
         for (size_t i = 0; i < objectSize; i++)
-            resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics);
+            resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics, line);
         break;
       case EOpSub:
         resultArray = new TConstantUnion[objectSize];
         for (size_t i = 0; i < objectSize; i++)
-            resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics);
+            resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics, line);
         break;
 
       case EOpMul:
@@ -1138,11 +1140,12 @@
       case EOpMatrixTimesScalar:
         resultArray = new TConstantUnion[objectSize];
         for (size_t i = 0; i < objectSize; i++)
-            resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics);
+            resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics, line);
         break;
 
       case EOpMatrixTimesMatrix:
         {
+            // TODO(jmadll): This code should check for overflows.
             ASSERT(getType().getBasicType() == EbtFloat && rightNode->getBasicType() == EbtFloat);
 
             const int leftCols = getCols();
@@ -1244,6 +1247,7 @@
 
       case EOpMatrixTimesVector:
         {
+            // TODO(jmadll): This code should check for overflows.
             ASSERT(rightNode->getBasicType() == EbtFloat);
 
             const int matrixCols = getCols();
@@ -1266,6 +1270,7 @@
 
       case EOpVectorTimesMatrix:
         {
+            // TODO(jmadll): This code should check for overflows.
             ASSERT(getType().getBasicType() == EbtFloat);
 
             const int matrixCols = rightNode->getType().getCols();
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 0fc7ab0..cf127f5 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -360,7 +360,8 @@
 
     TConstantUnion *foldBinary(TOperator op,
                                TIntermConstantUnion *rightNode,
-                               TDiagnostics *diagnostics);
+                               TDiagnostics *diagnostics,
+                               const TSourceLoc &line);
     const TConstantUnion *foldIndexing(int index);
     TConstantUnion *foldUnaryNonComponentWise(TOperator op);
     TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);