[llvm-exegesis][NFC] Rewrite of the YAML serialization.

Summary: This is a NFC in preparation of exporting the initial registers as part of the YAML dump

Reviewers: courbet

Reviewed By: courbet

Subscribers: mgorny, tschuett, llvm-commits

Differential Revision: https://reviews.llvm.org/D52427

llvm-svn: 342967
diff --git a/llvm/tools/llvm-exegesis/lib/BenchmarkResult.cpp b/llvm/tools/llvm-exegesis/lib/BenchmarkResult.cpp
index 77adb2f..0475228 100644
--- a/llvm/tools/llvm-exegesis/lib/BenchmarkResult.cpp
+++ b/llvm/tools/llvm-exegesis/lib/BenchmarkResult.cpp
@@ -8,6 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "BenchmarkResult.h"
+#include "BenchmarkRunner.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ObjectYAML/YAML.h"
@@ -18,75 +19,106 @@
 
 static constexpr const char kIntegerFormat[] = "i_0x%" PRId64 "x";
 static constexpr const char kDoubleFormat[] = "f_%la";
+static constexpr const char kInvalidOperand[] = "INVALID";
 
-static void serialize(const exegesis::BenchmarkResultContext &Context,
-                      const llvm::MCOperand &MCOperand, llvm::raw_ostream &OS) {
-  if (MCOperand.isReg()) {
-    OS << Context.getRegName(MCOperand.getReg());
-  } else if (MCOperand.isImm()) {
-    OS << llvm::format(kIntegerFormat, MCOperand.getImm());
-  } else if (MCOperand.isFPImm()) {
-    OS << llvm::format(kDoubleFormat, MCOperand.getFPImm());
-  } else {
-    OS << "INVALID";
-  }
-}
+// A mutable struct holding an LLVMState that can be passed through the
+// serialization process to encode/decode registers and instructions.
+struct YamlContext {
+  YamlContext(const exegesis::LLVMState &State)
+      : State(&State), ErrorStream(LastError) {}
 
-static void serialize(const exegesis::BenchmarkResultContext &Context,
-                      const llvm::MCInst &MCInst, llvm::raw_ostream &OS) {
-  OS << Context.getInstrName(MCInst.getOpcode());
-  for (const auto &Op : MCInst) {
-    OS << ' ';
-    serialize(Context, Op, OS);
-  }
-}
-
-static llvm::MCOperand
-deserialize(const exegesis::BenchmarkResultContext &Context,
-            llvm::StringRef String) {
-  assert(!String.empty());
-  int64_t IntValue = 0;
-  double DoubleValue = 0;
-  if (sscanf(String.data(), kIntegerFormat, &IntValue) == 1)
-    return llvm::MCOperand::createImm(IntValue);
-  if (sscanf(String.data(), kDoubleFormat, &DoubleValue) == 1)
-    return llvm::MCOperand::createFPImm(DoubleValue);
-  if (unsigned RegNo = Context.getRegNo(String)) // Returns 0 if invalid.
-    return llvm::MCOperand::createReg(RegNo);
-  return {};
-}
-
-static llvm::StringRef
-deserialize(const exegesis::BenchmarkResultContext &Context,
-            llvm::StringRef String, llvm::MCInst &Value) {
-  llvm::SmallVector<llvm::StringRef, 8> Pieces;
-  String.split(Pieces, " ", /* MaxSplit */ -1, /* KeepEmpty */ false);
-  if (Pieces.empty())
-    return "Invalid Instruction";
-  bool ProcessOpcode = true;
-  for (llvm::StringRef Piece : Pieces) {
-    if (ProcessOpcode) {
-      ProcessOpcode = false;
-      Value.setOpcode(Context.getInstrOpcode(Piece));
-      if (Value.getOpcode() == 0)
-        return "Unknown Opcode Name";
-    } else {
-      Value.addOperand(deserialize(Context, Piece));
+  void serializeMCInst(const llvm::MCInst &MCInst, llvm::raw_ostream &OS) {
+    OS << getInstrName(MCInst.getOpcode());
+    for (const auto &Op : MCInst) {
+      OS << ' ';
+      serializeMCOperand(Op, OS);
     }
   }
-  return {};
-}
 
-// YAML IO requires a mutable pointer to Context but we guarantee to not
-// modify it.
-static void *getUntypedContext(const exegesis::BenchmarkResultContext &Ctx) {
-  return const_cast<exegesis::BenchmarkResultContext *>(&Ctx);
-}
+  void deserializeMCInst(llvm::StringRef String, llvm::MCInst &Value) {
+    llvm::SmallVector<llvm::StringRef, 8> Pieces;
+    String.split(Pieces, " ", /* MaxSplit */ -1, /* KeepEmpty */ false);
+    if (Pieces.empty()) {
+      ErrorStream << "Unknown Instruction: '" << String << "'";
+      return;
+    }
+    bool ProcessOpcode = true;
+    for (llvm::StringRef Piece : Pieces) {
+      if (ProcessOpcode)
+        Value.setOpcode(getInstrOpcode(Piece));
+      else
+        Value.addOperand(deserializeMCOperand(Piece));
+      ProcessOpcode = false;
+    }
+  }
 
-static const exegesis::BenchmarkResultContext &getTypedContext(void *Ctx) {
-  assert(Ctx);
-  return *static_cast<const exegesis::BenchmarkResultContext *>(Ctx);
-}
+  std::string &getLastError() { return ErrorStream.str(); }
+
+private:
+  void serializeMCOperand(const llvm::MCOperand &MCOperand,
+                          llvm::raw_ostream &OS) {
+    if (MCOperand.isReg()) {
+      OS << getRegName(MCOperand.getReg());
+    } else if (MCOperand.isImm()) {
+      OS << llvm::format(kIntegerFormat, MCOperand.getImm());
+    } else if (MCOperand.isFPImm()) {
+      OS << llvm::format(kDoubleFormat, MCOperand.getFPImm());
+    } else {
+      OS << kInvalidOperand;
+    }
+  }
+
+  llvm::MCOperand deserializeMCOperand(llvm::StringRef String) {
+    assert(!String.empty());
+    int64_t IntValue = 0;
+    double DoubleValue = 0;
+    if (sscanf(String.data(), kIntegerFormat, &IntValue) == 1)
+      return llvm::MCOperand::createImm(IntValue);
+    if (sscanf(String.data(), kDoubleFormat, &DoubleValue) == 1)
+      return llvm::MCOperand::createFPImm(DoubleValue);
+    if (unsigned RegNo = getRegNo(String))
+      return llvm::MCOperand::createReg(RegNo);
+    if (String != kInvalidOperand)
+      ErrorStream << "Unknown Operand: '" << String << "'";
+    return {};
+  }
+
+  llvm::StringRef getRegName(unsigned RegNo) {
+    const llvm::StringRef RegName = State->getRegInfo().getName(RegNo);
+    if (RegName.empty())
+      ErrorStream << "No register with enum value" << RegNo;
+    return RegName;
+  }
+
+  llvm::StringRef getInstrName(unsigned InstrNo) {
+    const llvm::StringRef InstrName = State->getInstrInfo().getName(InstrNo);
+    if (InstrName.empty())
+      ErrorStream << "No opcode with enum value" << InstrNo;
+    return InstrName;
+  }
+
+  unsigned getRegNo(llvm::StringRef RegName) {
+    const llvm::MCRegisterInfo &RegInfo = State->getRegInfo();
+    for (unsigned E = RegInfo.getNumRegs(), I = 0; I < E; ++I)
+      if (RegInfo.getName(I) == RegName)
+        return I;
+    ErrorStream << "No register with name " << RegName;
+    return 0;
+  }
+
+  unsigned getInstrOpcode(llvm::StringRef InstrName) {
+    const llvm::MCInstrInfo &InstrInfo = State->getInstrInfo();
+    for (unsigned E = InstrInfo.getNumOpcodes(), I = 0; I < E; ++I)
+      if (InstrInfo.getName(I) == InstrName)
+        return I;
+    ErrorStream << "No opcode with name " << InstrName;
+    return 0;
+  }
+
+  const exegesis::LLVMState *State;
+  std::string LastError;
+  llvm::raw_string_ostream ErrorStream;
+};
 
 // Defining YAML traits for IO.
 namespace llvm {
@@ -101,11 +133,13 @@
 
   static void output(const llvm::MCInst &Value, void *Ctx,
                      llvm::raw_ostream &Out) {
-    serialize(getTypedContext(Ctx), Value, Out);
+    reinterpret_cast<YamlContext *>(Ctx)->serializeMCInst(Value, Out);
   }
 
   static StringRef input(StringRef Scalar, void *Ctx, llvm::MCInst &Value) {
-    return deserialize(getTypedContext(Ctx), Scalar, Value);
+    YamlContext &Context = *reinterpret_cast<YamlContext *>(Ctx);
+    Context.deserializeMCInst(Scalar, Value);
+    return Context.getLastError();
   }
 
   static QuotingType mustQuote(StringRef) { return QuotingType::Single; }
@@ -139,14 +173,18 @@
   }
 };
 
-template <> struct MappingTraits<exegesis::InstructionBenchmarkKey> {
-  static void mapping(IO &Io, exegesis::InstructionBenchmarkKey &Obj) {
+template <>
+struct MappingContextTraits<exegesis::InstructionBenchmarkKey, YamlContext> {
+  static void mapping(IO &Io, exegesis::InstructionBenchmarkKey &Obj,
+                      YamlContext &Context) {
+    Io.setContext(&Context);
     Io.mapRequired("instructions", Obj.Instructions);
     Io.mapOptional("config", Obj.Config);
   }
 };
 
-template <> struct MappingTraits<exegesis::InstructionBenchmark> {
+template <>
+struct MappingContextTraits<exegesis::InstructionBenchmark, YamlContext> {
   class NormalizedBinary {
   public:
     NormalizedBinary(IO &io) {}
@@ -164,9 +202,10 @@
     BinaryRef Binary;
   };
 
-  static void mapping(IO &Io, exegesis::InstructionBenchmark &Obj) {
+  static void mapping(IO &Io, exegesis::InstructionBenchmark &Obj,
+                      YamlContext &Context) {
     Io.mapRequired("mode", Obj.Mode);
-    Io.mapRequired("key", Obj.Key);
+    Io.mapRequired("key", Obj.Key, Context);
     Io.mapRequired("cpu_name", Obj.CpuName);
     Io.mapRequired("llvm_triple", Obj.LLVMTriple);
     Io.mapRequired("num_repetitions", Obj.NumRepetitions);
@@ -183,99 +222,68 @@
 } // namespace yaml
 } // namespace llvm
 
-LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(exegesis::InstructionBenchmark)
-
 namespace exegesis {
 
-void BenchmarkResultContext::addRegEntry(unsigned RegNo, llvm::StringRef Name) {
-  assert(RegNoToName.find(RegNo) == RegNoToName.end());
-  assert(RegNameToNo.find(Name) == RegNameToNo.end());
-  RegNoToName[RegNo] = Name;
-  RegNameToNo[Name] = RegNo;
-}
-
-llvm::StringRef BenchmarkResultContext::getRegName(unsigned RegNo) const {
-  const auto Itr = RegNoToName.find(RegNo);
-  if (Itr != RegNoToName.end())
-    return Itr->second;
-  return {};
-}
-
-unsigned BenchmarkResultContext::getRegNo(llvm::StringRef Name) const {
-  const auto Itr = RegNameToNo.find(Name);
-  if (Itr != RegNameToNo.end())
-    return Itr->second;
-  return 0;
-}
-
-void BenchmarkResultContext::addInstrEntry(unsigned Opcode,
-                                           llvm::StringRef Name) {
-  assert(InstrOpcodeToName.find(Opcode) == InstrOpcodeToName.end());
-  assert(InstrNameToOpcode.find(Name) == InstrNameToOpcode.end());
-  InstrOpcodeToName[Opcode] = Name;
-  InstrNameToOpcode[Name] = Opcode;
-}
-
-llvm::StringRef BenchmarkResultContext::getInstrName(unsigned Opcode) const {
-  const auto Itr = InstrOpcodeToName.find(Opcode);
-  if (Itr != InstrOpcodeToName.end())
-    return Itr->second;
-  return {};
-}
-
-unsigned BenchmarkResultContext::getInstrOpcode(llvm::StringRef Name) const {
-  const auto Itr = InstrNameToOpcode.find(Name);
-  if (Itr != InstrNameToOpcode.end())
-    return Itr->second;
-  return 0;
-}
-
-template <typename ObjectOrList>
-static llvm::Expected<ObjectOrList>
-readYamlCommon(const BenchmarkResultContext &Context,
-               llvm::StringRef Filename) {
+llvm::Expected<InstructionBenchmark>
+InstructionBenchmark::readYaml(const LLVMState &State,
+                               llvm::StringRef Filename) {
   if (auto ExpectedMemoryBuffer =
           llvm::errorOrToExpected(llvm::MemoryBuffer::getFile(Filename))) {
-    std::unique_ptr<llvm::MemoryBuffer> MemoryBuffer =
-        std::move(ExpectedMemoryBuffer.get());
-    llvm::yaml::Input Yin(*MemoryBuffer, getUntypedContext(Context));
-    ObjectOrList Benchmark;
-    Yin >> Benchmark;
+    llvm::yaml::Input Yin(*ExpectedMemoryBuffer.get());
+    YamlContext Context(State);
+    InstructionBenchmark Benchmark;
+    if (Yin.setCurrentDocument())
+      llvm::yaml::yamlize(Yin, Benchmark, /*unused*/ true, Context);
+    if (!Context.getLastError().empty())
+      return llvm::make_error<BenchmarkFailure>(Context.getLastError());
     return Benchmark;
   } else {
     return ExpectedMemoryBuffer.takeError();
   }
 }
 
-llvm::Expected<InstructionBenchmark>
-InstructionBenchmark::readYaml(const BenchmarkResultContext &Context,
-                               llvm::StringRef Filename) {
-  return readYamlCommon<InstructionBenchmark>(Context, Filename);
-}
-
 llvm::Expected<std::vector<InstructionBenchmark>>
-InstructionBenchmark::readYamls(const BenchmarkResultContext &Context,
+InstructionBenchmark::readYamls(const LLVMState &State,
                                 llvm::StringRef Filename) {
-  return readYamlCommon<std::vector<InstructionBenchmark>>(Context, Filename);
+  if (auto ExpectedMemoryBuffer =
+          llvm::errorOrToExpected(llvm::MemoryBuffer::getFile(Filename))) {
+    llvm::yaml::Input Yin(*ExpectedMemoryBuffer.get());
+    YamlContext Context(State);
+    std::vector<InstructionBenchmark> Benchmarks;
+    while (Yin.setCurrentDocument()) {
+      Benchmarks.emplace_back();
+      yamlize(Yin, Benchmarks.back(), /*unused*/ true, Context);
+      if (Yin.error())
+        return llvm::errorCodeToError(Yin.error());
+      if (!Context.getLastError().empty())
+        return llvm::make_error<BenchmarkFailure>(Context.getLastError());
+      Yin.nextDocument();
+    }
+    return Benchmarks;
+  } else {
+    return ExpectedMemoryBuffer.takeError();
+  }
 }
 
-void InstructionBenchmark::writeYamlTo(const BenchmarkResultContext &Context,
+void InstructionBenchmark::writeYamlTo(const LLVMState &State,
                                        llvm::raw_ostream &OS) {
-  llvm::yaml::Output Yout(OS, getUntypedContext(Context));
-  Yout << *this;
+  llvm::yaml::Output Yout(OS);
+  YamlContext Context(State);
+  llvm::yaml::yamlize(Yout, *this, /*unused*/ true, Context);
 }
 
-void InstructionBenchmark::readYamlFrom(const BenchmarkResultContext &Context,
+void InstructionBenchmark::readYamlFrom(const LLVMState &State,
                                         llvm::StringRef InputContent) {
-  llvm::yaml::Input Yin(InputContent, getUntypedContext(Context));
-  Yin >> *this;
+  llvm::yaml::Input Yin(InputContent);
+  YamlContext Context(State);
+  if (Yin.setCurrentDocument())
+    llvm::yaml::yamlize(Yin, *this, /*unused*/ true, Context);
 }
 
-llvm::Error
-InstructionBenchmark::writeYaml(const BenchmarkResultContext &Context,
-                                const llvm::StringRef Filename) {
+llvm::Error InstructionBenchmark::writeYaml(const LLVMState &State,
+                                            const llvm::StringRef Filename) {
   if (Filename == "-") {
-    writeYamlTo(Context, llvm::outs());
+    writeYamlTo(State, llvm::outs());
   } else {
     int ResultFD = 0;
     if (auto E = llvm::errorCodeToError(
@@ -284,7 +292,7 @@
       return E;
     }
     llvm::raw_fd_ostream Ostr(ResultFD, true /*shouldClose*/);
-    writeYamlTo(Context, Ostr);
+    writeYamlTo(State, Ostr);
   }
   return llvm::Error::success();
 }