Allow constant folding some non-constant expressions

This requires removing the assumption that constant folding implies
constness in the constant expression sense from various places in the
code.

This particularly benefits ternary operators, which can now be simplified
if just the condition is a compile-time constant.

In the future, the groundwork that is laid here could be used to implement
more aggressive constant folding of user-defined functions for example.

TEST=angle_unittests
BUG=angleproject:851

Change-Id: I0eede806570d56746c3dad1e01aa89a91d66013d
Reviewed-on: https://chromium-review.googlesource.com/310750
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Zhenyao Mo <zmo@chromium.org>
diff --git a/src/compiler/translator/ParseContext.cpp b/src/compiler/translator/ParseContext.cpp
index 8d05d1e..fff2b8e 100644
--- a/src/compiler/translator/ParseContext.cpp
+++ b/src/compiler/translator/ParseContext.cpp
@@ -164,6 +164,23 @@
     mDiagnostics.writeInfo(pp::Diagnostics::PP_WARNING, srcLoc, reason, token, extraInfo);
 }
 
+void TParseContext::outOfRangeError(bool isError,
+                                    const TSourceLoc &loc,
+                                    const char *reason,
+                                    const char *token,
+                                    const char *extraInfo)
+{
+    if (isError)
+    {
+        error(loc, reason, token, extraInfo);
+        recover();
+    }
+    else
+    {
+        warning(loc, reason, token, extraInfo);
+    }
+}
+
 //
 // Same error message for all places assignments don't work.
 //
