Implement ML function arguments. Add representation for argument list in ML Function using TrailingObjects template. Implement argument iterators, parsing and printing.
Unrelated minor change - remove OperationStmt::dropReferences(). Since MLFunction does not have cyclic operand references (it's an AST) destruction can be safely done w/o a special pass to drop references.
PiperOrigin-RevId: 207583024
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6875120..cb7560e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -200,6 +200,7 @@
ModuleState &state;
void printFunctionSignature(const Function *fn);
+ void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(const AffineMap *affineMap);
@@ -527,14 +528,7 @@
// Function printing
//===----------------------------------------------------------------------===//
-void ModulePrinter::printFunctionSignature(const Function *fn) {
- auto type = fn->getType();
-
- os << "@" << fn->getName() << '(';
- interleaveComma(type->getInputs(),
- [&](Type *eltType) { printType(eltType); });
- os << ')';
-
+void ModulePrinter::printFunctionResultType(const FunctionType *type) {
switch (type->getResults().size()) {
case 0:
break;
@@ -551,6 +545,17 @@
}
}
+void ModulePrinter::printFunctionSignature(const Function *fn) {
+ auto type = fn->getType();
+
+ os << "@" << fn->getName() << '(';
+ interleaveComma(type->getInputs(),
+ [&](Type *eltType) { printType(eltType); });
+ os << ')';
+
+ printFunctionResultType(type);
+}
+
void ModulePrinter::print(const ExtFunction *fn) {
os << "extfunc ";
printFunctionSignature(fn);
@@ -627,7 +632,7 @@
// done with it.
valueIDs[value] = nextValueID++;
return;
- case SSAValueKind::FnArgument:
+ case SSAValueKind::MLFuncArgument:
specialName << "arg" << nextArgumentID++;
break;
case SSAValueKind::ForStmt:
@@ -1018,6 +1023,9 @@
// Prints ML function
void print();
+ // Prints ML function signature
+ void printFunctionSignature();
+
// Methods to print ML function statements
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
@@ -1039,12 +1047,18 @@
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
const ModulePrinter &other)
: FunctionPrinter(other), function(function), numSpaces(0) {
+ assert(function && "Cannot print nullptr function");
numberValues();
}
/// Number all of the SSA values in this ML function.
void MLFunctionPrinter::numberValues() {
- // Visits all operation statements and numbers the first result.
+ // Numbers ML function arguments
+ for (auto *arg : function->getArguments())
+ numberValueID(arg);
+
+ // Walks ML function statements and numbers for statements and
+ // the first result of the operation statements.
struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
void visitOperationStmt(OperationStmt *stmt) {
@@ -1062,14 +1076,30 @@
void MLFunctionPrinter::print() {
os << "mlfunc ";
- // FIXME: should print argument names rather than just signature
- printFunctionSignature(function);
+ printFunctionSignature();
os << " {\n";
print(function);
os << " return\n";
os << "}\n\n";
}
+void MLFunctionPrinter::printFunctionSignature() {
+ auto type = function->getType();
+
+ os << "@" << function->getName() << '(';
+
+ for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+ if (i > 0)
+ os << ", ";
+ auto *arg = function->getArgument(i);
+ printOperand(arg);
+ os << " : ";
+ printType(arg->getType());
+ }
+ os << ")";
+ printFunctionResultType(type);
+}
+
void MLFunctionPrinter::print(const StmtBlock *block) {
numSpaces += indentWidth;
for (auto &stmt : block->getStatements()) {
@@ -1169,7 +1199,7 @@
return;
case SSAValueKind::InstResult:
return getDefiningInst()->print(os);
- case SSAValueKind::FnArgument:
+ case SSAValueKind::MLFuncArgument:
// TODO: Improve this.
os << "<function argument>\n";
return;