[AMDGPU] Switched HSA metadata to use MsgPackDocument

Summary:
MsgPackDocument is the lighter-weight replacement for MsgPackTypes. This
commit switches AMDGPU HSA metadata processing to use MsgPackDocument
instead of MsgPackTypes.

Differential Revision: https://reviews.llvm.org/D57024

Change-Id: I0751668013abe8c87db01db1170831a76079b3a6
llvm-svn: 356081
diff --git a/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp b/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp
index 48f996e..2c865ec 100644
--- a/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp
+++ b/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp
@@ -20,98 +20,92 @@
 namespace V3 {
 
 bool MetadataVerifier::verifyScalar(
-    msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
-    function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
-  auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node);
-  if (!ScalarPtr)
+    msgpack::DocNode &Node, msgpack::Type SKind,
+    function_ref<bool(msgpack::DocNode &)> verifyValue) {
+  if (!Node.isScalar())
     return false;
-  auto &Scalar = *ScalarPtr;
-  // Do not output extraneous tags for types we know from the spec.
-  Scalar.IgnoreTag = true;
-  if (Scalar.getScalarKind() != SKind) {
+  if (Node.getKind() != SKind) {
     if (Strict)
       return false;
     // If we are not strict, we interpret string values as "implicitly typed"
     // and attempt to coerce them to the expected type here.
-    if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String)
+    if (Node.getKind() != msgpack::Type::String)
       return false;
-    std::string StringValue = Scalar.getString();
-    Scalar.setScalarKind(SKind);
-    if (Scalar.inputYAML(StringValue) != StringRef())
+    StringRef StringValue = Node.getString();
+    Node.fromString(StringValue);
+    if (Node.getKind() != SKind)
       return false;
   }
   if (verifyValue)
-    return verifyValue(Scalar);
+    return verifyValue(Node);
   return true;
 }
 
-bool MetadataVerifier::verifyInteger(msgpack::Node &Node) {
-  if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt))
-    if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int))
+bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
+  if (!verifyScalar(Node, msgpack::Type::UInt))
+    if (!verifyScalar(Node, msgpack::Type::Int))
       return false;
   return true;
 }
 
 bool MetadataVerifier::verifyArray(
-    msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode,
+    msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
     Optional<size_t> Size) {
-  auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node);
-  if (!ArrayPtr)
+  if (!Node.isArray())
     return false;
-  auto &Array = *ArrayPtr;
+  auto &Array = Node.getArray();
   if (Size && Array.size() != *Size)
     return false;
   for (auto &Item : Array)
-    if (!verifyNode(*Item.get()))
+    if (!verifyNode(Item))
       return false;
 
   return true;
 }
 
 bool MetadataVerifier::verifyEntry(
-    msgpack::MapNode &MapNode, StringRef Key, bool Required,
-    function_ref<bool(msgpack::Node &)> verifyNode) {
+    msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
+    function_ref<bool(msgpack::DocNode &)> verifyNode) {
   auto Entry = MapNode.find(Key);
   if (Entry == MapNode.end())
     return !Required;
-  return verifyNode(*Entry->second.get());
+  return verifyNode(Entry->second);
 }
 
 bool MetadataVerifier::verifyScalarEntry(
-    msgpack::MapNode &MapNode, StringRef Key, bool Required,
-    msgpack::ScalarNode::ScalarKind SKind,
-    function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
-  return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) {
+    msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
+    msgpack::Type SKind,
+    function_ref<bool(msgpack::DocNode &)> verifyValue) {
+  return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
     return verifyScalar(Node, SKind, verifyValue);
   });
 }
 
-bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode,
+bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
                                           StringRef Key, bool Required) {
-  return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) {
+  return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
     return verifyInteger(Node);
   });
 }
 
-bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
-  auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node);
-  if (!ArgsMapPtr)
+bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
+  if (!Node.isMap())
     return false;
-  auto &ArgsMap = *ArgsMapPtr;
+  auto &ArgsMap = Node.getMap();
 
   if (!verifyScalarEntry(ArgsMap, ".name", false,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyIntegerEntry(ArgsMap, ".size", true))
     return false;
   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("by_value", true)
                                .Case("global_buffer", true)
@@ -131,8 +125,8 @@
                          }))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("struct", true)
                                .Case("i8", true)
@@ -152,8 +146,8 @@
   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("private", true)
                                .Case("global", true)
@@ -165,8 +159,8 @@
                          }))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".access", false,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("read_only", true)
                                .Case("write_only", true)
@@ -175,8 +169,8 @@
                          }))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("read_only", true)
                                .Case("write_only", true)
@@ -185,36 +179,35 @@
                          }))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