@@ -735,7 +752,10 @@
 {
     TIntermConstantUnion *constant = expr->getAsConstantUnion();
 
-    if (constant == nullptr || !constant->isScalarInt())
+    // TODO(oetuaho@nvidia.com): Get rid of the constant == nullptr check here once all constant
+    // expressions can be folded. Right now we don't allow constant expressions that ANGLE can't
+    // fold as array size.
+    if (expr->getQualifier() != EvqConst || constant == nullptr || !constant->isScalarInt())
     {
         error(line, "array size must be a constant integer expression", "");
         size = 1;
@@ -1188,11 +1208,10 @@
 {
     const TVariable *variable = getNamedVariable(location, name, symbol);
 
-    if (variable->getType().getQualifier() == EvqConst && variable->getConstPointer())
+    if (variable->getConstPointer())
     {
         TConstantUnion *constArray = variable->getConstPointer();
-        TType t(variable->getType());
-        return intermediate.addConstantUnion(constArray, t, location);
+        return intermediate.addConstantUnion(constArray, variable->getType(), location);
     }
     else
     {
@@ -1350,24 +1369,15 @@
     return false;
 }
 
-bool TParseContext::areAllChildConst(TIntermAggregate *aggrNode)
+bool TParseContext::areAllChildrenConstantFolded(TIntermAggregate *aggrNode)
 {
-    ASSERT(aggrNode != NULL);
-    if (!aggrNode->isConstructor())
-        return false;
-
-    bool allConstant = true;
-
-    // check if all the child nodes are constants so that they can be inserted into
-    // the parent node
-    TIntermSequence *sequence = aggrNode->getSequence();
-    for (TIntermSequence::iterator p = sequence->begin(); p != sequence->end(); ++p)
+    ASSERT(aggrNode != nullptr);
+    for (TIntermNode *&node : *aggrNode->getSequence())
     {
-        if (!(*p)->getAsTyped()->getAsConstantUnion())
+        if (node->getAsConstantUnion() == nullptr)
             return false;
     }
-
-    return allConstant;
+    return true;
 }
 
 TPublicType TParseContext::addFullySpecifiedType(TQualifier qualifier,
@@ -2295,11 +2305,10 @@
 
     // Turn the argument list itself into a constructor
     TIntermAggregate *constructor  = intermediate.setAggregateOperator(aggregateArguments, op, line);
-    TIntermTyped *constConstructor = foldConstConstructor(constructor, *type);
-    if (constConstructor)
-    {
-        return constConstructor;
-    }
+    ASSERT(constructor->isConstructor());
+
+    // Need to set type before setPrecisionFromChildren() because bool doesn't have precision.
+    constructor->setType(*type);
 
     // Structs should not be precision qualified, the individual members may be.
     // Built-in types on the other hand should be precision qualified.
@@ -2309,14 +2318,19 @@
         type->setPrecision(constructor->getPrecision());
     }
 
+    TIntermTyped *constConstructor = foldConstConstructor(constructor, *type);
+    if (constConstructor)
+    {
+        return constConstructor;
+    }
+
     return constructor;
 }
 
 TIntermTyped *TParseContext::foldConstConstructor(TIntermAggregate *aggrNode, const TType &type)
 {
     // TODO: Add support for folding array constructors
-    bool canBeFolded = areAllChildConst(aggrNode) && !type.isArray();
-    aggrNode->setType(type);
+    bool canBeFolded = areAllChildrenConstantFolded(aggrNode) && !type.isArray();
     if (canBeFolded)
     {
         bool returnVal             = false;
@@ -2345,36 +2359,16 @@
 // vector.
 // If only one component of vector is accessed (v.x or v[0] where v is a contant vector), then a
 // contant node is returned, else an aggregate node is returned (for v.xy). The input to this
-// function could either
-// be the symbol node or it could be the intermediate tree representation of accessing fields in a
-// constant
-// structure or column of a constant matrix.
+// function could either be the symbol node or it could be the intermediate tree representation of
+// accessing fields in a constant structure or column of a constant matrix.
 //
 TIntermTyped *TParseContext::addConstVectorNode(TVectorFields &fields,
-                                                TIntermTyped *node,
-                                                const TSourceLoc &line)
+                                                TIntermConstantUnion *node,
+                                                const TSourceLoc &line,
+                                                bool outOfRangeIndexIsError)
 {
-    TIntermTyped *typedNode;
-    TIntermConstantUnion *tempConstantNode = node->getAsConstantUnion();
-
-    const TConstantUnion *unionArray;
-    if (tempConstantNode)
-    {
-        unionArray = tempConstantNode->getUnionArrayPointer();
-
-        if (!unionArray)
-        {
-            return node;
-        }
-    }
-    else
-    {  // The node has to be either a symbol node or an aggregate node or a tempConstant node, else,
-        // its an error
-        error(line, "Cannot offset into the vector", "Error");
-        recover();
-
-        return 0;
-    }
+    const TConstantUnion *unionArray = node->getUnionArrayPointer();
+    ASSERT(unionArray);
 
     TConstantUnion *constArray = new TConstantUnion[fields.num];
 
@@ -2385,59 +2379,39 @@
             std::stringstream extraInfoStream;
             extraInfoStream << "vector field selection out of range '" << fields.offsets[i] << "'";
             std::string extraInfo = extraInfoStream.str();
-            error(line, "", "[", extraInfo.c_str());
-            recover();
-            fields.offsets[i] = 0;
+            outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
+            fields.offsets[i] = node->getType().getNominalSize() - 1;
         }
 
         constArray[i] = unionArray[fields.offsets[i]];
     }
-    typedNode = intermediate.addConstantUnion(constArray, node->getType(), line);
-    return typedNode;
+    return intermediate.addConstantUnion(constArray, node->getType(), line);
 }
 
 //
 // This function returns the column being accessed from a constant matrix. The values are retrieved
 // from the symbol table and parse-tree is built for a vector (each column of a matrix is a vector).
-// The
-// input to the function could either be a symbol node (m[0] where m is a constant matrix)that
-// represents
-// a constant matrix or it could be the tree representation of the constant matrix (s.m1[0] where s
-// is a constant structure)
+// The input to the function could either be a symbol node (m[0] where m is a constant matrix)that
+// represents a constant matrix or it could be the tree representation of the constant matrix
+// (s.m1[0] where s is a constant structure)
 //
 TIntermTyped *TParseContext::addConstMatrixNode(int index,
-                                                TIntermTyped *node,
-                                                const TSourceLoc &line)
+                                                TIntermConstantUnion *node,
+                                                const TSourceLoc &line,
+                                                bool outOfRangeIndexIsError)
 {
-    TIntermTyped *typedNode;
-    TIntermConstantUnion *tempConstantNode = node->getAsConstantUnion();
-
     if (index >= node->getType().getCols())
     {
         std::stringstream extraInfoStream;
         extraInfoStream << "matrix field selection out of range '" << index << "'";
         std::string extraInfo = extraInfoStream.str();
-        error(line, "", "[", extraInfo.c_str());
-        recover();
-        index = 0;
+        outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
+        index = node->getType().getCols() - 1;
     }
 
-    if (tempConstantNode)
-    {
-        TConstantUnion *unionArray = tempConstantNode->getUnionArrayPointer();
-        int size                   = tempConstantNode->getType().getCols();
-        typedNode = intermediate.addConstantUnion(&unionArray[size * index],
-                                                  tempConstantNode->getType(), line);
-    }
-    else
-    {
-        error(line, "Cannot offset into the matrix", "Error");
-        recover();
-
-        return 0;
-    }
-
-    return typedNode;
+    TConstantUnion *unionArray = node->getUnionArrayPointer();
+    int size = node->getType().getCols();
+    return intermediate.addConstantUnion(&unionArray[size * index], node->getType(), line);
 }
 
 //
