Support for affine integer sets
- introduce affine integer sets into the IR
- parse and print affine integer sets (both inline or outlined) similar to
affine maps
- use integer set for IfStmt's conditional, and implement parsing of IfStmt's
conditional
- fixed an affine expr paren omission bug while one this.
TODO: parse/represent/print MLValue operands to affine integer set references.
PiperOrigin-RevId: 207779408
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 02db4db..91fa8ad 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -111,7 +111,13 @@
return lexComment();
return emitError(tokStart, "unexpected character");
- case '@': return lexAtIdentifier(tokStart);
+ case '@':
+ if (*curPtr == '@') {
+ ++curPtr;
+ return lexDoubleAtIdentifier(tokStart);
+ }
+ return lexAtIdentifier(tokStart);
+
case '#':
LLVM_FALLTHROUGH;
case '%':
@@ -199,6 +205,20 @@
return formToken(Token::at_identifier, tokStart);
}
+/// Lex an '@@foo' identifier.
+///
+/// function-id ::= `@@` bare-id
+///
+Token Lexer::lexDoubleAtIdentifier(const char *tokStart) {
+ // These always start with a letter.
+ if (!isalpha(*curPtr++))
+ return emitError(curPtr - 1, "expected letter in @@ identifier");
+
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
+ ++curPtr;
+ return formToken(Token::double_at_identifier, tokStart);
+}
+
/// Lex an identifier that starts with a prefix followed by suffix-id.
///
/// affine-map-id ::= `#` suffix-id
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index f9cdfb6..51962fa 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -60,6 +60,7 @@
Token lexComment();
Token lexBareIdentifierOrKeyword(const char *tokStart);
Token lexAtIdentifier(const char *tokStart);
+ Token lexDoubleAtIdentifier(const char *tokStart);
Token lexPrefixedIdentifier(const char *tokStart);
Token lexNumber(const char *tokStart);
Token lexString(const char *tokStart);
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d89c62e..5ca4e56 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -58,6 +58,8 @@
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap *> affineMapDefinitions;
+ // A map from integer set identifier to IntegerSet.
+ llvm::StringMap<IntegerSet *> integerSetDefinitions;
private:
ParserState(const ParserState &) = delete;
@@ -177,6 +179,8 @@
// Polyhedral structures.
AffineMap *parseAffineMapInline();
AffineMap *parseAffineMapReference();
+ IntegerSet *parseIntegerSetInline();
+ IntegerSet *parseIntegerSetReference();
private:
// The Parser is subclassed and reinstantiated. Do not add additional
@@ -717,13 +721,15 @@
};
namespace {
-/// This is a specialized parser for AffineMap's, maintaining the state
-/// transient to their bodies.
-class AffineMapParser : public Parser {
+/// This is a specialized parser for affine structures (affine maps, affine
+/// expressions, and integer sets), maintaining the state transient to their
+/// bodies.
+class AffineParser : public Parser {
public:
- explicit AffineMapParser(ParserState &state) : Parser(state) {}
+ explicit AffineParser(ParserState &state) : Parser(state) {}
AffineMap *parseAffineMapInline();
+ IntegerSet *parseIntegerSetInline();
private:
// Binary affine op parsing.
@@ -751,6 +757,7 @@
AffineExpr *parseAffineHighPrecOpExpr(AffineExpr *llhs,
AffineHighPrecOp llhsOp,
SMLoc llhsOpLoc);
+ AffineExpr *parseAffineConstraint(bool *isEq);
private:
SmallVector<std::pair<StringRef, AffineExpr *>, 4> dimsAndSymbols;
@@ -760,10 +767,9 @@
/// Create an affine binary high precedence op expression (mul's, div's, mod).
/// opLoc is the location of the op token to be used to report errors
/// for non-conforming expressions.
-AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs,
- SMLoc opLoc) {
+AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs, SMLoc opLoc) {
// TODO: make the error location info accurate.
switch (op) {
case Mul:
@@ -801,9 +807,9 @@
}
/// Create an affine binary low precedence op expression (add, sub).
-AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
- AffineExpr *lhs,
- AffineExpr *rhs) {
+AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs) {
switch (op) {
case AffineLowPrecOp::Add:
return builder.getAddExpr(lhs, rhs);
@@ -818,7 +824,7 @@
/// Consume this token if it is a lower precedence affine op (there are only two
/// precedence levels).
-AffineLowPrecOp AffineMapParser::consumeIfLowPrecOp() {
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
switch (getToken().getKind()) {
case Token::plus:
consumeToken(Token::plus);
@@ -833,7 +839,7 @@
/// Consume this token if it is a higher precedence affine op (there are only
/// two precedence levels)
-AffineHighPrecOp AffineMapParser::consumeIfHighPrecOp() {
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
switch (getToken().getKind()) {
case Token::star:
consumeToken(Token::star);
@@ -861,9 +867,9 @@
/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
/// null. llhsOpLoc is the location of the llhsOp token that will be used to
/// report an error for non-conforming expressions.
-AffineExpr *AffineMapParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
- AffineHighPrecOp llhsOp,
- SMLoc llhsOpLoc) {
+AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
AffineExpr *lhs = parseAffineOperandExpr(llhs);
if (!lhs)
return nullptr;
@@ -892,7 +898,7 @@
/// Parse an affine expression inside parentheses.
///
/// affine-expr ::= `(` affine-expr `)`
-AffineExpr *AffineMapParser::parseParentheticalExpr() {
+AffineExpr *AffineParser::parseParentheticalExpr() {
if (parseToken(Token::l_paren, "expected '('"))
return nullptr;
if (getToken().is(Token::r_paren))
@@ -910,7 +916,7 @@
/// Parse the negation expression.
///
/// affine-expr ::= `-` affine-expr
-AffineExpr *AffineMapParser::parseNegateExpression(AffineExpr *lhs) {
+AffineExpr *AffineParser::parseNegateExpression(AffineExpr *lhs) {
if (parseToken(Token::minus, "expected '-'"))
return nullptr;
@@ -929,7 +935,7 @@
/// Parse a bare id that may appear in an affine expression.
///
/// affine-expr ::= bare-id
-AffineExpr *AffineMapParser::parseBareIdExpr() {
+AffineExpr *AffineParser::parseBareIdExpr() {
if (getToken().isNot(Token::bare_identifier))
return (emitError("expected bare identifier"), nullptr);
@@ -947,7 +953,7 @@
/// Parse a positive integral constant appearing in an affine expression.
///
/// affine-expr ::= integer-literal
-AffineExpr *AffineMapParser::parseIntegerExpr() {
+AffineExpr *AffineParser::parseIntegerExpr() {
// No need to handle negative numbers separately here. They are naturally
// handled via the unary negation operator, although (FIXME) MININT_64 still
// not correctly handled.
@@ -971,7 +977,7 @@
// operand expression, it's an op expression and will be parsed via
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l
// are valid operands that will be parsed by this function.
-AffineExpr *AffineMapParser::parseAffineOperandExpr(AffineExpr *lhs) {
+AffineExpr *AffineParser::parseAffineOperandExpr(AffineExpr *lhs) {
switch (getToken().getKind()) {
case Token::bare_identifier:
return parseBareIdExpr();
@@ -1021,8 +1027,8 @@
/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3)
/// will be parsed using parseAffineHighPrecOpExpr().
-AffineExpr *AffineMapParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
- AffineLowPrecOp llhsOp) {
+AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
+ AffineLowPrecOp llhsOp) {
AffineExpr *lhs;
if (!(lhs = parseAffineOperandExpr(llhs)))
return nullptr;
@@ -1077,14 +1083,14 @@
/// Additional conditions are checked depending on the production. For eg., one
/// of the operands for `*` has to be either constant/symbolic; the second
/// operand for floordiv, ceildiv, and mod has to be a positive integer.
-AffineExpr *AffineMapParser::parseAffineExpr() {
+AffineExpr *AffineParser::parseAffineExpr() {
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
}
/// Parse a dim or symbol from the lists appearing before the actual expressions
/// of the affine map. Update our state to store the dimensional/symbolic
/// identifier.
-ParseResult AffineMapParser::parseIdentifierDefinition(AffineExpr *idExpr) {
+ParseResult AffineParser::parseIdentifierDefinition(AffineExpr *idExpr) {
if (getToken().isNot(Token::bare_identifier))
return emitError("expected bare identifier");
@@ -1100,7 +1106,7 @@
}
/// Parse the list of symbolic identifiers to an affine map.
-ParseResult AffineMapParser::parseSymbolIdList(unsigned &numSymbols) {
+ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
consumeToken(Token::l_square);
auto parseElt = [&]() -> ParseResult {
auto *symbol = AffineSymbolExpr::get(numSymbols++, getContext());
@@ -1110,7 +1116,7 @@
}
/// Parse the list of dimensional identifiers to an affine map.
-ParseResult AffineMapParser::parseDimIdList(unsigned &numDims) {
+ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
if (parseToken(Token::l_paren,
"expected '(' at start of dimensional identifiers list"))
return ParseFailure;
@@ -1129,7 +1135,7 @@
/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
///
/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
-AffineMap *AffineMapParser::parseAffineMapInline() {
+AffineMap *AffineParser::parseAffineMapInline() {
unsigned numDims = 0, numSymbols = 0;
// List of dimensional identifiers.
@@ -1201,7 +1207,7 @@
}
AffineMap *Parser::parseAffineMapInline() {
- return AffineMapParser(state).parseAffineMapInline();
+ return AffineParser(state).parseAffineMapInline();
}
AffineMap *Parser::parseAffineMapReference() {
@@ -2100,6 +2106,7 @@
AffineConstantExpr *parseIntConstant();
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
+ IntegerSet *parseCondition();
ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
};
@@ -2203,6 +2210,120 @@
return builder.getConstantExpr((int64_t)val.getValue());
}
+/// Parse condition.
+IntegerSet *MLFunctionParser::parseCondition() {
+ return parseIntegerSetReference();
+
+ // TODO: Parse operands to the integer set.
+}
+
+/// Parse an affine constraint.
+/// affine-constraint ::= affine-expr `>=` `0`
+/// | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it is
+/// an inequality (greater than or equal).
+///
+AffineExpr *AffineParser::parseAffineConstraint(bool *isEq) {
+ AffineExpr *expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+
+ if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = false;
+ return expr;
+ }
+ return (emitError("expected '0' after '>='"), nullptr);
+ }
+
+ if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = true;
+ return expr;
+ }
+ return (emitError("expected '0' after '=='"), nullptr);
+ }
+
+ return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+ nullptr);
+}
+
+/// Parse an integer set definition.
+/// integer-set-inline
+/// ::= dim-and-symbol-id-lists `:` affine-constraint-conjunction
+/// affine-constraint-conjunction ::= /*empty*/
+/// | affine-constraint (`,` affine-constraint)*
+///
+IntegerSet *AffineParser::parseIntegerSetInline() {
+ unsigned numDims = 0, numSymbols = 0;
+
+ // List of dimensional identifiers.
+ if (parseDimIdList(numDims))
+ return nullptr;
+
+ // Symbols are optional.
+ if (getToken().is(Token::l_square)) {
+ if (parseSymbolIdList(numSymbols))
+ return nullptr;
+ }
+
+ if (parseToken(Token::colon, "expected ':' or '['") ||
+ parseToken(Token::l_paren,
+ "expected '(' at start of integer set constraint list"))
+ return nullptr;
+
+ SmallVector<AffineExpr *, 4> constraints;
+ SmallVector<bool, 4> isEqs;
+ auto parseElt = [&]() -> ParseResult {
+ bool isEq;
+ auto *elt = parseAffineConstraint(&isEq);
+ ParseResult res = elt ? ParseSuccess : ParseFailure;
+ if (elt) {
+ constraints.push_back(elt);
+ isEqs.push_back(isEq);
+ }
+ return res;
+ };
+
+ // Parse a list of affine constraints (comma-separated) .
+ // Grammar: affine-constraint-conjunct ::= `(` affine-constraint (`,`
+ // affine-constraint)* `)
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return nullptr;
+
+ // Parsed a valid integer set.
+ return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
+}
+
+IntegerSet *Parser::parseIntegerSetInline() {
+ return AffineParser(state).parseIntegerSetInline();
+}
+
+/// Parse a reference to an integer set.
+/// integer-set ::= integer-set-id | integer-set-inline
+/// integer-set-id ::= `@@` suffix-id
+///
+IntegerSet *Parser::parseIntegerSetReference() {
+ if (getToken().is(Token::double_at_identifier)) {
+ // Parse integer set identifier and verify that it exists.
+ StringRef integerSetId = getTokenSpelling().drop_front(2);
+ if (getState().integerSetDefinitions.count(integerSetId) == 0)
+ return (emitError("undefined integer set id '" + integerSetId + "'"),
+ nullptr);
+ consumeToken(Token::double_at_identifier);
+ return getState().integerSetDefinitions[integerSetId];
+ }
+ // Try to parse an inline integer set definition.
+ return parseIntegerSetInline();
+}
+
/// If statement.
///
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
@@ -2212,15 +2333,18 @@
///
ParseResult MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
- if (parseToken(Token::l_paren, "expected ("))
+
+ if (parseToken(Token::l_paren, "expected '('"))
return ParseFailure;
- // TODO: parse condition
-
- if (parseToken(Token::r_paren, "expected )"))
+ IntegerSet *condition = parseCondition();
+ if (!condition)
return ParseFailure;
- IfStmt *ifStmt = builder.createIf();
+ if (parseToken(Token::r_paren, "expected ')'"))
+ return ParseFailure;
+
+ IfStmt *ifStmt = builder.createIf(condition);
IfClause *thenClause = ifStmt->getThenClause();
// When parsing of an if statement body fails, the IR contains
@@ -2308,6 +2432,7 @@
private:
ParseResult parseAffineMapDef();
+ ParseResult parseIntegerSetDef();
// Functions.
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
@@ -2330,8 +2455,8 @@
StringRef affineMapId = getTokenSpelling().drop_front();
// Check for redefinitions.
- auto *&entry = getState().affineMapDefinitions[affineMapId];
- if (entry)
+ auto **entry = &getState().affineMapDefinitions[affineMapId];
+ if (*entry)
return emitError("redefinition of affine map id '" + affineMapId + "'");
consumeToken(Token::hash_identifier);
@@ -2341,8 +2466,36 @@
"expected '=' in affine map outlined definition"))
return ParseFailure;
- entry = parseAffineMapInline();
- if (!entry)
+ *entry = parseAffineMapInline();
+ if (!*entry)
+ return ParseFailure;
+
+ return ParseSuccess;
+}
+
+/// Integer set declaration.
+///
+/// integer-set-decl ::= integer-set-id `=` integer-set-inline
+///
+ParseResult ModuleParser::parseIntegerSetDef() {
+ assert(getToken().is(Token::double_at_identifier));
+
+ StringRef integerSetId = getTokenSpelling().drop_front(2);
+
+ // Check for redefinitions (a default entry is created if one doesn't exist)
+ auto **entry = &getState().integerSetDefinitions[integerSetId];
+ if (*entry)
+ return emitError("redefinition of integer set id '" + integerSetId + "'");
+
+ consumeToken(Token::double_at_identifier);
+
+ // Parse the '='
+ if (parseToken(Token::equal,
+ "expected '=' in outlined integer set definition"))
+ return ParseFailure;
+
+ *entry = parseIntegerSetInline();
+ if (!*entry)
return ParseFailure;
return ParseSuccess;
@@ -2509,6 +2662,11 @@
return ParseFailure;
break;
+ case Token::double_at_identifier:
+ if (parseIntegerSetDef())
+ return ParseFailure;
+ break;
+
case Token::kw_extfunc:
if (parseExtFunc())
return ParseFailure;
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 6d71884..7eff18e 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -50,10 +50,11 @@
TOK_MARKER(error)
// Identifiers.
-TOK_IDENTIFIER(bare_identifier) // foo
-TOK_IDENTIFIER(at_identifier) // @foo
-TOK_IDENTIFIER(hash_identifier) // #foo
-TOK_IDENTIFIER(percent_identifier) // %foo
+TOK_IDENTIFIER(bare_identifier) // foo
+TOK_IDENTIFIER(at_identifier) // @foo
+TOK_IDENTIFIER(double_at_identifier) // @@foo
+TOK_IDENTIFIER(hash_identifier) // #foo
+TOK_IDENTIFIER(percent_identifier) // %foo
// TODO: @@foo, etc.
// Literals
@@ -64,6 +65,7 @@
// Punctuation.
TOK_PUNCTUATION(arrow, "->")
+TOK_PUNCTUATION(at, "@")
TOK_PUNCTUATION(colon, ":")
TOK_PUNCTUATION(comma, ",")
TOK_PUNCTUATION(question, "?")