Introduce IR and parser support for ML functions.
Representing function arguments is still TODO.
Supporting instructions other than return is also TODO.
PiperOrigin-RevId: 202570934
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index c944937..3655c9c 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -33,7 +33,7 @@
class Function {
public:
enum class Kind {
- ExtFunc, CFGFunc
+ ExtFunc, CFGFunc, MLFunc
};
Kind getKind() const { return kind; }
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
new file mode 100644
index 0000000..8e45407
--- /dev/null
+++ b/include/mlir/IR/MLFunction.h
@@ -0,0 +1,53 @@
+//===- MLFunction.h - MLIR MLFunction Class -------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines MLFunction class
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_MLFUNCTION_H_
+#define MLIR_IR_MLFUNCTION_H_
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLStatements.h"
+#include <vector>
+
+namespace mlir {
+
+// MLFunction is defined as a sequence of statements that may
+// include nested affine for loops, conditionals and instructions.
+class MLFunction : public Function {
+public:
+ MLFunction(StringRef name, FunctionType *type);
+
+ // FIXME: wrong representation and API, leaks memory etc
+ std::vector<MLStatement*> stmtList;
+
+ // TODO: add function arguments and return values once
+ // SSA values are implemented
+
+ // Methods for support type inquiry through isa, cast, and dyn_cast
+ static bool classof(const Function *func) {
+ return func->getKind() == Kind::MLFunc;
+ }
+
+ void print(raw_ostream &os) const;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_MLFUNCTION_H_
diff --git a/include/mlir/IR/MLStatements.h b/include/mlir/IR/MLStatements.h
new file mode 100644
index 0000000..b565890
--- /dev/null
+++ b/include/mlir/IR/MLStatements.h
@@ -0,0 +1,58 @@
+//===- MLStatements.h - MLIR ML Statement Classes ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the classes for MLFunction statements.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_MLSTATEMENTS_H
+#define MLIR_IR_MLSTATEMENTS_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+ class MLFunction;
+
+/// ML function consists of ML statements - for statement, if statement
+/// or operation.
+class MLStatement {
+public:
+ enum class Kind {
+ For,
+ If,
+ Operation
+ };
+
+ Kind getKind() const { return kind; }
+
+ /// Returns the function that this MLStatement is part of.
+ MLFunction *getFunction() const { return function; }
+
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+protected:
+ MLStatement(Kind kind, MLFunction *function)
+ : kind(kind), function(function) {}
+
+private:
+ Kind kind;
+ MLFunction *function;
+};
+
+} //end namespace mlir
+#endif // MLIR_IR_STATEMENTS_H
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 08fe838..03871fc 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
@@ -151,10 +152,18 @@
print(llvm::errs());
}
+void MLStatement::print(raw_ostream &os) const {
+ //TODO
+}
+
+void MLStatement::dump() const {
+ print(llvm::errs());
+}
void Function::print(raw_ostream &os) const {
switch (getKind()) {
case Kind::ExtFunc: return cast<ExtFunction>(this)->print(os);
case Kind::CFGFunc: return cast<CFGFunction>(this)->print(os);
+ case Kind::MLFunc: return cast<MLFunction>(this)->print(os);
}
}
@@ -167,6 +176,18 @@
state.print();
}
+void MLFunction::print(raw_ostream &os) const {
+ os << "mlfunc ";
+ // FIXME: should print argument names rather than just signature
+ printFunctionSignature(this, os);
+ os << " {\n";
+
+ for (auto *stmt : stmtList)
+ stmt->print(os);
+ os << " return\n";
+ os << "}\n\n";
+}
+
void Module::print(raw_ostream &os) const {
for (auto *fn : functionList)
fn->print(os);
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 833c173..3af2a61 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
@@ -38,3 +39,11 @@
CFGFunction::CFGFunction(StringRef name, FunctionType *type)
: Function(name, type, Kind::CFGFunc) {
}
+
+//===----------------------------------------------------------------------===//
+// MLFunction implementation.
+//===----------------------------------------------------------------------===//
+
+MLFunction::MLFunction(StringRef name, FunctionType *type)
+ : Function(name, type, Kind::MLFunc) {
+}
diff --git a/lib/IR/MLStatements.cpp b/lib/IR/MLStatements.cpp
new file mode 100644
index 0000000..26fa708
--- /dev/null
+++ b/lib/IR/MLStatements.cpp
@@ -0,0 +1,22 @@
+//===- MLStatements.cpp - MLIR MLStatement Instruction Classes ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLStatements.h"
+using namespace mlir;
+
+// TODO: classes derived from MLStatement
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 0d24972..df952f9 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -134,9 +135,11 @@
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
+ ParseResult parseMLFunc();
ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
TerminatorInst *parseTerminator(BasicBlock *currentBB,
CFGFunctionParserState &functionState);
+ MLStatement *parseMLStatement(MLFunction *currentFunction);
};
} // end anonymous namespace
@@ -532,7 +535,6 @@
return ParseSuccess;
}
-
/// External function declarations.
///
/// ext-func ::= `extfunc` function-signature
@@ -707,6 +709,59 @@
}
}
+/// ML function declarations.
+///
+/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}`
+///
+ParseResult Parser::parseMLFunc() {
+ consumeToken(Token::kw_mlfunc);
+
+ StringRef name;
+ FunctionType *type = nullptr;
+
+ // FIXME: Parse ML function signature (args + types)
+ // by passing pointer to SmallVector<identifier> into parseFunctionSignature
+ if (parseFunctionSignature(name, type))
+ return ParseFailure;
+
+ if (!consumeIf(Token::l_brace))
+ return emitError("expected '{' in ML function");
+
+ // Okay, the ML function signature was parsed correctly, create the function.
+ auto function = new MLFunction(name, type);
+
+ // Make sure we have at least one statement.
+ if (curToken.is(Token::r_brace))
+ return emitError("ML function must end with return statement");
+
+ // Parse the list of instructions.
+ while (!consumeIf(Token::kw_return)) {
+ auto *stmt = parseMLStatement(function);
+ if (!stmt)
+ return ParseFailure;
+ function->stmtList.push_back(stmt);
+ }
+
+ // TODO: parse return statement operands
+ if (!consumeIf(Token::r_brace))
+ emitError("expected '}' in ML function");
+
+ module->functionList.push_back(function);
+
+ return ParseSuccess;
+}
+
+/// Parse an MLStatement
+/// TODO
+///
+MLStatement *Parser::parseMLStatement(MLFunction *currentFunction) {
+ switch (curToken.getKind()) {
+ default:
+ return (emitError("expected ML statement"), nullptr);
+
+ // TODO: add parsing of ML statements
+ }
+}
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
@@ -741,7 +796,11 @@
if (parseAffineMapDef()) return nullptr;
break;
- // TODO: mlfunc, affine entity declarations, etc.
+ case Token::kw_mlfunc:
+ if (parseMLFunc()) return nullptr;
+ break;
+
+ // TODO: affine entity declarations, etc.
}
}
}
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index d742b6b..e60bf78 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -47,3 +47,13 @@
bb42: ; expected-error {{expected terminator}}
return
}
+
+; -----
+
+mlfunc @foo()
+mlfunc @bar() ; expected-error {{expected '{' in ML function}}
+
+; -----
+
+mlfunc @no_return() {
+} ; expected-error {{ML function must end with return statement}}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index e307a3a..b7c28f7 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -48,3 +48,9 @@
bb4: ; CHECK: bb3:
return ; CHECK: return
} ; CHECK: }
+
+; CHECK-LABEL: mlfunc @simpleMLF() {
+mlfunc @simpleMLF() {
+ return ; CHECK: return
+} ; CHECK: }
+