Add robust math to constant folding.
Previously our multiplication and other operators could
do overflows, which can lead to security bugs.
BUG=chromium:637050
Change-Id: Icee22a87909e205b71bda1c5bc1627fcf5e26e90
Reviewed-on: https://chromium-review.googlesource.com/382678
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ConstantUnion.cpp b/src/compiler/translator/ConstantUnion.cpp
index 042f610..74bf570 100644
--- a/src/compiler/translator/ConstantUnion.cpp
+++ b/src/compiler/translator/ConstantUnion.cpp
@@ -7,8 +7,62 @@
#include "compiler/translator/ConstantUnion.h"
+#include "base/numerics/safe_math.h"
#include "compiler/translator/Diagnostics.h"
+namespace
+{
+
+template <typename T>
+T CheckedSum(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs + rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Addition out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedDiff(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs - rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Difference out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedMul(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs * rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Multiplication out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+} // anonymous namespace
+
TConstantUnion::TConstantUnion()
{
iConst = 0;
@@ -221,20 +275,21 @@
// static
TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst + rhs.iConst);
+ returnValue.setIConst(CheckedSum<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst + rhs.uConst);
+ returnValue.setUConst(CheckedSum<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst + rhs.fConst);
+ returnValue.setFConst(CheckedSum<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
@@ -246,20 +301,21 @@
// static
TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst - rhs.iConst);
+ returnValue.setIConst(CheckedDiff<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst - rhs.uConst);
+ returnValue.setUConst(CheckedDiff<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst - rhs.fConst);
+ returnValue.setFConst(CheckedDiff<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
@@ -271,20 +327,21 @@
// static
TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst * rhs.iConst);
+ returnValue.setIConst(CheckedMul<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst * rhs.uConst);
+ returnValue.setUConst(CheckedMul<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst * rhs.fConst);
+ returnValue.setFConst(CheckedMul<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
diff --git a/src/compiler/translator/ConstantUnion.h b/src/compiler/translator/ConstantUnion.h
index 917c637..d148d2c 100644
--- a/src/compiler/translator/ConstantUnion.h
+++ b/src/compiler/translator/ConstantUnion.h
@@ -46,13 +46,16 @@
bool operator<(const TConstantUnion &constant) const;
static TConstantUnion add(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag);
+ TDiagnostics *diag,
+ const TSourceLoc &line);
static TConstantUnion sub(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag);
+ TDiagnostics *diag,
+ const TSourceLoc &line);
static TConstantUnion mul(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag);
+ TDiagnostics *diag,
+ const TSourceLoc &line);
TConstantUnion operator%(const TConstantUnion &constant) const;
TConstantUnion operator>>(const TConstantUnion &constant) const;
TConstantUnion operator<<(const TConstantUnion &constant) const;
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 055d529..97a6076 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -1026,7 +1026,8 @@
{
return nullptr;
}
- TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, diagnostics);
+ TConstantUnion *constArray =
+ leftConstant->foldBinary(mOp, rightConstant, diagnostics, mLeft->getLine());
// Nodes may be constant folded without being qualified as constant.
return CreateFoldedNode(constArray, this, mType.getQualifier());
@@ -1097,7 +1098,8 @@
//
TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
- TDiagnostics *diagnostics)
+ TDiagnostics *diagnostics,
+ const TSourceLoc &line)
{
const TConstantUnion *leftArray = getUnionArrayPointer();
const TConstantUnion *rightArray = rightNode->getUnionArrayPointer();
@@ -1125,12 +1127,12 @@
case EOpAdd:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
- resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics);
+ resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpSub:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
- resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics);
+ resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpMul:
@@ -1138,11 +1140,12 @@
case EOpMatrixTimesScalar:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
- resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics);
+ resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpMatrixTimesMatrix:
{
+ // TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat && rightNode->getBasicType() == EbtFloat);
const int leftCols = getCols();
@@ -1244,6 +1247,7 @@
case EOpMatrixTimesVector:
{
+ // TODO(jmadll): This code should check for overflows.
ASSERT(rightNode->getBasicType() == EbtFloat);
const int matrixCols = getCols();
@@ -1266,6 +1270,7 @@
case EOpVectorTimesMatrix:
{
+ // TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat);
const int matrixCols = rightNode->getType().getCols();
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 0fc7ab0..cf127f5 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -360,7 +360,8 @@
TConstantUnion *foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
- TDiagnostics *diagnostics);
+ TDiagnostics *diagnostics,
+ const TSourceLoc &line);
const TConstantUnion *foldIndexing(int index);
TConstantUnion *foldUnaryNonComponentWise(TOperator op);
TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);