Sketch out parser/IR support for OperationInst, and a new Instruction base
class.
Introduce an Identifier class to MLIRContext to represent uniqued identifiers,
introduce string literal support to the lexer, introducing parser and printer
support etc.
PiperOrigin-RevId: 202592007
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 03871fc..639000b 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -77,7 +77,11 @@
void print();
void print(const BasicBlock *block);
- void print(const TerminatorInst *inst);
+
+ void print(const Instruction *inst);
+ void print(const OperationInst *inst);
+ void print(const ReturnInst *inst);
+ void print(const BranchInst *inst);
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
@@ -114,32 +118,47 @@
void CFGFunctionState::print(const BasicBlock *block) {
os << "bb" << getBBID(block) << ":\n";
- // TODO Print arguments and instructions.
+ // TODO Print arguments.
+ for (auto inst : block->instList)
+ print(inst);
print(block->getTerminator());
}
-void CFGFunctionState::print(const TerminatorInst *inst) {
+void CFGFunctionState::print(const Instruction *inst) {
switch (inst->getKind()) {
+ case Instruction::Kind::Operation:
+ return print(cast<OperationInst>(inst));
case TerminatorInst::Kind::Branch:
- os << " br bb" << getBBID(cast<BranchInst>(inst)->getDest()) << "\n";
- break;
+ return print(cast<BranchInst>(inst));
case TerminatorInst::Kind::Return:
- os << " return\n";
- break;
+ return print(cast<ReturnInst>(inst));
}
}
+void CFGFunctionState::print(const OperationInst *inst) {
+ // TODO: escape name if necessary.
+ os << " \"" << inst->getName().str() << "\"()\n";
+}
+
+void CFGFunctionState::print(const BranchInst *inst) {
+ os << " br bb" << getBBID(inst->getDest()) << "\n";
+}
+void CFGFunctionState::print(const ReturnInst *inst) {
+ os << " return\n";
+}
+
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
-void TerminatorInst::print(raw_ostream &os) const {
+
+void Instruction::print(raw_ostream &os) const {
CFGFunctionState state(getFunction(), os);
state.print(this);
}
-void TerminatorInst::dump() const {
+void Instruction::dump() const {
print(llvm::errs());
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index c32e878..2222a12 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -19,14 +19,33 @@
#include "mlir/IR/BasicBlock.h"
using namespace mlir;
-CFGFunction *TerminatorInst::getFunction() const {
+//===----------------------------------------------------------------------===//
+// Instruction
+//===----------------------------------------------------------------------===//
+
+CFGFunction *Instruction::getFunction() const {
return getBlock()->getFunction();
}
+//===----------------------------------------------------------------------===//
+// OperationInst
+//===----------------------------------------------------------------------===//
+
+OperationInst::OperationInst(Identifier name, BasicBlock *block) :
+ Instruction(Kind::Operation, block), name(name) {
+ getBlock()->instList.push_back(this);
+}
+
+//===----------------------------------------------------------------------===//
+// Terminators
+//===----------------------------------------------------------------------===//
+
ReturnInst::ReturnInst(BasicBlock *parent)
: TerminatorInst(Kind::Return, parent) {
+ getBlock()->setTerminator(this);
}
BranchInst::BranchInst(BasicBlock *dest, BasicBlock *parent)
: TerminatorInst(Kind::Branch, parent), dest(dest) {
+ getBlock()->setTerminator(this);
}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 8a035b6..5f2bd8e 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -16,9 +16,11 @@
// =============================================================================
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
using namespace mlir;
using namespace llvm;
@@ -89,6 +91,9 @@
/// We put immortal objects into this allocator.
llvm::BumpPtrAllocator allocator;
+ /// These are identifiers uniqued into this MLIRContext.
+ llvm::StringMap<char, llvm::BumpPtrAllocator&> identifiers;
+
// Primitive type uniquing.
PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
@@ -110,6 +115,8 @@
public:
+ MLIRContextImpl() : identifiers(allocator) {}
+
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
template<typename T>
@@ -128,9 +135,28 @@
}
+//===----------------------------------------------------------------------===//
+// Identifier
+//===----------------------------------------------------------------------===//
+
+/// Return an identifier for the specified string.
+Identifier Identifier::get(StringRef str, const MLIRContext *context) {
+ assert(!str.empty() && "Cannot create an empty identifier");
+ assert(str.find('\0') == StringRef::npos &&
+ "Cannot create an identifier with a nul character");
+
+ auto &impl = context->getImpl();
+ auto it = impl.identifiers.insert({str, char()}).first;
+ return Identifier(it->getKeyData());
+}
+
+
+//===----------------------------------------------------------------------===//
+// Types
+//===----------------------------------------------------------------------===//
+
PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
: Type(kind, context) {
-
}
PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 209f988..b6473f5 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -99,6 +99,7 @@
case ';': return lexComment();
case '@': return lexAtIdentifier(tokStart);
case '#': return lexAffineMapId(tokStart);
+ case '"': return lexString(tokStart);
case '0': case '1': case '2': case '3': case '4':
case '5': case '6': case '7': case '8': case '9':
@@ -231,3 +232,32 @@
return formToken(Token::integer, tokStart);
}
+
+/// Lex a string literal.
+///
+/// string-literal ::= '"' [^"\n\f\v\r]* '"'
+///
+/// TODO: define escaping rules.
+Token Lexer::lexString(const char *tokStart) {
+ assert(curPtr[-1] == '"');
+
+ while (1) {
+ switch (*curPtr++) {
+ case '"':
+ return formToken(Token::string, tokStart);
+ case '0':
+ // If this is a random nul character in the middle of a string, just
+ // include it. If it is the end of file, then it is an error.
+ if (curPtr-1 != curBuffer.end())
+ continue;
+ LLVM_FALLTHROUGH;
+ case '\n':
+ case '\v':
+ case '\f':
+ return emitError(curPtr-1, "expected '\"' in string literal");
+
+ default:
+ continue;
+ }
+ }
+}
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index 0301a35..f0274fe 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -62,6 +62,7 @@
Token lexAtIdentifier(const char *tokStart);
Token lexAffineMapId(const char *tokStart);
Token lexNumber(const char *tokStart);
+ Token lexString(const char *tokStart);
};
} // end namespace mlir
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index df952f9..c36d3b9 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -43,7 +43,7 @@
/// Main parser implementation.
class Parser {
- public:
+public:
Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
const SMDiagnosticHandlerTy &errorReporter)
: context(context),
@@ -137,10 +137,13 @@
ParseResult parseCFGFunc();
ParseResult parseMLFunc();
ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
- TerminatorInst *parseTerminator(BasicBlock *currentBB,
- CFGFunctionParserState &functionState);
MLStatement *parseMLStatement(MLFunction *currentFunction);
+ ParseResult parseCFGOperation(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState);
+ ParseResult parseTerminator(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState);
+
};
} // end anonymous namespace
@@ -490,7 +493,7 @@
// Check that 'affineMapId' is unique.
// TODO(andydavis) Add a unit test for this case.
if (affineMaps.count(affineMapId) > 0)
- return emitError("encountered non-unique affine map id");
+ return emitError("redefinition of affine map id '" + affineMapId + "'");
consumeToken(Token::affine_map_id);
@@ -660,26 +663,58 @@
if (!consumeIf(Token::colon))
return emitError("expected ':' after basic block name");
+ // Parse the list of operations that make up the body of the block.
+ while (curToken.isNot(Token::kw_return, Token::kw_br)) {
+ if (parseCFGOperation(block, functionState))
+ return ParseFailure;
+ }
- // TODO(clattner): Verify block hasn't already been parsed (this would be a
- // redefinition of the same name) once we have a body implementation.
-
- // TODO(clattner): Move block to the end of the list, once we have a proper
- // block list representation in CFGFunction.
-
- // TODO: parse instruction list.
-
- // TODO: Generalize this once instruction list parsing is built out.
-
- auto *termInst = parseTerminator(block, functionState);
- if (!termInst)
+ if (parseTerminator(block, functionState))
return ParseFailure;
- block->setTerminator(termInst);
return ParseSuccess;
}
+/// Parse the CFG operation.
+///
+/// TODO(clattner): This is a change from the MLIR spec as written, it is an
+/// experiment that will eliminate "builtin" instructions as a thing.
+///
+/// cfg-operation ::=
+/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
+/// `:` function-type
+///
+ParseResult Parser::
+parseCFGOperation(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState) {
+
+ // TODO: parse ssa-id.
+
+ if (curToken.isNot(Token::string))
+ return emitError("expected operation name in quotes");
+
+ auto name = curToken.getStringValue();
+ if (name.empty())
+ return emitError("empty operation name is invalid");
+
+ consumeToken(Token::string);
+
+ if (!consumeIf(Token::l_paren))
+ return emitError("expected '(' in operation");
+
+ // TODO: Parse operands.
+ if (!consumeIf(Token::r_paren))
+ return emitError("expected '(' in operation");
+
+ auto nameId = Identifier::get(name, context);
+ new OperationInst(nameId, currentBB);
+
+ // TODO: add instruction the per-function symbol table.
+ return ParseSuccess;
+}
+
+
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
@@ -688,23 +723,25 @@
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
///
-TerminatorInst *Parser::parseTerminator(BasicBlock *currentBB,
- CFGFunctionParserState &functionState) {
+ParseResult Parser::parseTerminator(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState) {
switch (curToken.getKind()) {
default:
- return (emitError("expected terminator at end of basic block"), nullptr);
+ return emitError("expected terminator at end of basic block");
case Token::kw_return:
consumeToken(Token::kw_return);
- return new ReturnInst(currentBB);
+ new ReturnInst(currentBB);
+ return ParseSuccess;
case Token::kw_br: {
consumeToken(Token::kw_br);
auto destBB = functionState.getBlockNamed(curToken.getSpelling(),
curToken.getLoc());
if (!consumeIf(Token::bare_identifier))
- return (emitError("expected basic block name"), nullptr);
- return new BranchInst(destBB, currentBB);
+ return emitError("expected basic block name");
+ new BranchInst(destBB, currentBB);
+ return ParseSuccess;
}
}
}
diff --git a/lib/Parser/Token.cpp b/lib/Parser/Token.cpp
index c721cf1..a8affc7 100644
--- a/lib/Parser/Token.cpp
+++ b/lib/Parser/Token.cpp
@@ -39,7 +39,7 @@
/// For an integer token, return its value as an unsigned. If it doesn't fit,
/// return None.
-Optional<unsigned> Token::getUnsignedIntegerValue() {
+Optional<unsigned> Token::getUnsignedIntegerValue() const {
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
unsigned result = 0;
@@ -47,3 +47,12 @@
return None;
return result;
}
+
+/// Given a 'string' token, return its value, including removing the quote
+/// characters and unescaping the contents of the string.
+std::string Token::getStringValue() const {
+ // TODO: Handle escaping.
+
+ // Just drop the quotes off for now.
+ return getSpelling().drop_front().drop_back().str();
+}
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index 8a654a1..15ce015 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -38,6 +38,7 @@
// TODO: @@foo, etc.
integer, // 42
+ string, // "foo"
// Punctuation.
arrow, // ->
@@ -105,7 +106,11 @@
/// For an integer token, return its value as an unsigned. If it doesn't fit,
/// return None.
- Optional<unsigned> getUnsignedIntegerValue();
+ Optional<unsigned> getUnsignedIntegerValue() const;
+
+ /// Given a 'string' token, return its value, including removing the quote
+ /// characters and unescaping the contents of the string.
+ std::string getStringValue() const;
// Location processing.
llvm::SMLoc getLoc() const;