Add attributes and affine expr/map to the Builder, switch the parser over to
use it.
This also removes "operand" from the affine expr classes: it is unnecessary
verbosity and "operand" will mean something very specific for SSA stuff (we
will have an Operand type).
PiperOrigin-RevId: 203976504
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 26447a0..e894f49 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -20,6 +20,8 @@
using namespace mlir;
+/// Returns true if this expression is made out of only symbols and
+/// constants (no dimensional identifiers).
bool AffineExpr::isSymbolic() const {
switch (getKind()) {
case Kind::Constant:
@@ -34,48 +36,41 @@
case Kind::Mul:
case Kind::FloorDiv:
case Kind::CeilDiv:
- case Kind::Mod:
- return cast<AffineBinaryOpExpr>(this)->isSymbolic();
+ case Kind::Mod: {
+ auto expr = cast<AffineBinaryOpExpr>(this);
+ return expr->getLHS()->isSymbolic() && expr->getRHS()->isSymbolic();
+ }
}
}
+/// Returns true if this is a pure affine expression, i.e., multiplication,
+/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool AffineExpr::isPureAffine() const {
switch (getKind()) {
case Kind::SymbolId:
- return cast<AffineSymbolExpr>(this)->isPureAffine();
case Kind::DimId:
- return cast<AffineDimExpr>(this)->isPureAffine();
case Kind::Constant:
- return cast<AffineConstantExpr>(this)->isPureAffine();
+ return true;
case Kind::Add:
- return cast<AffineAddExpr>(this)->isPureAffine();
- case Kind::Sub:
- return cast<AffineSubExpr>(this)->isPureAffine();
- case Kind::Mul:
- return cast<AffineMulExpr>(this)->isPureAffine();
- case Kind::FloorDiv:
- return cast<AffineFloorDivExpr>(this)->isPureAffine();
- case Kind::CeilDiv:
- return cast<AffineCeilDivExpr>(this)->isPureAffine();
- case Kind::Mod:
- return cast<AffineModExpr>(this)->isPureAffine();
+ case Kind::Sub: {
+ auto op = cast<AffineBinaryOpExpr>(this);
+ return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine();
}
-}
-bool AffineMulExpr::isPureAffine() const {
- return lhsOperand->isPureAffine() && rhsOperand->isPureAffine() &&
- (isa<AffineConstantExpr>(lhsOperand) ||
- isa<AffineConstantExpr>(rhsOperand));
-}
-
-bool AffineFloorDivExpr::isPureAffine() const {
- return lhsOperand->isPureAffine() && isa<AffineConstantExpr>(rhsOperand);
-}
-
-bool AffineCeilDivExpr::isPureAffine() const {
- return lhsOperand->isPureAffine() && isa<AffineConstantExpr>(rhsOperand);
-}
-
-bool AffineModExpr::isPureAffine() const {
- return lhsOperand->isPureAffine() && isa<AffineConstantExpr>(rhsOperand);
+ case Kind::Mul: {
+ // TODO: Canonicalize the constants in binary operators to the RHS when
+ // possible, allowing this to merge into the next case.
+ auto op = cast<AffineBinaryOpExpr>(this);
+ return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() &&
+ (isa<AffineConstantExpr>(op->getLHS()) ||
+ isa<AffineConstantExpr>(op->getRHS()));
+ }
+ case Kind::FloorDiv:
+ case Kind::CeilDiv:
+ case Kind::Mod: {
+ auto op = cast<AffineBinaryOpExpr>(this);
+ return op->getLHS()->isPureAffine() &&
+ isa<AffineConstantExpr>(op->getRHS());
+ }
+ }
}
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 715459f..d284d0f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -313,27 +313,27 @@
}
void AffineAddExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " + " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " + " << *getRHS() << ")";
}
void AffineSubExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " - " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " - " << *getRHS() << ")";
}
void AffineMulExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " * " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " * " << *getRHS() << ")";
}
void AffineModExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " mod " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " mod " << *getRHS() << ")";
}
void AffineFloorDivExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " floordiv " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
}
void AffineCeilDivExpr::print(raw_ostream &os) const {
- os << "(" << *getLeftOperand() << " ceildiv " << *getRightOperand() << ")";
+ os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
}
void AffineSymbolExpr::print(raw_ostream &os) const {
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index a03f4b8..71aa0a3 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -16,13 +16,25 @@
// =============================================================================
#include "mlir/IR/Builders.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
using namespace mlir;
Builder::Builder(Module *module) : context(module->getContext()) {}
+Identifier Builder::getIdentifier(StringRef str) {
+ return Identifier::get(str, context);
+}
+
+Module *Builder::createModule() { return new Module(context); }
+
+//===----------------------------------------------------------------------===//
// Types.
+//===----------------------------------------------------------------------===//
+
PrimitiveType *Builder::getAffineIntType() {
return Type::getAffineInt(context);
}
@@ -57,3 +69,72 @@
UnrankedTensorType *Builder::getTensorType(Type *elementType) {
return UnrankedTensorType::get(elementType);
}
+
+//===----------------------------------------------------------------------===//
+// Attributes.
+//===----------------------------------------------------------------------===//
+
+BoolAttr *Builder::getBoolAttr(bool value) {
+ return BoolAttr::get(value, context);
+}
+
+IntegerAttr *Builder::getIntegerAttr(int64_t value) {
+ return IntegerAttr::get(value, context);
+}
+
+FloatAttr *Builder::getFloatAttr(double value) {
+ return FloatAttr::get(value, context);
+}
+
+StringAttr *Builder::getStringAttr(StringRef bytes) {
+ return StringAttr::get(bytes, context);
+}
+
+ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
+ return ArrayAttr::get(value, context);
+}
+
+//===----------------------------------------------------------------------===//
+// Affine Expressions and Affine Map.
+//===----------------------------------------------------------------------===//
+
+AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
+ ArrayRef<AffineExpr *> results) {
+ return AffineMap::get(dimCount, symbolCount, results, context);
+}
+
+AffineDimExpr *Builder::getDimExpr(unsigned position) {
+ return AffineDimExpr::get(position, context);
+}
+
+AffineSymbolExpr *Builder::getSymbolExpr(unsigned position) {
+ return AffineSymbolExpr::get(position, context);
+}
+
+AffineConstantExpr *Builder::getConstantExpr(int64_t constant) {
+ return AffineConstantExpr::get(constant, context);
+}
+
+AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineAddExpr::get(lhs, rhs, context);
+}
+
+AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineSubExpr::get(lhs, rhs, context);
+}
+
+AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineMulExpr::get(lhs, rhs, context);
+}
+
+AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineModExpr::get(lhs, rhs, context);
+}
+
+AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineFloorDivExpr::get(lhs, rhs, context);
+}
+
+AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
+ return AffineCeilDivExpr::get(lhs, rhs, context);
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 407c7a3..3c8fc89 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -505,17 +505,17 @@
switch (getToken().getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
- return BoolAttr::get(true, builder.getContext());
+ return builder.getBoolAttr(true);
case Token::kw_false:
consumeToken(Token::kw_false);
- return BoolAttr::get(false, builder.getContext());
+ return builder.getBoolAttr(false);
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)val.getValue(), builder.getContext());
+ return builder.getIntegerAttr((int64_t)val.getValue());
}
case Token::minus: {
@@ -525,7 +525,7 @@
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
- return IntegerAttr::get((int64_t)-val.getValue(), builder.getContext());
+ return builder.getIntegerAttr((int64_t)-val.getValue());
}
return (emitError("expected constant integer or floating point value"),
@@ -535,7 +535,7 @@
case Token::string: {
auto val = getToken().getStringValue();
consumeToken(Token::string);
- return StringAttr::get(val, builder.getContext());
+ return builder.getStringAttr(val);
}
case Token::l_bracket: {
@@ -549,7 +549,7 @@
if (parseCommaSeparatedList(Token::r_bracket, parseElt))
return nullptr;
- return ArrayAttr::get(elements, builder.getContext());
+ return builder.getArrayAttr(elements);
}
default:
// TODO: Handle floating point.
@@ -572,7 +572,7 @@
if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
!getToken().isKeyword())
return emitError("expected attribute name");
- auto nameId = Identifier::get(getTokenSpelling(), builder.getContext());
+ auto nameId = builder.getIdentifier(getTokenSpelling());
consumeToken();
if (!consumeIf(Token::colon))
@@ -671,28 +671,28 @@
"operands has to be either a constant or symbolic");
return nullptr;
}
- return AffineMulExpr::get(lhs, rhs, builder.getContext());
+ return builder.getMulExpr(lhs, rhs);
case FloorDiv:
if (!rhs->isSymbolic()) {
emitError("non-affine expression: right operand of floordiv "
"has to be either a constant or symbolic");
return nullptr;
}
- return AffineFloorDivExpr::get(lhs, rhs, builder.getContext());
+ return builder.getFloorDivExpr(lhs, rhs);
case CeilDiv:
if (!rhs->isSymbolic()) {
emitError("non-affine expression: right operand of ceildiv "
"has to be either a constant or symbolic");
return nullptr;
}
- return AffineCeilDivExpr::get(lhs, rhs, builder.getContext());
+ return builder.getCeilDivExpr(lhs, rhs);
case Mod:
if (!rhs->isSymbolic()) {
emitError("non-affine expression: right operand of mod "
"has to be either a constant or symbolic");
return nullptr;
}
- return AffineModExpr::get(lhs, rhs, builder.getContext());
+ return builder.getModExpr(lhs, rhs);
case HNoOp:
llvm_unreachable("can't create affine expression for null high prec op");
return nullptr;
@@ -705,9 +705,9 @@
AffineExpr *rhs) {
switch (op) {
case AffineLowPrecOp::Add:
- return AffineAddExpr::get(lhs, rhs, builder.getContext());
+ return builder.getAddExpr(lhs, rhs);
case AffineLowPrecOp::Sub:
- return AffineSubExpr::get(lhs, rhs, builder.getContext());
+ return builder.getSubExpr(lhs, rhs);
case AffineLowPrecOp::LNoOp:
llvm_unreachable("can't create affine expression for null low prec op");
return nullptr;
@@ -816,8 +816,8 @@
// Extra error message although parseAffineOperandExpr would have
// complained. Leads to a better diagnostic.
return (emitError("missing operand of negation"), nullptr);
- auto *minusOne = AffineConstantExpr::get(-1, builder.getContext());
- return AffineMulExpr::get(minusOne, operand, builder.getContext());
+ auto *minusOne = builder.getConstantExpr(-1);
+ return builder.getMulExpr(minusOne, operand);
}
/// Parse a bare id that may appear in an affine expression.
@@ -830,11 +830,11 @@
StringRef sRef = getTokenSpelling();
if (dims.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineDimExpr::get(dims.lookup(sRef), builder.getContext());
+ return builder.getDimExpr(dims.lookup(sRef));
}
if (symbols.count(sRef)) {
consumeToken(Token::bare_identifier);
- return AffineSymbolExpr::get(symbols.lookup(sRef), builder.getContext());
+ return builder.getSymbolExpr(symbols.lookup(sRef));
}
return (emitError("identifier is neither dimensional nor symbolic"), nullptr);
}
@@ -854,7 +854,7 @@
return (emitError("constant too large for affineint"), nullptr);
}
consumeToken(Token::integer);
- return AffineConstantExpr::get((int64_t)val.getValue(), builder.getContext());
+ return builder.getConstantExpr((int64_t)val.getValue());
}
/// Parses an expression that can be a valid operand of an affine expression.
@@ -1054,8 +1054,7 @@
return nullptr;
// Parsed a valid affine map.
- return AffineMap::get(dims.size(), symbols.size(), exprs,
- builder.getContext());
+ return builder.getAffineMap(dims.size(), symbols.size(), exprs);
}
AffineMap *Parser::parseAffineMapInline() {
@@ -1288,7 +1287,7 @@
}
// TODO: Don't drop result name and operand names on the floor.
- auto nameId = Identifier::get(name, builder.getContext());
+ auto nameId = builder.getIdentifier(name);
return builder.createOperation(nameId, attributes);
}