TableGen: Type-check BinOps

Additionally, allow more than two operands to !con, !add, !and, !or
in the same way as is already allowed for !listconcat and !strconcat.

Change-Id: I9659411f554201b90cd8ed7c7e004d381a66fa93

Differential revision: https://reviews.llvm.org/D44112

llvm-svn: 327494
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 0dd4da4..d839c2a 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -981,28 +981,59 @@
     Lex.Lex();  // eat the operation
 
     BinOpInit::BinaryOp Code;
-    RecTy *Type = nullptr;
-
     switch (OpTok) {
     default: llvm_unreachable("Unhandled code!");
-    case tgtok::XConcat: Code = BinOpInit::CONCAT;Type = DagRecTy::get(); break;
-    case tgtok::XADD:    Code = BinOpInit::ADD;   Type = IntRecTy::get(); break;
-    case tgtok::XAND:    Code = BinOpInit::AND;   Type = IntRecTy::get(); break;
-    case tgtok::XOR:     Code = BinOpInit::OR;    Type = IntRecTy::get(); break;
-    case tgtok::XSRA:    Code = BinOpInit::SRA;   Type = IntRecTy::get(); break;
-    case tgtok::XSRL:    Code = BinOpInit::SRL;   Type = IntRecTy::get(); break;
-    case tgtok::XSHL:    Code = BinOpInit::SHL;   Type = IntRecTy::get(); break;
-    case tgtok::XEq:     Code = BinOpInit::EQ;    Type = BitRecTy::get(); break;
+    case tgtok::XConcat: Code = BinOpInit::CONCAT; break;
+    case tgtok::XADD:    Code = BinOpInit::ADD; break;
+    case tgtok::XAND:    Code = BinOpInit::AND; break;
+    case tgtok::XOR:     Code = BinOpInit::OR; break;
+    case tgtok::XSRA:    Code = BinOpInit::SRA; break;
+    case tgtok::XSRL:    Code = BinOpInit::SRL; break;
+    case tgtok::XSHL:    Code = BinOpInit::SHL; break;
+    case tgtok::XEq:     Code = BinOpInit::EQ; break;
+    case tgtok::XListConcat: Code = BinOpInit::LISTCONCAT; break;
+    case tgtok::XStrConcat: Code = BinOpInit::STRCONCAT; break;
+    }
+
+    RecTy *Type = nullptr;
+    RecTy *ArgType = nullptr;
+    switch (OpTok) {
+    default:
+      llvm_unreachable("Unhandled code!");
+    case tgtok::XConcat:
+      Type = DagRecTy::get();
+      ArgType = DagRecTy::get();
+      break;
+    case tgtok::XAND:
+    case tgtok::XOR:
+    case tgtok::XSRA:
+    case tgtok::XSRL:
+    case tgtok::XSHL:
+    case tgtok::XADD:
+      Type = IntRecTy::get();
+      ArgType = IntRecTy::get();
+      break;
+    case tgtok::XEq:
+      Type = BitRecTy::get();
+      // ArgType for Eq is not known at this point
+      break;
     case tgtok::XListConcat:
-      Code = BinOpInit::LISTCONCAT;
       // We don't know the list type until we parse the first argument
+      ArgType = ItemType;
       break;
     case tgtok::XStrConcat:
-      Code = BinOpInit::STRCONCAT;
       Type = StringRecTy::get();
+      ArgType = StringRecTy::get();
       break;
     }
 