-                         msgpack::ScalarNode::SK_Boolean))
+                         msgpack::Type::Boolean))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
-                         msgpack::ScalarNode::SK_Boolean))
+                         msgpack::Type::Boolean))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
-                         msgpack::ScalarNode::SK_Boolean))
+                         msgpack::Type::Boolean))
     return false;
   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
-                         msgpack::ScalarNode::SK_Boolean))
+                         msgpack::Type::Boolean))
     return false;
 
   return true;
 }
 
-bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
-  auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node);
-  if (!KernelMapPtr)
+bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
+  if (!Node.isMap())
     return false;
-  auto &KernelMap = *KernelMapPtr;
+  auto &KernelMap = Node.getMap();
 
   if (!verifyScalarEntry(KernelMap, ".name", true,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyScalarEntry(KernelMap, ".symbol", true,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyScalarEntry(KernelMap, ".language", false,
-                         msgpack::ScalarNode::SK_String,
-                         [](msgpack::ScalarNode &SNode) {
+                         msgpack::Type::String,
+                         [](msgpack::DocNode &SNode) {
                            return StringSwitch<bool>(SNode.getString())
                                .Case("OpenCL C", true)
                                .Case("OpenCL C++", true)
@@ -226,41 +219,41 @@
                          }))
     return false;
   if (!verifyEntry(
-          KernelMap, ".language_version", false, [this](msgpack::Node &Node) {
+          KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
             return verifyArray(
                 Node,
-                [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
+                [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
           }))
     return false;
-  if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) {
-        return verifyArray(Node, [this](msgpack::Node &Node) {
+  if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
+        return verifyArray(Node, [this](msgpack::DocNode &Node) {
           return verifyKernelArgs(Node);
         });
       }))
     return false;
   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
-                   [this](msgpack::Node &Node) {
+                   [this](msgpack::DocNode &Node) {
                      return verifyArray(Node,
-                                        [this](msgpack::Node &Node) {
+                                        [this](msgpack::DocNode &Node) {
                                           return verifyInteger(Node);
                                         },
                                         3);
                    }))
     return false;
   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
-                   [this](msgpack::Node &Node) {
+                   [this](msgpack::DocNode &Node) {
                      return verifyArray(Node,
-                                        [this](msgpack::Node &Node) {
+                                        [this](msgpack::DocNode &Node) {
                                           return verifyInteger(Node);
                                         },
                                         3);
                    }))
     return false;
   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
-                         msgpack::ScalarNode::SK_String))
+                         msgpack::Type::String))
     return false;
   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
     return false;
@@ -286,29 +279,28 @@
   return true;
 }
 
-bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) {
-  auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot);
-  if (!RootMapPtr)
+bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
+  if (!HSAMetadataRoot.isMap())
     return false;
