TableGen: Introduce an abstract variable resolver interface

Summary:
The intention is to allow us to more easily restructure how resolving is
done, e.g. resolving multiple variables simultaneously, or using the
resolving mechanism to implement !foreach.

Change-Id: I4b976b54a32e240ad4f562f7eb86a4d663a20ea8

Reviewers: arsenm, craig.topper, tra, MartinO

Subscribers: wdng, llvm-commits

Differential Revision: https://reviews.llvm.org/D43564

llvm-svn: 326704
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index d487970..c08ae97 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -312,15 +312,15 @@
 // Fix bit initializer to preserve the behavior that bit reference from a unset
 // bits initializer will resolve into VarBitInit to keep the field name and bit
 // number used in targets with fixed insn length.
-static Init *fixBitInit(const RecordVal *RV, Init *Before, Init *After) {
-  if (RV || !isa<UnsetInit>(After))
+static Init *fixBitInit(const Resolver &R, Init *Before, Init *After) {
+  if (!isa<UnsetInit>(After) || !R.keepUnsetBits())
     return After;
   return Before;
 }
 
 // resolveReferences - If there are any field references that refer to fields
 // that have been filled in, we can propagate the values now.
-Init *BitsInit::resolveReferences(Record &R, const RecordVal *RV) const {
+Init *BitsInit::resolveReferences(Resolver &R) const {
   bool Changed = false;
   SmallVector<Init *, 16> NewBits(getNumBits());
 
@@ -337,7 +337,7 @@
     if (CurBitVar == CachedBitVar) {
       if (CachedBitVarChanged) {
         Init *Bit = CachedInit->getBit(CurBit->getBitNum());
-        NewBits[i] = fixBitInit(RV, CurBit, Bit);
+        NewBits[i] = fixBitInit(R, CurBit, Bit);
       }
       continue;
     }
@@ -347,7 +347,7 @@
     Init *B;
     do {
       B = CurBitVar;
-      CurBitVar = CurBitVar->resolveReferences(R, RV);
+      CurBitVar = CurBitVar->resolveReferences(R);
       CachedBitVarChanged |= B != CurBitVar;
       Changed |= B != CurBitVar;
     } while (B != CurBitVar);
@@ -355,7 +355,7 @@
 
     if (CachedBitVarChanged) {
       Init *Bit = CurBitVar->getBit(CurBit->getBitNum());
-      NewBits[i] = fixBitInit(RV, CurBit, Bit);
+      NewBits[i] = fixBitInit(R, CurBit, Bit);
     }
   }
 
@@ -543,7 +543,7 @@
   return DI->getDef();
 }
 
-Init *ListInit::resolveReferences(Record &R, const RecordVal *RV) const {
+Init *ListInit::resolveReferences(Resolver &R) const {
   SmallVector<Init*, 8> Resolved;
   Resolved.reserve(size());
   bool Changed = false;
@@ -553,7 +553,7 @@
 
     do {
       E = CurElt;
-      CurElt = CurElt->resolveReferences(R, RV);
+      CurElt = CurElt->resolveReferences(R);
       Changed |= E != CurElt;
     } while (E != CurElt);
     Resolved.push_back(E);
@@ -706,12 +706,13 @@
   return const_cast<UnOpInit *>(this);
 }
 
-Init *UnOpInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  Init *lhs = LHS->resolveReferences(R, RV);
+Init *UnOpInit::resolveReferences(Resolver &R) const {
+  Init *lhs = LHS->resolveReferences(R);
 
   if (LHS != lhs)
-    return (UnOpInit::get(getOpcode(), lhs, getType()))->Fold(&R, nullptr);
-  return Fold(&R, nullptr);
+    return (UnOpInit::get(getOpcode(), lhs, getType()))
+        ->Fold(R.getCurrentRecord(), nullptr);
+  return Fold(R.getCurrentRecord(), nullptr);
 }
 
 std::string UnOpInit::getAsString() const {
@@ -854,13 +855,14 @@
   return const_cast<BinOpInit *>(this);
 }
 
-Init *BinOpInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  Init *lhs = LHS->resolveReferences(R, RV);
-  Init *rhs = RHS->resolveReferences(R, RV);
+Init *BinOpInit::resolveReferences(Resolver &R) const {
+  Init *lhs = LHS->resolveReferences(R);
+  Init *rhs = RHS->resolveReferences(R);
 
   if (LHS != lhs || RHS != rhs)
-    return (BinOpInit::get(getOpcode(), lhs, rhs, getType()))->Fold(&R,nullptr);
-  return Fold(&R, nullptr);
+    return (BinOpInit::get(getOpcode(), lhs, rhs, getType()))
+        ->Fold(R.getCurrentRecord(), nullptr);
+  return Fold(R.getCurrentRecord(), nullptr);
 }
 
 std::string BinOpInit::getAsString() const {
@@ -1058,9 +1060,8 @@
   return const_cast<TernOpInit *>(this);
 }
 
-Init *TernOpInit::resolveReferences(Record &R,
-                                    const RecordVal *RV) const {
-  Init *lhs = LHS->resolveReferences(R, RV);
+Init *TernOpInit::resolveReferences(Resolver &R) const {
+  Init *lhs = LHS->resolveReferences(R);
 
   if (getOpcode() == IF && lhs != LHS) {
     IntInit *Value = dyn_cast<IntInit>(lhs);
@@ -1069,23 +1070,23 @@
     if (Value) {
       // Short-circuit
       if (Value->getValue()) {
-        Init *mhs = MHS->resolveReferences(R, RV);
-        return (TernOpInit::get(getOpcode(), lhs, mhs,
-                                RHS, getType()))->Fold(&R, nullptr);
+        Init *mhs = MHS->resolveReferences(R);
+        return (TernOpInit::get(getOpcode(), lhs, mhs, RHS, getType()))
+            ->Fold(R.getCurrentRecord(), nullptr);
       }
-      Init *rhs = RHS->resolveReferences(R, RV);
-      return (TernOpInit::get(getOpcode(), lhs, MHS,
-                              rhs, getType()))->Fold(&R, nullptr);
+      Init *rhs = RHS->resolveReferences(R);
+      return (TernOpInit::get(getOpcode(), lhs, MHS, rhs, getType()))
+          ->Fold(R.getCurrentRecord(), nullptr);
     }
   }
 
-  Init *mhs = MHS->resolveReferences(R, RV);
-  Init *rhs = RHS->resolveReferences(R, RV);
+  Init *mhs = MHS->resolveReferences(R);
+  Init *rhs = RHS->resolveReferences(R);
 
   if (LHS != lhs || MHS != mhs || RHS != rhs)
-    return (TernOpInit::get(getOpcode(), lhs, mhs, rhs,
-                            getType()))->Fold(&R, nullptr);
-  return Fold(&R, nullptr);
+    return (TernOpInit::get(getOpcode(), lhs, mhs, rhs, getType()))
+        ->Fold(R.getCurrentRecord(), nullptr);
+  return Fold(R.getCurrentRecord(), nullptr);
 }
 
 std::string TernOpInit::getAsString() const {
@@ -1248,10 +1249,9 @@
   return VarBitInit::get(const_cast<VarInit*>(this), Bit);
 }
 
-Init *VarInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  if (RecordVal *Val = R.getValue(VarName))
-    if (RV == Val || (!RV && !isa<UnsetInit>(Val->getValue())))
-      return Val->getValue();
+Init *VarInit::resolveReferences(Resolver &R) const {
+  if (Init *Val = R.resolve(VarName))
+    return Val;
   return const_cast<VarInit *>(this);
 }
 
@@ -1278,8 +1278,8 @@
   return TI->getAsString() + "{" + utostr(Bit) + "}";
 }
 
-Init *VarBitInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  Init *I = TI->resolveReferences(R, RV);
+Init *VarBitInit::resolveReferences(Resolver &R) const {
+  Init *I = TI->resolveReferences(R);
   if (TI != I)
     return I->getBit(getBitNum());
 
@@ -1302,9 +1302,8 @@
   return TI->getAsString() + "[" + utostr(Element) + "]";
 }
 
-Init *
-VarListElementInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  Init *NewTI = TI->resolveReferences(R, RV);
+Init *VarListElementInit::resolveReferences(Resolver &R) const {
+  Init *NewTI = TI->resolveReferences(R);
   if (ListInit *List = dyn_cast<ListInit>(NewTI)) {
     // Leave out-of-bounds array references as-is. This can happen without
     // being an error, e.g. in the untaken "branch" of an !if expression.
@@ -1360,12 +1359,12 @@
   return VarBitInit::get(const_cast<FieldInit*>(this), Bit);
 }
 
-Init *FieldInit::resolveReferences(Record &R, const RecordVal *RV) const {
-  Init *NewRec = Rec->resolveReferences(R, RV);
+Init *FieldInit::resolveReferences(Resolver &R) const {
+  Init *NewRec = Rec->resolveReferences(R);
 
   if (DefInit *DI = dyn_cast<DefInit>(NewRec)) {
     Init *FieldVal = DI->getDef()->getValue(FieldName)->getValue();
-    Init *BVR = FieldVal->resolveReferences(R, RV);
+    Init *BVR = FieldVal->resolveReferences(R);
     if (BVR->isComplete())
       return BVR;
   }
@@ -1438,17 +1437,17 @@
   return nullptr;
 }
 
-Init *DagInit::resolveReferences(Record &R, const RecordVal *RV) const {
+Init *DagInit::resolveReferences(Resolver &R) const {
   SmallVector<Init*, 8> NewArgs;
   NewArgs.reserve(arg_size());
   bool ArgsChanged = false;
   for (const Init *Arg : getArgs()) {
-    Init *NewArg = Arg->resolveReferences(R, RV);
+    Init *NewArg = Arg->resolveReferences(R);
     NewArgs.push_back(NewArg);
     ArgsChanged |= NewArg != Arg;
   }
 
-  Init *Op = Val->resolveReferences(R, RV);
+  Init *Op = Val->resolveReferences(R);
   if (Op != Val || ArgsChanged)
     return DagInit::get(Op, ValName, NewArgs, getArgNames());
 
@@ -1538,11 +1537,19 @@
 }
 
 void Record::resolveReferencesTo(const RecordVal *RV) {
+  RecordResolver RecResolver(*this);
+  RecordValResolver RecValResolver(*this, RV);
+  Resolver *R;
+  if (RV)
+    R = &RecValResolver;
+  else
+    R = &RecResolver;
+
   for (RecordVal &Value : Values) {
     if (RV == &Value) // Skip resolve the same field as the given one
       continue;
     if (Init *V = Value.getValue())
-      if (Value.setValue(V->resolveReferences(*this, RV)))
+      if (Value.setValue(V->resolveReferences(*R)))
         PrintFatalError(getLoc(), "Invalid value is found when setting '" +
                         Value.getNameInitAsString() +
                         "' after resolving references" +
@@ -1552,7 +1559,7 @@
                             : "") + "\n");
   }
   Init *OldName = getNameInit();
-  Init *NewName = Name->resolveReferences(*this, RV);
+  Init *NewName = Name->resolveReferences(*R);
   if (NewName != OldName) {
     // Re-register with RecordKeeper.
     setName(NewName);
@@ -1813,3 +1820,26 @@
     NewName = BinOp->Fold(&CurRec, CurMultiClass);
   return NewName;
 }
+
+Init *RecordResolver::resolve(Init *VarName) {
+  Init *Val = Cache.lookup(VarName);
+  if (Val)
+    return Val;
+
+  for (Init *S : Stack) {
+    if (S == VarName)
+      return nullptr; // prevent infinite recursion
+  }
+
+  if (RecordVal *RV = getCurrentRecord()->getValue(VarName)) {
+    if (!isa<UnsetInit>(RV->getValue())) {
+      Val = RV->getValue();
+      Stack.push_back(VarName);
+      Val = Val->resolveReferences(*this);
+      Stack.pop_back();
+    }
+  }
+
+  Cache[VarName] = Val;
+  return Val;
+}
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index baf97ba..3ef6593 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -320,7 +320,8 @@
 
     // Process each value.
     for (unsigned i = 0; i < List->size(); ++i) {
-      Init *ItemVal = List->getElement(i)->resolveReferences(*CurRec, nullptr);
+      RecordResolver R(*CurRec);
+      Init *ItemVal = List->getElement(i)->resolveReferences(R);
       IterVals.push_back(IterRecord(CurLoop.IterVar, ItemVal));
       if (ProcessForeachDefs(CurRec, Loc, IterVals))
         return true;