Implement operands for the 'if' statement.
This CL also includes two other minor changes:
- change the implemented syntax from 'if (cond)' to 'if cond', as specified by MLIR spec.
- a minor fix to the implementation of the ForStmt.
PiperOrigin-RevId: 210618122
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 2cc20ac..21dd4c6 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -186,7 +186,7 @@
}
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
- recordIntegerSetReference(ifStmt->getCondition());
+ recordIntegerSetReference(ifStmt->getIntegerSet());
for (auto &childStmt : *ifStmt->getThen())
visitStatement(&childStmt);
if (ifStmt->hasElse())
@@ -1406,9 +1406,11 @@
}
void MLFunctionPrinter::print(const IfStmt *stmt) {
- os.indent(numSpaces) << "if (";
- printIntegerSetReference(stmt->getCondition());
- os << ") {\n";
+ os.indent(numSpaces) << "if ";
+ IntegerSet *set = stmt->getIntegerSet();
+ printIntegerSetReference(set);
+ printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims());
+ os << " {\n";
print(stmt->getThen());
os.indent(numSpaces) << "}";
if (stmt->hasElse()) {
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 8d0f223..a03de2f 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -254,8 +254,9 @@
return stmt;
}
-IfStmt *MLFuncBuilder::createIf(Location *location, IntegerSet *condition) {
- auto *stmt = new IfStmt(location, condition);
+IfStmt *MLFuncBuilder::createIf(Location *location,
+ ArrayRef<MLValue *> operands, IntegerSet *set) {
+ auto *stmt = IfStmt::create(location, operands, set);
block->getStatements().insert(insertPoint, stmt);
return stmt;
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 3f95a94..673aa43 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -15,7 +15,9 @@
// limitations under the License.
// =============================================================================
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardOps.h"
@@ -115,8 +117,7 @@
case Kind::For:
return cast<ForStmt>(this)->getNumOperands();
case Kind::If:
- // TODO: query IfStmt once it has operands.
- return 0;
+ return cast<IfStmt>(this)->getNumOperands();
}
}
@@ -127,8 +128,7 @@
case Kind::For:
return cast<ForStmt>(this)->getStmtOperands();
case Kind::If:
- // TODO: query IfStmt once it has operands.
- return {};
+ return cast<IfStmt>(this)->getStmtOperands();
}
}
@@ -301,7 +301,9 @@
AffineMap *ubMap, int64_t step, MLIRContext *context)
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
- StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {}
+ StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
+ operands.reserve(numOperands);
+}
const AffineBound ForStmt::getLowerBound() const {
return AffineBound(*this, 0, lbMap->getNumOperands(), lbMap);
@@ -357,18 +359,47 @@
// IfStmt
//===----------------------------------------------------------------------===//
-IfStmt::IfStmt(Location *location, IntegerSet *condition)
+IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set)
: Statement(Kind::If, location), thenClause(new IfClause(this)),
- elseClause(nullptr), condition(condition) {}
+ elseClause(nullptr), set(set) {
+ operands.reserve(numOperands);
+}
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause)
delete elseClause;
- // An IfStmt's IntegerSet 'condition' should not be deleted since it is
+ // An IfStmt's IntegerSet 'set' should not be deleted since it is
// allocated through MLIRContext's bump pointer allocator.
}
+IfStmt *IfStmt::create(Location *location, ArrayRef<MLValue *> operands,
+ IntegerSet *set) {
+ unsigned numOperands = operands.size();
+ assert(numOperands == set->getNumOperands() &&
+ "operand cound does not match the integer set operand count");
+
+ IfStmt *stmt = new IfStmt(location, numOperands, set);
+
+ for (auto *op : operands)
+ stmt->operands.emplace_back(StmtOperand(stmt, op));
+
+ return stmt;
+}
+
+const AffineCondition IfStmt::getCondition() const {
+ return AffineCondition(*this, set);
+}
+
+MLIRContext *IfStmt::getContext() const {
+ // Check for degenerate case of if statement with no operands.
+ // This is unlikely, but legal.
+ if (operands.empty())
+ return findFunction()->getContext();
+
+ return getOperand(0)->getType()->getContext();
+}
+
//===----------------------------------------------------------------------===//
// Statement Cloning
//===----------------------------------------------------------------------===//
@@ -428,9 +459,7 @@
// Otherwise, we must have an If statement.
auto *ifStmt = cast<IfStmt>(this);
- auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
-
- // TODO: remap operands with remapOperand when if statements have them.
+ auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
auto *resultThen = newIf->getThen();
for (auto &childStmt : *ifStmt->getThen())
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index deda932..48f3643 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -372,7 +372,7 @@
// TODO: check that operation is not a return statement unless it's
// the last one in the function.
- // TODO: check that loop bounds are properly formed.
+ // TODO: check that loop bounds and if conditions are properly formed.
if (verifyReturn())
return true;
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index ebeb196..0935090 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
@@ -2196,12 +2197,12 @@
ParseResult parseForStmt();
ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
- const AffineMap *map);
+ unsigned numDims, unsigned numOperands,
+ const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap *&map,
bool isLower);
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
- IntegerSet *parseCondition();
ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
};
@@ -2306,7 +2307,8 @@
///
ParseResult
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
- const AffineMap *map) {
+ unsigned numDims, unsigned numOperands,
+ const char *affineStructName) {
if (parseToken(Token::l_paren, "expected '('"))
return ParseFailure;
@@ -2316,8 +2318,9 @@
if (parseToken(Token::r_paren, "expected ')'"))
return ParseFailure;
- if (map->getNumDims() != opInfo.size())
- return emitError("dim operand count and affine map dim count must match");
+ if (numDims != opInfo.size())
+ return emitError("dim operand count and " + Twine(affineStructName) +
+ " dim count must match");
if (consumeIf(Token::l_square)) {
parseOptionalSSAUseList(opInfo);
@@ -2325,13 +2328,12 @@
return ParseFailure;
}
- if (map->getNumOperands() != opInfo.size())
- return emitError(
- "symbol operand count and affine map symbol count must match");
+ if (numOperands != opInfo.size())
+ return emitError("symbol operand count and " + Twine(affineStructName) +
+ " symbol count must match");
// Resolve SSA uses.
Type *affineIntType = builder.getAffineIntType();
- unsigned numDims = map->getNumDims();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], affineIntType);
if (!sval)
@@ -2340,10 +2342,10 @@
auto *v = cast<MLValue>(sval);
if (i < numDims && !v->isValidDim())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
- "' cannot be used as dimension id");
+ "' cannot be used as a dimension id");
if (i >= numDims && !v->isValidSymbol())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
- "' cannot be used as symbol");
+ "' cannot be used as a symbol");
operands.push_back(v);
}
@@ -2370,7 +2372,8 @@
if (!map)
return ParseFailure;
- if (parseDimAndSymbolList(operands, map))
+ if (parseDimAndSymbolList(operands, map->getNumDims(),
+ map->getNumOperands(), "affine map"))
return ParseFailure;
return ParseSuccess;
}
@@ -2408,13 +2411,6 @@
return ParseSuccess;
}
-/// Parse condition.
-IntegerSet *MLFunctionParser::parseCondition() {
- return parseIntegerSetReference();
-
- // TODO: Parse operands to the integer set.
-}
-
/// Parse an affine constraint.
/// affine-constraint ::= affine-expr `>=` `0`
/// | affine-expr `==` `0`
@@ -2509,6 +2505,7 @@
/// integer-set-id ::= `@@` suffix-id
///
IntegerSet *Parser::parseIntegerSetReference() {
+ // TODO: change '@@' integer set prefix to '#'.
if (getToken().is(Token::double_at_identifier)) {
// Parse integer set identifier and verify that it exists.
StringRef integerSetId = getTokenSpelling().drop_front(2);
@@ -2533,17 +2530,18 @@
auto loc = getToken().getLoc();
consumeToken(Token::kw_if);
- if (parseToken(Token::l_paren, "expected '('"))
+ IntegerSet *set = parseIntegerSetReference();
+ if (!set)
return ParseFailure;
- IntegerSet *condition = parseCondition();
- if (!condition)
+ SmallVector<MLValue *, 4> operands;
+ if (parseDimAndSymbolList(operands, set->getNumDims(), set->getNumOperands(),
+ "integer set"))
return ParseFailure;
- if (parseToken(Token::r_paren, "expected ')'"))
- return ParseFailure;
+ IfStmt *ifStmt =
+ builder.createIf(getEncodedSourceLocation(loc), operands, set);
- IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), condition);
IfClause *thenClause = ifStmt->getThen();
// When parsing of an if statement body fails, the IR contains