-  auto &RootMap = *RootMapPtr;
+  auto &RootMap = HSAMetadataRoot.getMap();
 
   if (!verifyEntry(
-          RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) {
+          RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
             return verifyArray(
                 Node,
-                [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
+                [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
           }))
     return false;
   if (!verifyEntry(
-          RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) {
-            return verifyArray(Node, [this](msgpack::Node &Node) {
-              return verifyScalar(Node, msgpack::ScalarNode::SK_String);
+          RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
+            return verifyArray(Node, [this](msgpack::DocNode &Node) {
+              return verifyScalar(Node, msgpack::Type::String);
             });
           }))
     return false;
   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
-                   [this](msgpack::Node &Node) {
-                     return verifyArray(Node, [this](msgpack::Node &Node) {
+                   [this](msgpack::DocNode &Node) {
+                     return verifyArray(Node, [this](msgpack::DocNode &Node) {
                        return verifyKernel(Node);
                      });
                    }))
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
index f93ccf6..b4bed4e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
@@ -489,20 +489,16 @@
 void MetadataStreamerV3::verify(StringRef HSAMetadataString) const {
   errs() << "AMDGPU HSA Metadata Parser Test: ";
 
-  std::shared_ptr<msgpack::Node> FromHSAMetadataString =
-      std::make_shared<msgpack::MapNode>();
+  msgpack::Document FromHSAMetadataString;
 
-  yaml::Input YIn(HSAMetadataString);
-  YIn >> FromHSAMetadataString;
-  if (YIn.error()) {
+  if (!FromHSAMetadataString.fromYAML(HSAMetadataString)) {
     errs() << "FAIL\n";
     return;
   }
 
   std::string ToHSAMetadataString;
   raw_string_ostream StrOS(ToHSAMetadataString);
-  yaml::Output YOut(StrOS);
-  YOut << FromHSAMetadataString;
+  FromHSAMetadataString.toYAML(StrOS);
 
   errs() << (HSAMetadataString == StrOS.str() ? "PASS" : "FAIL") << '\n';
   if (HSAMetadataString != ToHSAMetadataString) {
@@ -636,23 +632,23 @@
   }
 }
 
-std::shared_ptr<msgpack::ArrayNode>
+msgpack::ArrayDocNode
 MetadataStreamerV3::getWorkGroupDimensions(MDNode *Node) const {
-  auto Dims = std::make_shared<msgpack::ArrayNode>();
+  auto Dims = HSAMetadataDoc->getArrayNode();
   if (Node->getNumOperands() != 3)
     return Dims;
 
   for (auto &Op : Node->operands())
-    Dims->push_back(std::make_shared<msgpack::ScalarNode>(
-        mdconst::extract<ConstantInt>(Op)->getZExtValue()));
+    Dims.push_back(Dims.getDocument()->getNode(
+        uint64_t(mdconst::extract<ConstantInt>(Op)->getZExtValue())));
   return Dims;
 }
 
 void MetadataStreamerV3::emitVersion() {
-  auto Version = std::make_shared<msgpack::ArrayNode>();
-  Version->push_back(std::make_shared<msgpack::ScalarNode>(V3::VersionMajor));
-  Version->push_back(std::make_shared<msgpack::ScalarNode>(V3::VersionMinor));
-  getRootMetadata("amdhsa.version") = std::move(Version);
+  auto Version = HSAMetadataDoc->getArrayNode();
+  Version.push_back(Version.getDocument()->getNode(VersionMajor));
+  Version.push_back(Version.getDocument()->getNode(VersionMinor));
+  getRootMetadata("amdhsa.version") = Version;
 }
 
 void MetadataStreamerV3::emitPrintf(const Module &Mod) {
@@ -660,16 +656,16 @@
   if (!Node)
     return;
 
-  auto Printf = std::make_shared<msgpack::ArrayNode>();
+  auto Printf = HSAMetadataDoc->getArrayNode();
   for (auto Op : Node->operands())
     if (Op->getNumOperands())
-      Printf->push_back(std::make_shared<msgpack::ScalarNode>(
-          cast<MDString>(Op->getOperand(0))->getString()));
-  getRootMetadata("amdhsa.printf") = std::move(Printf);
+      Printf.push_back(Printf.getDocument()->getNode(
+          cast<MDString>(Op->getOperand(0))->getString(), /*Copy=*/true));
+  getRootMetadata("amdhsa.printf") = Printf;
 }
 
 void MetadataStreamerV3::emitKernelLanguage(const Function &Func,
-                                            msgpack::MapNode &Kern) {
+                                            msgpack::MapDocNode Kern) {
   // TODO: What about other languages?
   auto Node = Func.getParent()->getNamedMetadata("opencl.ocl.version");
   if (!Node || !Node->getNumOperands())
@@ -678,50 +674,53 @@
   if (Op0->getNumOperands() <= 1)
     return;
 
-  Kern[".language"] = std::make_shared<msgpack::ScalarNode>("OpenCL C");
-  auto LanguageVersion = std::make_shared<msgpack::ArrayNode>();
-  LanguageVersion->push_back(std::make_shared<msgpack::ScalarNode>(
+  Kern[".language"] = Kern.getDocument()->getNode("OpenCL C");
+  auto LanguageVersion = Kern.getDocument()->getArrayNode();
+  LanguageVersion.push_back(Kern.getDocument()->getNode(
       mdconst::extract<ConstantInt>(Op0->getOperand(0))->getZExtValue()));
-  LanguageVersion->push_back(std::make_shared<msgpack::ScalarNode>(
+  LanguageVersion.push_back(Kern.getDocument()->getNode(
       mdconst::extract<ConstantInt>(Op0->getOperand(1))->getZExtValue()));
-  Kern[".language_version"] = std::move(LanguageVersion);
+  Kern[".language_version"] = LanguageVersion;
 }
 
 void MetadataStreamerV3::emitKernelAttrs(const Function &Func,
-                                         msgpack::MapNode &Kern) {
+                                         msgpack::MapDocNode Kern) {
 
   if (auto Node = Func.getMetadata("reqd_work_group_size"))
     Kern[".reqd_workgroup_size"] = getWorkGroupDimensions(Node);
   if (auto Node = Func.getMetadata("work_group_size_hint"))
     Kern[".workgroup_size_hint"] = getWorkGroupDimensions(Node);
   if (auto Node = Func.getMetadata("vec_type_hint")) {
-    Kern[".vec_type_hint"] = std::make_shared<msgpack::ScalarNode>(getTypeName(
-        cast<ValueAsMetadata>(Node->getOperand(0))->getType(),
-        mdconst::extract<ConstantInt>(Node->getOperand(1))->getZExtValue()));
+    Kern[".vec_type_hint"] = Kern.getDocument()->getNode(
+        getTypeName(
+            cast<ValueAsMetadata>(Node->getOperand(0))->getType(),
+            mdconst::extract<ConstantInt>(Node->getOperand(1))->getZExtValue()),
+        /*Copy=*/true);
   }
   if (Func.hasFnAttribute("runtime-handle")) {
-    Kern[".device_enqueue_symbol"] = std::make_shared<msgpack::ScalarNode>(
-        Func.getFnAttribute("runtime-handle").getValueAsString().str());
+    Kern[".device_enqueue_symbol"] = Kern.getDocument()->getNode(
+        Func.getFnAttribute("runtime-handle").getValueAsString().str(),
+        /*Copy=*/true);
   }
 }
 
 void MetadataStreamerV3::emitKernelArgs(const Function &Func,
-                                        msgpack::MapNode &Kern) {
+                                        msgpack::MapDocNode Kern) {
   unsigned Offset = 0;
-  auto Args = std::make_shared<msgpack::ArrayNode>();
+  auto Args = HSAMetadataDoc->getArrayNode();
   for (auto &Arg : Func.args())
-    emitKernelArg(Arg, Offset, *Args);
+    emitKernelArg(Arg, Offset, Args);
 
-  emitHiddenKernelArgs(Func, Offset, *Args);
+  emitHiddenKernelArgs(Func, Offset, Args);
 
   // TODO: What about other languages?
   if (Func.getParent()->getNamedMetadata("opencl.ocl.version")) {
     auto &DL = Func.getParent()->getDataLayout();
     auto Int64Ty = Type::getInt64Ty(Func.getContext());
 
-    emitKernelArg(DL, Int64Ty, "hidden_global_offset_x", Offset, *Args);
-    emitKernelArg(DL, Int64Ty, "hidden_global_offset_y", Offset, *Args);
-    emitKernelArg(DL, Int64Ty, "hidden_global_offset_z", Offset, *Args);
+    emitKernelArg(DL, Int64Ty, "hidden_global_offset_x", Offset, Args);
+    emitKernelArg(DL, Int64Ty, "hidden_global_offset_y", Offset, Args);
+    emitKernelArg(DL, Int64Ty, "hidden_global_offset_z", Offset, Args);
 
     auto Int8PtrTy =
         Type::getInt8PtrTy(Func.getContext(), AMDGPUAS::GLOBAL_ADDRESS);
@@ -729,26 +728,26 @@
     // Emit "printf buffer" argument if printf is used, otherwise emit dummy
     // "none" argument.
     if (Func.getParent()->getNamedMetadata("llvm.printf.fmts"))
-      emitKernelArg(DL, Int8PtrTy, "hidden_printf_buffer", Offset, *Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_printf_buffer", Offset, Args);
     else
-      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, *Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, Args);
 
     // Emit "default queue" and "completion action" arguments if enqueue kernel
     // is used, otherwise emit dummy "none" arguments.
     if (Func.hasFnAttribute("calls-enqueue-kernel")) {
-      emitKernelArg(DL, Int8PtrTy, "hidden_default_queue", Offset, *Args);
-      emitKernelArg(DL, Int8PtrTy, "hidden_completion_action", Offset, *Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_default_queue", Offset, Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_completion_action", Offset, Args);
     } else {
-      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, *Args);
-      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, *Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, Args);
+      emitKernelArg(DL, Int8PtrTy, "hidden_none", Offset, Args);
     }
   }
 
-  Kern[".args"] = std::move(Args);
+  Kern[".args"] = Args;
 }
 
 void MetadataStreamerV3::emitKernelArg(const Argument &Arg, unsigned &Offset,
-                                       msgpack::ArrayNode &Args) {
+                                       msgpack::ArrayDocNode Args) {
   auto Func = Arg.getParent();
   auto ArgNo = Arg.getArgNo();
   const MDNode *Node;
@@ -805,36 +804,35 @@
 
 void MetadataStreamerV3::emitKernelArg(const DataLayout &DL, Type *Ty,
                                        StringRef ValueKind, unsigned &Offset,
-                                       msgpack::ArrayNode &Args,
+                                       msgpack::ArrayDocNode Args,
                                        unsigned PointeeAlign, StringRef Name,
                                        StringRef TypeName,
                                        StringRef BaseTypeName,
                                        StringRef AccQual, StringRef TypeQual) {
-  auto ArgPtr = std::make_shared<msgpack::MapNode>();
-  auto &Arg = *ArgPtr;
+  auto Arg = Args.getDocument()->getMapNode();
 
   if (!Name.empty())
-    Arg[".name"] = std::make_shared<msgpack::ScalarNode>(Name);
+    Arg[".name"] = Arg.getDocument()->getNode(Name, /*Copy=*/true);
   if (!TypeName.empty())
-    Arg[".type_name"] = std::make_shared<msgpack::ScalarNode>(TypeName);
+    Arg[".type_name"] = Arg.getDocument()->getNode(TypeName, /*Copy=*/true);
   auto Size = DL.getTypeAllocSize(Ty);
   auto Align = DL.getABITypeAlignment(Ty);
-  Arg[".size"] = std::make_shared<msgpack::ScalarNode>(Size);
+  Arg[".size"] = Arg.getDocument()->getNode(Size);
   Offset = alignTo(Offset, Align);
-  Arg[".offset"] = std::make_shared<msgpack::ScalarNode>(Offset);
+  Arg[".offset"] = Arg.getDocument()->getNode(Offset);
   Offset += Size;
-  Arg[".value_kind"] = std::make_shared<msgpack::ScalarNode>(ValueKind);
+  Arg[".value_kind"] = Arg.getDocument()->getNode(ValueKind, /*Copy=*/true);
   Arg[".value_type"] =
-      std::make_shared<msgpack::ScalarNode>(getValueType(Ty, BaseTypeName));
+      Arg.getDocument()->getNode(getValueType(Ty, BaseTypeName), /*Copy=*/true);
   if (PointeeAlign)
-    Arg[".pointee_align"] = std::make_shared<msgpack::ScalarNode>(PointeeAlign);
+    Arg[".pointee_align"] = Arg.getDocument()->getNode(PointeeAlign);
 
   if (auto PtrTy = dyn_cast<PointerType>(Ty))
     if (auto Qualifier = getAddressSpaceQualifier(PtrTy->getAddressSpace()))
-      Arg[".address_space"] = std::make_shared<msgpack::ScalarNode>(*Qualifier);
+      Arg[".address_space"] = Arg.getDocument()->getNode(*Qualifier, /*Copy=*/true);
 
   if (auto AQ = getAccessQualifier(AccQual))
-    Arg[".access"] = std::make_shared<msgpack::ScalarNode>(*AQ);
+    Arg[".access"] = Arg.getDocument()->getNode(*AQ, /*Copy=*/true);
 
   // TODO: Emit Arg[".actual_access"].
 
@@ -842,21 +840,21 @@
   TypeQual.split(SplitTypeQuals, " ", -1, false);
   for (StringRef Key : SplitTypeQuals) {
     if (Key == "const")
-      Arg[".is_const"] = std::make_shared<msgpack::ScalarNode>(true);
+      Arg[".is_const"] = Arg.getDocument()->getNode(true);
     else if (Key == "restrict")
-      Arg[".is_restrict"] = std::make_shared<msgpack::ScalarNode>(true);
+      Arg[".is_restrict"] = Arg.getDocument()->getNode(true);
     else if (Key == "volatile")
-      Arg[".is_volatile"] = std::make_shared<msgpack::ScalarNode>(true);
+      Arg[".is_volatile"] = Arg.getDocument()->getNode(true);
     else if (Key == "pipe")
-      Arg[".is_pipe"] = std::make_shared<msgpack::ScalarNode>(true);
+      Arg[".is_pipe"] = Arg.getDocument()->getNode(true);
   }
 
-  Args.push_back(std::move(ArgPtr));
+  Args.push_back(Arg);
 }
 
 void MetadataStreamerV3::emitHiddenKernelArgs(const Function &Func,
                                               unsigned &Offset,
-                                              msgpack::ArrayNode &Args) {
+                                              msgpack::ArrayDocNode Args) {
   int HiddenArgNumBytes =
       getIntegerAttribute(Func, "amdgpu-implicitarg-num-bytes", 0);
 
@@ -898,54 +896,52 @@
   }
 }
 
-std::shared_ptr<msgpack::MapNode>
+msgpack::MapDocNode
 MetadataStreamerV3::getHSAKernelProps(const MachineFunction &MF,
                                       const SIProgramInfo &ProgramInfo) const {
   const GCNSubtarget &STM = MF.getSubtarget<GCNSubtarget>();
   const SIMachineFunctionInfo &MFI = *MF.getInfo<SIMachineFunctionInfo>();
   const Function &F = MF.getFunction();
 
-  auto HSAKernelProps = std::make_shared<msgpack::MapNode>();
-  auto &Kern = *HSAKernelProps;
+  auto Kern = HSAMetadataDoc->getMapNode();
 
   unsigned MaxKernArgAlign;
-  Kern[".kernarg_segment_size"] = std::make_shared<msgpack::ScalarNode>(
+  Kern[".kernarg_segment_size"] = Kern.getDocument()->getNode(
       STM.getKernArgSegmentSize(F, MaxKernArgAlign));
   Kern[".group_segment_fixed_size"] =
-      std::make_shared<msgpack::ScalarNode>(ProgramInfo.LDSSize);
+      Kern.getDocument()->getNode(ProgramInfo.LDSSize);
   Kern[".private_segment_fixed_size"] =
-      std::make_shared<msgpack::ScalarNode>(ProgramInfo.ScratchSize);
+      Kern.getDocument()->getNode(ProgramInfo.ScratchSize);
   Kern[".kernarg_segment_align"] =
-      std::make_shared<msgpack::ScalarNode>(std::max(uint32_t(4), MaxKernArgAlign));
+      Kern.getDocument()->getNode(std::max(uint32_t(4), MaxKernArgAlign));
   Kern[".wavefront_size"] =
-      std::make_shared<msgpack::ScalarNode>(STM.getWavefrontSize());
-  Kern[".sgpr_count"] = std::make_shared<msgpack::ScalarNode>(ProgramInfo.NumSGPR);
-  Kern[".vgpr_count"] = std::make_shared<msgpack::ScalarNode>(ProgramInfo.NumVGPR);
+      Kern.getDocument()->getNode(STM.getWavefrontSize());
+  Kern[".sgpr_count"] = Kern.getDocument()->getNode(ProgramInfo.NumSGPR);
+  Kern[".vgpr_count"] = Kern.getDocument()->getNode(ProgramInfo.NumVGPR);
   Kern[".max_flat_workgroup_size"] =
-      std::make_shared<msgpack::ScalarNode>(MFI.getMaxFlatWorkGroupSize());
+      Kern.getDocument()->getNode(MFI.getMaxFlatWorkGroupSize());
   Kern[".sgpr_spill_count"] =
-      std::make_shared<msgpack::ScalarNode>(MFI.getNumSpilledSGPRs());
+      Kern.getDocument()->getNode(MFI.getNumSpilledSGPRs());
   Kern[".vgpr_spill_count"] =
-      std::make_shared<msgpack::ScalarNode>(MFI.getNumSpilledVGPRs());
+      Kern.getDocument()->getNode(MFI.getNumSpilledVGPRs());
 
-  return HSAKernelProps;
+  return Kern;
 }
 
 bool MetadataStreamerV3::emitTo(AMDGPUTargetStreamer &TargetStreamer) {
-  return TargetStreamer.EmitHSAMetadata(getHSAMetadataRoot(), true);
+  return TargetStreamer.EmitHSAMetadata(*HSAMetadataDoc, true);
 }
 
 void MetadataStreamerV3::begin(const Module &Mod) {
   emitVersion();
   emitPrintf(Mod);
-  getRootMetadata("amdhsa.kernels").reset(new msgpack::ArrayNode());
+  getRootMetadata("amdhsa.kernels") = HSAMetadataDoc->getArrayNode();
 }
 
 void MetadataStreamerV3::end() {
   std::string HSAMetadataString;
   raw_string_ostream StrOS(HSAMetadataString);
-  yaml::Output YOut(StrOS);
-  YOut << HSAMetadataRoot;
+  HSAMetadataDoc->toYAML(StrOS);
 
   if (DumpHSAMetadata)
     dump(StrOS.str());
@@ -956,25 +952,24 @@
 void MetadataStreamerV3::emitKernel(const MachineFunction &MF,
                                     const SIProgramInfo &ProgramInfo) {
   auto &Func = MF.getFunction();
-  auto KernelProps = getHSAKernelProps(MF, ProgramInfo);
+  auto Kern = getHSAKernelProps(MF, ProgramInfo);
 
   assert(Func.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
          Func.getCallingConv() == CallingConv::SPIR_KERNEL);
 
-  auto &KernelsNode = getRootMetadata("amdhsa.kernels");
-  auto Kernels = cast<msgpack::ArrayNode>(KernelsNode.get());
+  auto Kernels =
+      getRootMetadata("amdhsa.kernels").getArray(/*Convert=*/true);
 
   {
-    auto &Kern = *KernelProps;
-    Kern[".name"] = std::make_shared<msgpack::ScalarNode>(Func.getName());
-    Kern[".symbol"] = std::make_shared<msgpack::ScalarNode>(
-        (Twine(Func.getName()) + Twine(".kd")).str());
+    Kern[".name"] = Kern.getDocument()->getNode(Func.getName());
+    Kern[".symbol"] = Kern.getDocument()->getNode(
+        (Twine(Func.getName()) + Twine(".kd")).str(), /*Copy=*/true);
     emitKernelLanguage(Func, Kern);
     emitKernelAttrs(Func, Kern);
     emitKernelArgs(Func, Kern);
   }
 
-  Kernels->push_back(std::move(KernelProps));
+  Kernels.push_back(Kern);
 }
 
 } // end namespace HSAMD
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
index 5835ed7..2eecddb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
@@ -18,7 +18,7 @@
 #include "AMDGPU.h"
 #include "AMDKernelCodeT.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/BinaryFormat/MsgPackTypes.h"
+#include "llvm/BinaryFormat/MsgPackDocument.h"
 #include "llvm/Support/AMDGPUMetadata.h"
 
 namespace llvm {
@@ -51,8 +51,8 @@
 
 class MetadataStreamerV3 final : public MetadataStreamer {
 private:
-  std::shared_ptr<msgpack::Node> HSAMetadataRoot =
-      std::make_shared<msgpack::MapNode>();
+  std::unique_ptr<msgpack::Document> HSAMetadataDoc =
+      llvm::make_unique<msgpack::Document>();
 
   void dump(StringRef HSAMetadataString) const;
 
@@ -69,41 +69,39 @@
 
   std::string getTypeName(Type *Ty, bool Signed) const;
 
-  std::shared_ptr<msgpack::ArrayNode>
-  getWorkGroupDimensions(MDNode *Node) const;
+  msgpack::ArrayDocNode getWorkGroupDimensions(MDNode *Node) const;
 
-  std::shared_ptr<msgpack::MapNode>
-  getHSAKernelProps(const MachineFunction &MF,
-                    const SIProgramInfo &ProgramInfo) const;
+  msgpack::MapDocNode getHSAKernelProps(const MachineFunction &MF,
+                                        const SIProgramInfo &ProgramInfo) const;
 
   void emitVersion();
 
   void emitPrintf(const Module &Mod);
 
-  void emitKernelLanguage(const Function &Func, msgpack::MapNode &Kern);
+  void emitKernelLanguage(const Function &Func, msgpack::MapDocNode Kern);
 
-  void emitKernelAttrs(const Function &Func, msgpack::MapNode &Kern);
+  void emitKernelAttrs(const Function &Func, msgpack::MapDocNode Kern);
 
-  void emitKernelArgs(const Function &Func, msgpack::MapNode &Kern);
+  void emitKernelArgs(const Function &Func, msgpack::MapDocNode Kern);
 
   void emitKernelArg(const Argument &Arg, unsigned &Offset,
-                     msgpack::ArrayNode &Args);
+                     msgpack::ArrayDocNode Args);
 
   void emitKernelArg(const DataLayout &DL, Type *Ty, StringRef ValueKind,
-                     unsigned &Offset, msgpack::ArrayNode &Args,
+                     unsigned &Offset, msgpack::ArrayDocNode Args,
                      unsigned PointeeAlign = 0, StringRef Name = "",
                      StringRef TypeName = "", StringRef BaseTypeName = "",
                      StringRef AccQual = "", StringRef TypeQual = "");
 
   void emitHiddenKernelArgs(const Function &Func, unsigned &Offset,
-                            msgpack::ArrayNode &Args);
+                            msgpack::ArrayDocNode Args);
 
-  std::shared_ptr<msgpack::Node> &getRootMetadata(StringRef Key) {
-    return (*cast<msgpack::MapNode>(HSAMetadataRoot.get()))[Key];
+  msgpack::DocNode &getRootMetadata(StringRef Key) {
+    return HSAMetadataDoc->getRoot().getMap(/*Convert=*/true)[Key];
   }
 
-  std::shared_ptr<msgpack::Node> &getHSAMetadataRoot() {
-    return HSAMetadataRoot;
+  msgpack::DocNode &getHSAMetadataRoot() {
+    return HSAMetadataDoc->getRoot();
   }
 
 public:
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp
index c1c00f4..6373756 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.cpp
@@ -18,7 +18,6 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
 #include "llvm/BinaryFormat/ELF.h"
-#include "llvm/BinaryFormat/MsgPackTypes.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Metadata.h"
@@ -51,12 +50,10 @@
 }
 
 bool AMDGPUTargetStreamer::EmitHSAMetadataV3(StringRef HSAMetadataString) {
-  std::shared_ptr<msgpack::Node> HSAMetadataRoot;
-  yaml::Input YIn(HSAMetadataString);
-  YIn >> HSAMetadataRoot;
-  if (YIn.error())
+  msgpack::Document HSAMetadataDoc;
+  if (!HSAMetadataDoc.fromYAML(HSAMetadataString))
     return false;
-  return EmitHSAMetadata(HSAMetadataRoot, false);
+  return EmitHSAMetadata(HSAMetadataDoc, false);
 }
 
 StringRef AMDGPUTargetStreamer::getArchNameFromElfMach(unsigned ElfMach) {
@@ -213,15 +210,14 @@
 }
 
 bool AMDGPUTargetAsmStreamer::EmitHSAMetadata(
-    std::shared_ptr<msgpack::Node> &HSAMetadataRoot, bool Strict) {
+    msgpack::Document &HSAMetadataDoc, bool Strict) {
   V3::MetadataVerifier Verifier(Strict);
-  if (!Verifier.verify(*HSAMetadataRoot))
+  if (!Verifier.verify(HSAMetadataDoc.getRoot()))
     return false;
 
   std::string HSAMetadataString;
   raw_string_ostream StrOS(HSAMetadataString);
-  yaml::Output YOut(StrOS);
-  YOut << HSAMetadataRoot;
+  HSAMetadataDoc.toYAML(StrOS);
 
   OS << '\t' << V3::AssemblerDirectiveBegin << '\n';
   OS << StrOS.str() << '\n';
@@ -481,16 +477,14 @@
   return true;
 }
 
-bool AMDGPUTargetELFStreamer::EmitHSAMetadata(
-    std::shared_ptr<msgpack::Node> &HSAMetadataRoot, bool Strict) {
+bool AMDGPUTargetELFStreamer::EmitHSAMetadata(msgpack::Document &HSAMetadataDoc,
+                                              bool Strict) {
   V3::MetadataVerifier Verifier(Strict);
-  if (!Verifier.verify(*HSAMetadataRoot))
+  if (!Verifier.verify(HSAMetadataDoc.getRoot()))
     return false;
 
   std::string HSAMetadataString;
-  raw_string_ostream StrOS(HSAMetadataString);
-  msgpack::Writer MPWriter(StrOS);
-  HSAMetadataRoot->write(MPWriter);
+  HSAMetadataDoc.writeToBlob(HSAMetadataString);
 
   // Create two labels to mark the beginning and end of the desc field
   // and a MCExpr to calculate the size of the desc field.
@@ -504,7 +498,7 @@
   EmitNote(ElfNote::NoteNameV3, DescSZ, ELF::NT_AMDGPU_METADATA,
            [&](MCELFStreamer &OS) {
              OS.EmitLabel(DescBegin);
-             OS.EmitBytes(StrOS.str());
+             OS.EmitBytes(HSAMetadataString);
              OS.EmitLabel(DescEnd);
            });
   return true;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.h
index c7a5839..5d9cb7c 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUTargetStreamer.h
@@ -10,7 +10,7 @@
 #define LLVM_LIB_TARGET_AMDGPU_MCTARGETDESC_AMDGPUTARGETSTREAMER_H
 
 #include "AMDKernelCodeT.h"
-#include "llvm/BinaryFormat/MsgPackTypes.h"
+#include "llvm/BinaryFormat/MsgPackDocument.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSubtargetInfo.h"
 #include "llvm/Support/AMDGPUMetadata.h"
@@ -64,8 +64,7 @@
   /// the \p HSAMetadata structure is updated with the correct types.
   ///
   /// \returns True on success, false on failure.
-  virtual bool EmitHSAMetadata(std::shared_ptr<msgpack::Node> &HSAMetadata,
-                               bool Strict) = 0;
+  virtual bool EmitHSAMetadata(msgpack::Document &HSAMetadata, bool Strict) = 0;
 
   /// \returns True on success, false on failure.
   virtual bool EmitHSAMetadata(const AMDGPU::HSAMD::Metadata &HSAMetadata) = 0;
@@ -105,8 +104,7 @@
   bool EmitISAVersion(StringRef IsaVersionString) override;
 
   /// \returns True on success, false on failure.
-  bool EmitHSAMetadata(std::shared_ptr<msgpack::Node> &HSAMetadata,
-                       bool Strict) override;
+  bool EmitHSAMetadata(msgpack::Document &HSAMetadata, bool Strict) override;
 
   /// \returns True on success, false on failure.
   bool EmitHSAMetadata(const AMDGPU::HSAMD::Metadata &HSAMetadata) override;
@@ -149,8 +147,7 @@
   bool EmitISAVersion(StringRef IsaVersionString) override;
 
   /// \returns True on success, false on failure.
-  bool EmitHSAMetadata(std::shared_ptr<msgpack::Node> &HSAMetadata,
-                       bool Strict) override;
+  bool EmitHSAMetadata(msgpack::Document &HSAMetadata, bool Strict) override;
 
   /// \returns True on success, false on failure.
   bool EmitHSAMetadata(const AMDGPU::HSAMD::Metadata &HSAMetadata) override;