[X86][AVX512] Add mask/maskz writemask support to constant pool shuffle decode commentx

llvm-svn: 284488
diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp
index df3c24d..4108462 100644
--- a/llvm/lib/Target/X86/X86MCInstLower.cpp
+++ b/llvm/lib/Target/X86/X86MCInstLower.cpp
@@ -1191,9 +1191,9 @@
   return C;
 }
 
-static std::string getShuffleComment(const MachineOperand &DstOp,
-                                     const MachineOperand &SrcOp1,
-                                     const MachineOperand &SrcOp2,
+static std::string getShuffleComment(const MachineInstr *MI,
+                                     unsigned SrcOp1Idx,
+                                     unsigned SrcOp2Idx,
                                      ArrayRef<int> Mask) {
   std::string Comment;
 
@@ -1206,7 +1206,10 @@
     return X86ATTInstPrinter::getRegisterName(RegNum);
   };
 
-  // TODO: Add support for specifying an AVX512 style mask register in the comment.
+  const MachineOperand &DstOp = MI->getOperand(0);
+  const MachineOperand &SrcOp1 = MI->getOperand(SrcOp1Idx);
+  const MachineOperand &SrcOp2 = MI->getOperand(SrcOp2Idx);
+
   StringRef DstName = DstOp.isReg() ? GetRegisterName(DstOp.getReg()) : "mem";
   StringRef Src1Name =
       SrcOp1.isReg() ? GetRegisterName(SrcOp1.getReg()) : "mem";
@@ -1221,7 +1224,26 @@
         ShuffleMask[i] -= e;
 
   raw_string_ostream CS(Comment);
-  CS << DstName << " = ";
+  CS << DstName;
+
+  // Handle AVX512 MASK/MASXZ write mask comments.
+  // MASK: zmmX {%kY}
+  // MASKZ: zmmX {%kY} {z}
+  if (SrcOp1Idx > 1) {
+    assert((SrcOp1Idx == 2 || SrcOp1Idx == 3) && "Unexpected writemask");
+
+    const MachineOperand &WriteMaskOp = MI->getOperand(SrcOp1Idx - 1);
+    if (WriteMaskOp.isReg()) {
+      CS << " {%" << GetRegisterName(WriteMaskOp.getReg()) << "}";
+
+      if (SrcOp1Idx == 2) {
+        CS << " {z}";
+      }
+    }
+  }
+
+  CS << " = ";
+
   for (int i = 0, e = ShuffleMask.size(); i != e; ++i) {
     if (i != 0)
       CS << ",";
@@ -1514,15 +1536,13 @@
 
     assert(MI->getNumOperands() >= 6 &&
            "We should always have at least 6 operands!");
-    const MachineOperand &DstOp = MI->getOperand(0);
-    const MachineOperand &SrcOp = MI->getOperand(SrcIdx);
-    const MachineOperand &MaskOp = MI->getOperand(MaskIdx);
 
+    const MachineOperand &MaskOp = MI->getOperand(MaskIdx);
     if (auto *C = getConstantFromPool(*MI, MaskOp)) {
       SmallVector<int, 16> Mask;
       DecodePSHUFBMask(C, Mask);
       if (!Mask.empty())
-        OutStreamer->AddComment(getShuffleComment(DstOp, SrcOp, SrcOp, Mask));
+        OutStreamer->AddComment(getShuffleComment(MI, SrcIdx, SrcIdx, Mask));
     }
     break;
   }
@@ -1587,15 +1607,13 @@
 
     assert(MI->getNumOperands() >= 6 &&
            "We should always have at least 6 operands!");
-    const MachineOperand &DstOp = MI->getOperand(0);
-    const MachineOperand &SrcOp = MI->getOperand(SrcIdx);
-    const MachineOperand &MaskOp = MI->getOperand(MaskIdx);
 
+    const MachineOperand &MaskOp = MI->getOperand(MaskIdx);
     if (auto *C = getConstantFromPool(*MI, MaskOp)) {
       SmallVector<int, 16> Mask;
       DecodeVPERMILPMask(C, ElSize, Mask);
       if (!Mask.empty())
-        OutStreamer->AddComment(getShuffleComment(DstOp, SrcOp, SrcOp, Mask));
+        OutStreamer->AddComment(getShuffleComment(MI, SrcIdx, SrcIdx, Mask));
     }
     break;
   }
@@ -1608,12 +1626,8 @@
       break;
     assert(MI->getNumOperands() >= 8 &&
            "We should always have at least 8 operands!");
-    const MachineOperand &DstOp = MI->getOperand(0);
-    const MachineOperand &SrcOp1 = MI->getOperand(1);
-    const MachineOperand &SrcOp2 = MI->getOperand(2);
-    const MachineOperand &MaskOp = MI->getOperand(6);
-    const MachineOperand &CtrlOp = MI->getOperand(MI->getNumOperands() - 1);
 
+    const MachineOperand &CtrlOp = MI->getOperand(MI->getNumOperands() - 1);
     if (!CtrlOp.isImm())
       break;
 
@@ -1624,11 +1638,12 @@
     case X86::VPERMIL2PDrm: case X86::VPERMIL2PDrmY: ElSize = 64; break;
     }
 
+    const MachineOperand &MaskOp = MI->getOperand(6);
     if (auto *C = getConstantFromPool(*MI, MaskOp)) {
       SmallVector<int, 16> Mask;
       DecodeVPERMIL2PMask(C, (unsigned)CtrlOp.getImm(), ElSize, Mask);
       if (!Mask.empty())
-        OutStreamer->AddComment(getShuffleComment(DstOp, SrcOp1, SrcOp2, Mask));
+        OutStreamer->AddComment(getShuffleComment(MI, 1, 2, Mask));
     }
     break;
   }
@@ -1638,16 +1653,13 @@
       break;
     assert(MI->getNumOperands() >= 7 &&
            "We should always have at least 7 operands!");
-    const MachineOperand &DstOp = MI->getOperand(0);
-    const MachineOperand &SrcOp1 = MI->getOperand(1);
-    const MachineOperand &SrcOp2 = MI->getOperand(2);
-    const MachineOperand &MaskOp = MI->getOperand(6);
 
+    const MachineOperand &MaskOp = MI->getOperand(6);
     if (auto *C = getConstantFromPool(*MI, MaskOp)) {
       SmallVector<int, 16> Mask;
       DecodeVPPERMMask(C, Mask);
       if (!Mask.empty())
-        OutStreamer->AddComment(getShuffleComment(DstOp, SrcOp1, SrcOp2, Mask));
+        OutStreamer->AddComment(getShuffleComment(MI, 1, 2, Mask));
     }
     break;
   }