[MIR] Add support for printing and parsing target MMO flags

Summary: Add target hooks for printing and parsing target MMO flags.
Targets may override getSerializableMachineMemOperandTargetFlags() to
return a mapping from string to flag value for target MMO values that
should be serialized/parsed in MIR output.

Add implementation of this hook for AArch64 SuppressPair MMO flag.

Reviewers: bogner, hfinkel, qcolombet, MatzeB

Subscribers: mcrosier, javed.absar, llvm-commits

Differential Revision: https://reviews.llvm.org/D34962

llvm-svn: 307877
diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.h b/llvm/lib/CodeGen/MIRParser/MILexer.h
index ed41e07..08b82e5 100644
--- a/llvm/lib/CodeGen/MIRParser/MILexer.h
+++ b/llvm/lib/CodeGen/MIRParser/MILexer.h
@@ -169,7 +169,8 @@
 
   bool isMemoryOperandFlag() const {
     return Kind == kw_volatile || Kind == kw_non_temporal ||
-           Kind == kw_dereferenceable || Kind == kw_invariant;
+           Kind == kw_dereferenceable || Kind == kw_invariant ||
+           Kind == StringConstant;
   }
 
   bool is(TokenKind K) const { return Kind == K; }
diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
index 70dca27..c68d87b 100644
--- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
@@ -141,6 +141,8 @@
   StringMap<unsigned> Names2DirectTargetFlags;
   /// Maps from direct target flag names to the bitmask target flag values.
   StringMap<unsigned> Names2BitmaskTargetFlags;
+  /// Maps from MMO target flag names to MMO target flag values.
+  StringMap<MachineMemOperand::Flags> Names2MMOTargetFlags;
 
 public:
   MIParser(PerFunctionMIParsingState &PFS, SMDiagnostic &Error,
@@ -320,6 +322,14 @@
   /// Return true if the name isn't a name of a bitmask target flag.
   bool getBitmaskTargetFlag(StringRef Name, unsigned &Flag);
 
+  void initNames2MMOTargetFlags();
+
+  /// Try to convert a name of a MachineMemOperand target flag to the
+  /// corresponding target flag.
+  ///
+  /// Return true if the name isn't a name of a target MMO flag.
+  bool getMMOTargetFlag(StringRef Name, MachineMemOperand::Flags &Flag);
+
   /// parseStringConstant
   ///   ::= StringConstant
   bool parseStringConstant(std::string &Result);
@@ -2039,7 +2049,14 @@
   case MIToken::kw_invariant:
     Flags |= MachineMemOperand::MOInvariant;
     break;
-  // TODO: parse the target specific memory operand flags.
+  case MIToken::StringConstant: {
+    MachineMemOperand::Flags TF;
+    if (getMMOTargetFlag(Token.stringValue(), TF))
+      return error("use of undefined target MMO flag '" + Token.stringValue() +
+                   "'");
+    Flags |= TF;
+    break;
+  }
   default:
     llvm_unreachable("The current token should be a memory operand flag");
   }
@@ -2480,6 +2497,27 @@
   return false;
 }
 
+void MIParser::initNames2MMOTargetFlags() {
+  if (!Names2MMOTargetFlags.empty())
+    return;
+  const auto *TII = MF.getSubtarget().getInstrInfo();
+  assert(TII && "Expected target instruction info");
+  auto Flags = TII->getSerializableMachineMemOperandTargetFlags();
+  for (const auto &I : Flags)
+    Names2MMOTargetFlags.insert(
+        std::make_pair(StringRef(I.second), I.first));
+}
+
+bool MIParser::getMMOTargetFlag(StringRef Name,
+                                MachineMemOperand::Flags &Flag) {
+  initNames2MMOTargetFlags();
+  auto FlagInfo = Names2MMOTargetFlags.find(Name);
+  if (FlagInfo == Names2MMOTargetFlags.end())
+    return true;
+  Flag = FlagInfo->second;
+  return false;
+}
+
 bool MIParser::parseStringConstant(std::string &Result) {
   if (Token.isNot(MIToken::StringConstant))
     return error("expected string constant");
diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp
index 4cc6142..ddeacf1 100644
--- a/llvm/lib/CodeGen/MIRPrinter.cpp
+++ b/llvm/lib/CodeGen/MIRPrinter.cpp
@@ -165,7 +165,8 @@
   void print(const MachineOperand &Op, const TargetRegisterInfo *TRI,
              unsigned I, bool ShouldPrintRegisterTies,
              LLT TypeToPrint, bool IsDef = false);
-  void print(const LLVMContext &Context, const MachineMemOperand &Op);
+  void print(const LLVMContext &Context, const TargetInstrInfo &TII,
+             const MachineMemOperand &Op);
   void printSyncScope(const LLVMContext &Context, SyncScope::ID SSID);
 
   void print(const MCCFIInstruction &CFI, const TargetRegisterInfo *TRI);
@@ -740,7 +741,7 @@
     for (const auto *Op : MI.memoperands()) {
       if (NeedComma)
         OS << ", ";
-      print(Context, *Op);
+      print(Context, *TII, *Op);
       NeedComma = true;
     }
   }
@@ -1036,9 +1037,20 @@
   }
 }
 
-void MIPrinter::print(const LLVMContext &Context, const MachineMemOperand &Op) {
+static const char *getTargetMMOFlagName(const TargetInstrInfo &TII,
+                                        unsigned TMMOFlag) {
+  auto Flags = TII.getSerializableMachineMemOperandTargetFlags();
+  for (const auto &I : Flags) {
+    if (I.first == TMMOFlag) {
+      return I.second;
+    }
+  }
+  return nullptr;
+}
+
+void MIPrinter::print(const LLVMContext &Context, const TargetInstrInfo &TII,
+                      const MachineMemOperand &Op) {
   OS << '(';
-  // TODO: Print operand's target specific flags.
   if (Op.isVolatile())
     OS << "volatile ";
   if (Op.isNonTemporal())
@@ -1047,6 +1059,15 @@
     OS << "dereferenceable ";
   if (Op.isInvariant())
     OS << "invariant ";
+  if (Op.getFlags() & MachineMemOperand::MOTargetFlag1)
+    OS << '"' << getTargetMMOFlagName(TII, MachineMemOperand::MOTargetFlag1)
+       << "\" ";
+  if (Op.getFlags() & MachineMemOperand::MOTargetFlag2)
+    OS << '"' << getTargetMMOFlagName(TII, MachineMemOperand::MOTargetFlag2)
+       << "\" ";
+  if (Op.getFlags() & MachineMemOperand::MOTargetFlag3)
+    OS << '"' << getTargetMMOFlagName(TII, MachineMemOperand::MOTargetFlag3)
+       << "\" ";
   if (Op.isLoad())
     OS << "load ";
   else {
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index ca39db4..afea557 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -752,6 +752,12 @@
     OS << "(dereferenceable)";
   if (isInvariant())
     OS << "(invariant)";
+  if (getFlags() & MOTargetFlag1)
+    OS << "(flag1)";
+  if (getFlags() & MOTargetFlag2)
+    OS << "(flag2)";
+  if (getFlags() & MOTargetFlag3)
+    OS << "(flag3)";
 }
 
 //===----------------------------------------------------------------------===//