@@ -2448,11 +2422,10 @@
 // constant structure)
 //
 TIntermTyped *TParseContext::addConstArrayNode(int index,
-                                               TIntermTyped *node,
-                                               const TSourceLoc &line)
+                                               TIntermConstantUnion *node,
+                                               const TSourceLoc &line,
+                                               bool outOfRangeIndexIsError)
 {
-    TIntermTyped *typedNode;
-    TIntermConstantUnion *tempConstantNode = node->getAsConstantUnion();
     TType arrayElementType = node->getType();
     arrayElementType.clearArrayness();
 
@@ -2461,27 +2434,13 @@
         std::stringstream extraInfoStream;
         extraInfoStream << "array field selection out of range '" << index << "'";
         std::string extraInfo = extraInfoStream.str();
-        error(line, "", "[", extraInfo.c_str());
-        recover();
-        index = 0;
+        outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
+        index = node->getType().getArraySize() - 1;
     }
-
-    if (tempConstantNode)
-    {
-        size_t arrayElementSize    = arrayElementType.getObjectSize();
-        TConstantUnion *unionArray = tempConstantNode->getUnionArrayPointer();
-        typedNode = intermediate.addConstantUnion(&unionArray[arrayElementSize * index],
-                                                  tempConstantNode->getType(), line);
-    }
-    else
-    {
-        error(line, "Cannot offset into the array", "Error");
-        recover();
-
-        return 0;
-    }
-
-    return typedNode;
+    size_t arrayElementSize    = arrayElementType.getObjectSize();
+    TConstantUnion *unionArray = node->getUnionArrayPointer();
+    return intermediate.addConstantUnion(&unionArray[arrayElementSize * index], node->getType(),
+                                         line);
 }
 
 //
@@ -2775,25 +2734,31 @@
 
     TIntermConstantUnion *indexConstantUnion = indexExpression->getAsConstantUnion();
 
-    if (indexExpression->getQualifier() == EvqConst && indexConstantUnion)
+    if (indexConstantUnion)
     {
+        // If the index is not qualified as constant, the behavior in the spec is undefined. This
+        // applies even if ANGLE has been able to constant fold it (ANGLE may constant fold
+        // expressions that are not constant expressions). The most compatible way to handle this
+        // case is to report a warning instead of an error and force the index to be in the
+        // correct range.
+        bool outOfRangeIndexIsError = indexExpression->getQualifier() == EvqConst;
         int index = indexConstantUnion->getIConst(0);
         if (index < 0)
         {
             std::stringstream infoStream;
             infoStream << index;
             std::string info = infoStream.str();
-            error(location, "negative index", info.c_str());
-            recover();
+            outOfRangeError(outOfRangeIndexIsError, location, "negative index", info.c_str());
             index = 0;
         }
-        if (baseExpression->getType().getQualifier() == EvqConst &&
-            baseExpression->getAsConstantUnion())
+        TIntermConstantUnion *baseConstantUnion = baseExpression->getAsConstantUnion();
+        if (baseConstantUnion)
         {
             if (baseExpression->isArray())
             {
                 // constant folding for array indexing
-                indexedExpression = addConstArrayNode(index, baseExpression, location);
+                indexedExpression =
+                    addConstArrayNode(index, baseConstantUnion, location, outOfRangeIndexIsError);
             }
             else if (baseExpression->isVector())
             {
@@ -2802,12 +2767,14 @@
                 fields.num = 1;
                 fields.offsets[0] =
                     index;  // need to do it this way because v.xy sends fields integer array
-                indexedExpression = addConstVectorNode(fields, baseExpression, location);
+                indexedExpression =
+                    addConstVectorNode(fields, baseConstantUnion, location, outOfRangeIndexIsError);
             }
             else if (baseExpression->isMatrix())
             {
                 // constant folding for matrix indexing
-                indexedExpression = addConstMatrixNode(index, baseExpression, location);
+                indexedExpression =
+                    addConstMatrixNode(index, baseConstantUnion, location, outOfRangeIndexIsError);
             }
         }
         else
@@ -2821,17 +2788,16 @@
                     std::stringstream extraInfoStream;
                     extraInfoStream << "array index out of range '" << index << "'";
                     std::string extraInfo = extraInfoStream.str();
-                    error(location, "", "[", extraInfo.c_str());
-                    recover();
+                    outOfRangeError(outOfRangeIndexIsError, location, "", "[", extraInfo.c_str());
                     safeIndex = baseExpression->getType().getArraySize() - 1;
                 }
                 else if (baseExpression->getQualifier() == EvqFragData && index > 0 &&
                          !isExtensionEnabled("GL_EXT_draw_buffers"))
                 {
-                    error(location, "", "[",
-                          "array indexes for gl_FragData must be zero when GL_EXT_draw_buffers is "
-                          "disabled");
-                    recover();
+                    outOfRangeError(
+                        outOfRangeIndexIsError, location, "", "[",
+                        "array indexes for gl_FragData must be zero when GL_EXT_draw_buffers is "
+                        "disabled");
                     safeIndex = 0;
                 }
             }