+    if (Type && ItemType && !Type->typeIsConvertibleTo(ItemType)) {
+      Error(OpLoc, Twine("expected value of type '") +
+                   ItemType->getAsString() + "', got '" +
+                   Type->getAsString() + "'");
+      return nullptr;
+    }
+
     if (Lex.getCode() != tgtok::l_paren) {
       TokError("expected '(' after binary operator");
       return nullptr;
@@ -1011,14 +1042,51 @@
 
     SmallVector<Init*, 2> InitList;
 
-    InitList.push_back(ParseValue(CurRec));
-    if (!InitList.back()) return nullptr;
-
-    while (Lex.getCode() == tgtok::comma) {
-      Lex.Lex();  // eat the ','
-
-      InitList.push_back(ParseValue(CurRec));
+    for (;;) {
+      SMLoc InitLoc = Lex.getLoc();
+      InitList.push_back(ParseValue(CurRec, ArgType));
       if (!InitList.back()) return nullptr;
+
+      // All BinOps require their arguments to be of compatible types.
+      TypedInit *TI = dyn_cast<TypedInit>(InitList.back());
+      if (!ArgType) {
+        ArgType = TI->getType();
+
+        switch (Code) {
+        case BinOpInit::LISTCONCAT:
+          if (!isa<ListRecTy>(ArgType)) {
+            Error(InitLoc, Twine("expected a list, got value of type '") +
+                           ArgType->getAsString() + "'");
+            return nullptr;
+          }
+          break;
+        case BinOpInit::EQ:
+          if (!ArgType->typeIsConvertibleTo(IntRecTy::get()) &&
+              !ArgType->typeIsConvertibleTo(StringRecTy::get())) {
+            Error(InitLoc, Twine("expected int, bits, or string; got value of "
+                                 "type '") + ArgType->getAsString() + "'");
+            return nullptr;
+          }
+          break;
+        default: llvm_unreachable("other ops have fixed argument types");
+        }
+      } else {
+        RecTy *Resolved = resolveTypes(ArgType, TI->getType());
+        if (!Resolved) {
+          Error(InitLoc, Twine("expected value of type '") +
+                         ArgType->getAsString() + "', got '" +
+                         TI->getType()->getAsString() + "'");
+          return nullptr;
+        }
+        if (Code != BinOpInit::ADD && Code != BinOpInit::AND &&
+            Code != BinOpInit::OR && Code != BinOpInit::SRA &&
+            Code != BinOpInit::SRL && Code != BinOpInit::SHL)
+          ArgType = Resolved;
+      }
+
+      if (Lex.getCode() != tgtok::comma)
+        break;
+      Lex.Lex();  // eat the ','
     }
 
     if (Lex.getCode() != tgtok::r_paren) {
@@ -1027,20 +1095,14 @@
     }
     Lex.Lex();  // eat the ')'
 
-    // If we are doing !listconcat, we should know the type by now
-    if (OpTok == tgtok::XListConcat) {
-      if (TypedInit *Arg0 = dyn_cast<TypedInit>(InitList[0]))
-        Type = Arg0->getType();
-      else {
-        InitList[0]->print(errs());
-        Error(OpLoc, "expected a list");
-        return nullptr;
-      }
-    }
+    if (Code == BinOpInit::LISTCONCAT)
+      Type = ArgType;
 
     // We allow multiple operands to associative operators like !strconcat as
     // shorthand for nesting them.
-    if (Code == BinOpInit::STRCONCAT || Code == BinOpInit::LISTCONCAT) {
+    if (Code == BinOpInit::STRCONCAT || Code == BinOpInit::LISTCONCAT ||
+        Code == BinOpInit::CONCAT || Code == BinOpInit::ADD ||
+        Code == BinOpInit::AND || Code == BinOpInit::OR) {
       while (InitList.size() > 2) {
         Init *RHS = InitList.pop_back_val();
         RHS = (BinOpInit::get(Code, InitList.back(), RHS, Type))
@@ -1896,7 +1958,7 @@
         break;
 
       default:
-        Init *RHSResult = ParseValue(CurRec, ItemType, ParseNameMode);
+        Init *RHSResult = ParseValue(CurRec, nullptr, ParseNameMode);
         RHS = dyn_cast<TypedInit>(RHSResult);
         if (!RHS) {
           Error(PasteLoc, "RHS of paste is not typed!");