Escape and unescape strings in the parser and printer so they can roundtrip,
print floating point in a structured form that we know can round trip,
enumerate attributes in the visitor so we print affine mapping attributes
symbolically (the majority of the testcase updates).

We still have an issue where the hexadecimal floating point syntax is reparsed
as an integer, but that can evolve in subsequent patches.

PiperOrigin-RevId: 208828876
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 892ee5d..9e4dd65 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
@@ -187,7 +188,8 @@
 }
 
 void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
-  // TODO: visit any attributes if necessary.
+  for (auto attr : opStmt->getAttrs())
+    visitAttribute(attr.second);
 }
 
 void ModuleState::visitStatement(const Statement *stmt) {
@@ -341,6 +343,50 @@
     print(&fn);
 }
 
+/// Print a floating point value in a way that the parser will be able to
+/// round-trip losslessly.
+static void printFloatValue(double value, raw_ostream &os) {
+  APFloat apValue(value);
+
+  // We would like to output the FP constant value in exponential notation,
+  // but we cannot do this if doing so will lose precision.  Check here to
+  // make sure that we only output it in exponential format if we can parse
+  // the value back and get the same value.
+  bool isInf = apValue.isInfinity();
+  bool isNaN = apValue.isNaN();
+  if (!isInf && !isNaN) {
+    SmallString<128> strValue;
+    apValue.toString(strValue, 6, 0, false);
+
+    // Check to make sure that the stringized number is not some string like
+    // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
+    // that the string matches the "[-+]?[0-9]" regex.
+    assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
+            ((strValue[0] == '-' || strValue[0] == '+') &&
+             (strValue[1] >= '0' && strValue[1] <= '9'))) &&
+           "[-+]?[0-9] regex does not match!");
+    // Reparse stringized version!
+    if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
+      os << strValue;
+      return;
+    }
+  }
+
+  // Otherwise, print it in a hexadecimal form.  Convert it to an integer so we
+  // can print it out using integer math.
+  union {
+    double doubleValue;
+    uint64_t integerValue;
+  };
+  doubleValue = value;
+  os << "0x";
+  // Print out 16 nibbles worth of hex digit.
+  for (unsigned i = 0; i != 16; ++i) {
+    os << llvm::hexdigit(integerValue >> 60);
+    integerValue <<= 4;
+  }
+}
+
 void ModulePrinter::printAttribute(const Attribute *attr) {
   switch (attr->getKind()) {
   case Attribute::Kind::Bool:
@@ -350,20 +396,19 @@
     os << cast<IntegerAttr>(attr)->getValue();
     break;
   case Attribute::Kind::Float:
-    // FIXME: this isn't precise, we should print with a hex format.
-    os << cast<FloatAttr>(attr)->getValue();
+    printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
     break;
   case Attribute::Kind::String:
-    // FIXME: should escape the string.
-    os << '"' << cast<StringAttr>(attr)->getValue() << '"';
+    os << '"';
+    printEscapedString(cast<StringAttr>(attr)->getValue(), os);
+    os << '"';
     break;
-  case Attribute::Kind::Array: {
-    auto elts = cast<ArrayAttr>(attr)->getValue();
+  case Attribute::Kind::Array:
     os << '[';
-    interleaveComma(elts, [&](Attribute *attr) { printAttribute(attr); });
+    interleaveComma(cast<ArrayAttr>(attr)->getValue(),
+                    [&](Attribute *attr) { printAttribute(attr); });
     os << ']';
     break;
-  }
   case Attribute::Kind::AffineMap:
     printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
     break;
@@ -911,8 +956,9 @@
 }
 
 void FunctionPrinter::printDefaultOp(const Operation *op) {
-  // TODO: escape name if necessary.
-  os << "\"" << op->getName().str() << "\"(";
+  os << '"';
+  printEscapedString(op->getName().str(), os);
+  os << "\"(";
 
   interleaveComma(op->getOperands(),
                   [&](const SSAValue *value) { printValueID(value); });
@@ -1078,13 +1124,11 @@
 
   if (inst->getNumOperands() != 0) {
     os << '(';
-    // TODO: Use getOperands() when we have it.
-    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
-      printValueID(operand.get());
-    });
+    interleaveComma(inst->getOperands(),
+                    [&](const CFGValue *operand) { printValueID(operand); });
     os << ") : ";
-    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
-      printType(operand.get()->getType());
+    interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
+      printType(operand->getType());
     });
   }
 }
@@ -1198,7 +1242,7 @@
   };
 
   NumberValuesPass pass(this);
-  // TODO: it'd be cleaner to have constant visitor istead of using const_cast.
+  // TODO: it'd be cleaner to have constant visitor instead of using const_cast.
   pass.walk(const_cast<MLFunction *>(function));
 }