Add robust math to constant folding.
Previously our multiplication and other operators could
do overflows, which can lead to security bugs.
BUG=chromium:637050
Change-Id: Icee22a87909e205b71bda1c5bc1627fcf5e26e90
Reviewed-on: https://chromium-review.googlesource.com/382678
Commit-Queue: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ConstantUnion.cpp b/src/compiler/translator/ConstantUnion.cpp
index 042f610..74bf570 100644
--- a/src/compiler/translator/ConstantUnion.cpp
+++ b/src/compiler/translator/ConstantUnion.cpp
@@ -7,8 +7,62 @@
#include "compiler/translator/ConstantUnion.h"
+#include "base/numerics/safe_math.h"
#include "compiler/translator/Diagnostics.h"
+namespace
+{
+
+template <typename T>
+T CheckedSum(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs + rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Addition out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedDiff(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs - rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Difference out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+template <typename T>
+T CheckedMul(base::CheckedNumeric<T> lhs,
+ base::CheckedNumeric<T> rhs,
+ TDiagnostics *diag,
+ const TSourceLoc &line)
+{
+ ASSERT(lhs.IsValid() && rhs.IsValid());
+ auto result = lhs * rhs;
+ if (!result.IsValid())
+ {
+ diag->error(line, "Multiplication out of range", "*", "");
+ return 0;
+ }
+ return result.ValueOrDefault(0);
+}
+
+} // anonymous namespace
+
TConstantUnion::TConstantUnion()
{
iConst = 0;
@@ -221,20 +275,21 @@
// static
TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst + rhs.iConst);
+ returnValue.setIConst(CheckedSum<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst + rhs.uConst);
+ returnValue.setUConst(CheckedSum<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst + rhs.fConst);
+ returnValue.setFConst(CheckedSum<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
@@ -246,20 +301,21 @@
// static
TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst - rhs.iConst);
+ returnValue.setIConst(CheckedDiff<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst - rhs.uConst);
+ returnValue.setUConst(CheckedDiff<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst - rhs.fConst);
+ returnValue.setFConst(CheckedDiff<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
@@ -271,20 +327,21 @@
// static
TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
const TConstantUnion &rhs,
- TDiagnostics *diag)
+ TDiagnostics *diag,
+ const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
- returnValue.setIConst(lhs.iConst * rhs.iConst);
+ returnValue.setIConst(CheckedMul<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
- returnValue.setUConst(lhs.uConst * rhs.uConst);
+ returnValue.setUConst(CheckedMul<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
- returnValue.setFConst(lhs.fConst * rhs.fConst);
+ returnValue.setFConst(CheckedMul<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();