@@ -2841,8 +2807,7 @@
                 std::stringstream extraInfoStream;
                 extraInfoStream << "field selection out of range '" << index << "'";
                 std::string extraInfo = extraInfoStream.str();
-                error(location, "", "[", extraInfo.c_str());
-                recover();
+                outOfRangeError(outOfRangeIndexIsError, location, "", "[", extraInfo.c_str());
                 safeIndex = baseExpression->getType().getNominalSize() - 1;
             }
 
@@ -2945,32 +2910,29 @@
             recover();
         }
 
-        if (baseExpression->getType().getQualifier() == EvqConst &&
-            baseExpression->getAsConstantUnion())
+        if (baseExpression->getAsConstantUnion())
         {
             // constant folding for vector fields
-            indexedExpression = addConstVectorNode(fields, baseExpression, fieldLocation);
-            if (indexedExpression == 0)
-            {
-                recover();
-                indexedExpression = baseExpression;
-            }
-            else
-            {
-                indexedExpression->setType(TType(baseExpression->getBasicType(),
-                                                 baseExpression->getPrecision(), EvqConst,
-                                                 (unsigned char)(fieldString).size()));
-            }
+            indexedExpression = addConstVectorNode(fields, baseExpression->getAsConstantUnion(),
+                                                   fieldLocation, true);
         }
         else
         {
-            TString vectorString = fieldString;
             TIntermTyped *index = intermediate.addSwizzle(fields, fieldLocation);
             indexedExpression =
                 intermediate.addIndex(EOpVectorSwizzle, baseExpression, index, dotLocation);
+        }
+        if (indexedExpression == nullptr)
+        {
+            recover();
+            indexedExpression = baseExpression;
+        }
+        else
+        {
+            // Note that the qualifier set here will be corrected later.
             indexedExpression->setType(TType(baseExpression->getBasicType(),
                                              baseExpression->getPrecision(), EvqTemporary,
-                                             (unsigned char)vectorString.size()));
+                                             (unsigned char)(fieldString).size()));
         }
     }
     else if (baseExpression->getBasicType() == EbtStruct)
@@ -2996,8 +2958,7 @@
             }
             if (fieldFound)
             {
-                if (baseExpression->getType().getQualifier() == EvqConst &&
-                    baseExpression->getAsConstantUnion())
+                if (baseExpression->getAsConstantUnion())
                 {
                     indexedExpression = addConstStruct(fieldString, baseExpression, dotLocation);
                     if (indexedExpression == 0)
@@ -3008,9 +2969,6 @@
                     else
                     {
                         indexedExpression->setType(*fields[i]->type());
-                        // change the qualifier of the return type, not of the structure field
-                        // as the structure definition is shared between various structures.
-                        indexedExpression->getTypePointer()->setQualifier(EvqConst);
                     }
                 }
                 else
@@ -3093,6 +3051,10 @@
     {
         indexedExpression->getTypePointer()->setQualifier(EvqConst);
     }
+    else
+    {
+        indexedExpression->getTypePointer()->setQualifier(EvqTemporary);
+    }
 
     return indexedExpression;
 }
@@ -3406,7 +3368,10 @@
         recover();
     }
     TIntermConstantUnion *conditionConst = condition->getAsConstantUnion();
-    if (conditionConst == nullptr)
+    // TODO(oetuaho@nvidia.com): Get rid of the conditionConst == nullptr check once all constant
+    // expressions can be folded. Right now we don't allow constant expressions that ANGLE can't
+    // fold in case labels.
+    if (condition->getQualifier() != EvqConst || conditionConst == nullptr)
     {
         error(condition->getLine(), "case label must be constant", "case");
         recover();
@@ -3938,7 +3903,8 @@
                     // Some built-in functions have out parameters too.
                     functionCallLValueErrorCheck(fnCandidate, aggregate);
 
-                    // See if we can constant fold a built-in.
+                    // See if we can constant fold a built-in. Note that this may be possible even
+                    // if it is not const-qualified.
                     TIntermTyped *foldedNode = intermediate.foldAggregateBuiltIn(aggregate);
                     if (foldedNode)
                     {