Parsing support for affine maps and affine expressions
A recursive descent parser for affine maps/expressions with operator precedence and
associativity. (While on this, sketch out uniqui'ing functionality for affine maps
and affine binary op expressions (partly).)
PiperOrigin-RevId: 203222063
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index b5bf8f2..a9dfc09 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -33,40 +33,34 @@
/// AffineExpression's are immutable (like Type's)
class AffineExpr {
public:
- enum class Kind {
- // Add.
- Add,
- // Mul.
- Mul,
- // Mod.
- Mod,
- // Floordiv
- FloorDiv,
- // Ceildiv
- CeilDiv,
+ enum class Kind {
+ Add,
+ Sub,
+ Mul,
+ Mod,
+ FloorDiv,
+ CeilDiv,
- /// This is a marker for the last affine binary op. The range of binary op's
- /// is expected to be this element and earlier.
- LAST_AFFINE_BINARY_OP = CeilDiv,
+ /// This is a marker for the last affine binary op. The range of binary
+ /// op's is expected to be this element and earlier.
+ LAST_AFFINE_BINARY_OP = CeilDiv,
- // Unary op negation
- Neg,
+ // Unary op negation
+ Neg,
- // Constant integer.
- Constant,
- // Dimensional identifier.
- DimId,
- // Symbolic identifier.
- SymbolId,
- };
+ // Constant integer.
+ Constant,
+ // Dimensional identifier.
+ DimId,
+ // Symbolic identifier.
+ SymbolId,
+ };
- /// Return the classification for this type.
- Kind getKind() const { return kind; }
+ /// Return the classification for this type.
+ Kind getKind() const { return kind; }
- ~AffineExpr() = default;
-
- void print(raw_ostream &os) const;
- void dump() const;
+ void print(raw_ostream &os) const;
+ void dump() const;
protected:
explicit AffineExpr(Kind kind) : kind(kind) {}
@@ -76,12 +70,14 @@
const Kind kind;
};
+inline raw_ostream &operator<<(raw_ostream &os, const AffineExpr &expr) {
+ expr.print(os);
+ return os;
+}
+
/// Binary affine expression.
class AffineBinaryOpExpr : public AffineExpr {
public:
- static AffineBinaryOpExpr *get(Kind kind, AffineExpr *lhsOperand,
- AffineExpr *rhsOperand, MLIRContext *context);
-
AffineExpr *getLeftOperand() const { return lhsOperand; }
AffineExpr *getRightOperand() const { return rhsOperand; }
@@ -91,12 +87,15 @@
}
protected:
- explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhsOperand,
- AffineExpr *rhsOperand)
- : AffineExpr(kind), lhsOperand(lhsOperand), rhsOperand(rhsOperand) {}
+ static AffineBinaryOpExpr *get(Kind kind, AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand, MLIRContext *context);
- AffineExpr *const lhsOperand;
- AffineExpr *const rhsOperand;
+ explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand)
+ : AffineExpr(kind), lhsOperand(lhsOperand), rhsOperand(rhsOperand) {}
+
+ AffineExpr *const lhsOperand;
+ AffineExpr *const rhsOperand;
};
/// Binary affine add expression.
@@ -109,34 +108,60 @@
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Add;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineAddExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
: AffineBinaryOpExpr(Kind::Add, lhsOperand, rhsOperand) {}
};
+/// Binary affine sub expression.
+class AffineSubExpr : public AffineBinaryOpExpr {
+public:
+ static AffineSubExpr *get(AffineExpr *lhsOperand, AffineExpr *rhsOperand,
+ MLIRContext *context);
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const AffineExpr *expr) {
+ return expr->getKind() == Kind::Sub;
+ }
+ void print(raw_ostream &os) const;
+
+private:
+ explicit AffineSubExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
+ : AffineBinaryOpExpr(Kind::Sub, lhsOperand, rhsOperand) {}
+};
+
/// Binary affine mul expression.
class AffineMulExpr : public AffineBinaryOpExpr {
- public:
+public:
+ static AffineMulExpr *get(AffineExpr *lhsOperand, AffineExpr *rhsOperand,
+ MLIRContext *context);
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Mul;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineMulExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
: AffineBinaryOpExpr(Kind::Mul, lhsOperand, rhsOperand) {}
};
/// Binary affine mod expression.
class AffineModExpr : public AffineBinaryOpExpr {
- public:
+public:
+ static AffineModExpr *get(AffineExpr *lhsOperand, AffineExpr *rhsOperand,
+ MLIRContext *context);
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Mod;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineModExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
: AffineBinaryOpExpr(Kind::Mod, lhsOperand, rhsOperand) {}
};
@@ -144,32 +169,40 @@
/// Binary affine floordiv expression.
class AffineFloorDivExpr : public AffineBinaryOpExpr {
public:
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const AffineExpr *expr) {
- return expr->getKind() == Kind::FloorDiv;
- }
+ static AffineFloorDivExpr *get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand, MLIRContext *context);
- private:
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const AffineExpr *expr) {
+ return expr->getKind() == Kind::FloorDiv;
+ }
+ void print(raw_ostream &os) const;
+
+private:
explicit AffineFloorDivExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
: AffineBinaryOpExpr(Kind::FloorDiv, lhsOperand, rhsOperand) {}
};
/// Binary affine ceildiv expression.
class AffineCeilDivExpr : public AffineBinaryOpExpr {
- public:
+public:
+ static AffineCeilDivExpr *get(AffineExpr *lhsOperand, AffineExpr *rhsOperand,
+ MLIRContext *context);
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::CeilDiv;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineCeilDivExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand)
: AffineBinaryOpExpr(Kind::CeilDiv, lhsOperand, rhsOperand) {}
};
/// Unary affine expression.
class AffineUnaryOpExpr : public AffineExpr {
- public:
+public:
static AffineUnaryOpExpr *get(const AffineExpr &operand,
MLIRContext *context);
@@ -180,17 +213,22 @@
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Neg;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineUnaryOpExpr(Kind kind, AffineExpr *operand)
: AffineExpr(kind), operand(operand) {}
AffineExpr *operand;
};
-/// A argument identifier appearing in an affine expression
+/// A dimensional identifier appearing in an affine expression.
+///
+/// This is a POD type of int size; so it should be passed around by
+/// value. The underlying data is owned by MLIRContext and is thus immortal for
+/// almost all clients.
class AffineDimExpr : public AffineExpr {
- public:
+public:
static AffineDimExpr *get(unsigned position, MLIRContext *context);
unsigned getPosition() const { return position; }
@@ -199,8 +237,9 @@
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::DimId;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineDimExpr(unsigned position)
: AffineExpr(Kind::DimId), position(position) {}
@@ -208,7 +247,11 @@
unsigned position;
};
-/// A symbolic identifier appearing in an affine expression
+/// A symbolic identifier appearing in an affine expression.
+//
+/// This is a POD type of int size, so it should be passed around by
+/// value. The underlying data is owned by MLIRContext and is thus immortal for
+/// almost all clients.
class AffineSymbolExpr : public AffineExpr {
public:
static AffineSymbolExpr *get(unsigned position, MLIRContext *context);
@@ -219,8 +262,9 @@
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::SymbolId;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineSymbolExpr(unsigned position)
: AffineExpr(Kind::SymbolId), position(position) {}
@@ -239,8 +283,9 @@
static bool classof(const AffineExpr *expr) {
return expr->getKind() == Kind::Constant;
}
+ void print(raw_ostream &os) const;
- private:
+private:
explicit AffineConstantExpr(int64_t constant)
: AffineExpr(Kind::Constant), constant(constant) {}
diff --git a/include/mlir/IR/AffineMap.h b/include/mlir/IR/AffineMap.h
index 6266050..4c9367a 100644
--- a/include/mlir/IR/AffineMap.h
+++ b/include/mlir/IR/AffineMap.h
@@ -39,28 +39,33 @@
/// The names used (d0, d1) don't matter - it's the mathematical function that
/// is unique to this affine map.
class AffineMap {
- public:
+public:
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
- ArrayRef<AffineExpr *> exprs,
- MLIRContext *context);
+ ArrayRef<AffineExpr *> results, MLIRContext *context);
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
void dump() const;
- unsigned dimCount() const { return numDims; }
- unsigned symbolCount() const { return numSymbols; }
+ unsigned getNumDims() const { return numDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+ unsigned getNumResults() const { return numResults; }
+
+ ArrayRef<AffineExpr *> getResults() const {
+ return ArrayRef<AffineExpr *>(results, numResults);
+ }
private:
- AffineMap(unsigned dimCount, unsigned symbolCount,
- ArrayRef<AffineExpr *> exprs);
+ AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
+ AffineExpr *const *results);
- const unsigned numDims;
- const unsigned numSymbols;
+ const unsigned numDims;
+ const unsigned numSymbols;
+ const unsigned numResults;
- /// The affine expressions for this (multi-dimensional) map.
- /// TODO: use trailing objects for these
- ArrayRef<AffineExpr *> exprs;
+ /// The affine expressions for this (multi-dimensional) map.
+ /// TODO: use trailing objects for this.
+ AffineExpr *const *const results;
};
} // end namespace mlir
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 8631751..cba3094 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -20,10 +20,7 @@
using namespace mlir;
-// TODO(clattner): make this ctor take an LLVMContext. This will eventually
-// copy the elements into the context.
-AffineMap::AffineMap(unsigned dimCount, unsigned symbolCount,
- ArrayRef<AffineExpr *> exprs)
- : numDims(dimCount), numSymbols(symbolCount), exprs(exprs) {
- // TODO(bondhugula)
-}
+AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
+ AffineExpr *const *results)
+ : numDims(numDims), numSymbols(numSymbols), numResults(numResults),
+ results(results) {}
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index b871066..b24eb2f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -251,13 +251,95 @@
print(llvm::errs());
}
+void AffineExpr::dump() const {
+ print(llvm::errs());
+ llvm::errs() << "\n";
+}
+
+void AffineAddExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " + " << *getRightOperand() << ")";
+}
+
+void AffineSubExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " - " << *getRightOperand() << ")";
+}
+
+void AffineMulExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " * " << *getRightOperand() << ")";
+}
+
+void AffineModExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " mod " << *getRightOperand() << ")";
+}
+
+void AffineFloorDivExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " floordiv " << *getRightOperand() << ")";
+}
+
+void AffineCeilDivExpr::print(raw_ostream &os) const {
+ os << "(" << *getLeftOperand() << " ceildiv " << *getRightOperand() << ")";
+}
+
+void AffineSymbolExpr::print(raw_ostream &os) const {
+ os << "s" << getPosition();
+}
+
+void AffineDimExpr::print(raw_ostream &os) const { os << "d" << getPosition(); }
+
+void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
+
void AffineExpr::print(raw_ostream &os) const {
- // TODO(bondhugula): print out affine expression
+ switch (getKind()) {
+ case Kind::SymbolId:
+ return cast<AffineSymbolExpr>(this)->print(os);
+ case Kind::DimId:
+ return cast<AffineDimExpr>(this)->print(os);
+ case Kind::Constant:
+ return cast<AffineConstantExpr>(this)->print(os);
+ case Kind::Add:
+ return cast<AffineAddExpr>(this)->print(os);
+ case Kind::Sub:
+ return cast<AffineSubExpr>(this)->print(os);
+ case Kind::Mul:
+ return cast<AffineMulExpr>(this)->print(os);
+ case Kind::FloorDiv:
+ return cast<AffineFloorDivExpr>(this)->print(os);
+ case Kind::CeilDiv:
+ return cast<AffineCeilDivExpr>(this)->print(os);
+ case Kind::Mod:
+ return cast<AffineModExpr>(this)->print(os);
+ default:
+ os << "<unimplemented expr>";
+ return;
+ }
}
void AffineMap::print(raw_ostream &os) const {
- // TODO(andydavis) Print out affine map based on dimensionCount and
- // symbolCount: (d0, d1) [S0, S1] -> (d0 + S0, d1 + S1)
+ // Dimension identifiers.
+ os << "(";
+ for (int i = 0; i < (int)getNumDims() - 1; i++)
+ os << "d" << i << ", ";
+ if (getNumDims() >= 1)
+ os << "d" << getNumDims() - 1;
+ os << ")";
+
+ // Symbolic identifiers.
+ if (getNumSymbols() >= 1) {
+ os << " [";
+ for (int i = 0; i < (int)getNumSymbols() - 1; i++)
+ os << "s" << i << ", ";
+ if (getNumSymbols() >= 1)
+ os << "s" << getNumSymbols() - 1;
+ os << "]";
+ }
+
+ // AffineMap should have at least one result.
+ assert(!getResults().empty());
+ // Result affine expressions.
+ os << " -> (";
+ interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
+ [&]() { os << ", "; });
+ os << ")\n";
}
void BasicBlock::print(raw_ostream &os) const {
@@ -300,8 +382,11 @@
}
void Module::print(raw_ostream &os) const {
- for (auto *map : affineMapList)
+ unsigned id = 0;
+ for (auto *map : affineMapList) {
+ os << "#" << id++ << " = ";
map->print(os);
+ }
for (auto *fn : functionList)
fn->print(os);
}
@@ -309,4 +394,3 @@
void Module::dump() const {
print(llvm::errs());
}
-
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 85ff432..b6e1faa 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -46,20 +46,27 @@
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
}
};
+
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
- // Affine maps are uniqued based on their arguments and affine expressions
- using KeyTy = std::pair<unsigned, unsigned>;
+ // Affine maps are uniqued based on their dim/symbol counts and affine
+ // expressions.
+ using KeyTy =
+ std::pair<std::pair<unsigned, unsigned>, ArrayRef<AffineExpr *>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
static unsigned getHashValue(KeyTy key) {
- // FIXME(bondhugula): placeholder for now
- return hash_combine(key.first, key.second);
+ return hash_combine(
+ key.first.first, key.first.second,
+ hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
- // TODO(bondhugula)
- return false;
+ static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == KeyTy(std::pair<unsigned, unsigned>(rhs->getNumDims(),
+ rhs->getNumSymbols()),
+ rhs->getResults());
}
};
@@ -120,6 +127,15 @@
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
AffineMapSet affineMaps;
+ // Affine binary op expression uniquing. We don't need to unique dimensional
+ // or symbolic identifiers.
+ // std::tuple doesn't work with DesnseMap!
+ // DenseSet<AffineBinaryOpExpr *, AffineBinaryOpExprKeyInfo>;
+ // AffineExprSet affineExprs;
+ DenseMap<std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>,
+ AffineBinaryOpExpr *>
+ affineExprs;
+
/// Integer type uniquing.
DenseMap<unsigned, IntegerType*> integers;
@@ -325,34 +341,101 @@
return existing.first->second = result;
}
-// TODO(bondhugula,andydavis): unique affine maps based on dim list,
-// symbol list and all affine expressions contained
-AffineMap *AffineMap::get(unsigned dimCount,
- unsigned symbolCount,
- ArrayRef<AffineExpr *> exprs,
+AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
+ ArrayRef<AffineExpr *> results,
MLIRContext *context) {
- // TODO(bondhugula)
- return new AffineMap(dimCount, symbolCount, exprs);
+ // The number of results can't be zero.
+ assert(!results.empty());
+
+ auto &impl = context->getImpl();
+
+ // Check if we already have this affine map.
+ AffineMapKeyInfo::KeyTy key(
+ std::pair<unsigned, unsigned>(dimCount, symbolCount), results);
+ auto existing = impl.affineMaps.insert_as(nullptr, key);
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return *existing.first;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *res = impl.allocator.Allocate<AffineMap>();
+
+ // Copy the results into the bump pointer.
+ results = impl.copyInto(ArrayRef<AffineExpr *>(results));
+
+ // Initialize the memory using placement new.
+ new (res) AffineMap(dimCount, symbolCount, results.size(), results.data());
+
+ // Cache and return it.
+ return *existing.first = res;
}
+// TODO(bondhugula): complete uniqu'ing of remaining AffinExpr sub-classes
AffineBinaryOpExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind,
AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
MLIRContext *context) {
- // TODO(bondhugula): allocate this through context
- // FIXME
- return new AffineBinaryOpExpr(kind, lhsOperand, rhsOperand);
+ auto &impl = context->getImpl();
+
+ // Check if we already have this affine expression.
+ auto key = std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>(
+ (unsigned)kind,
+ std::pair<AffineExpr *, AffineExpr *>(lhsOperand, rhsOperand));
+ auto *&result = impl.affineExprs[key];
+
+ // If we already have it, return that value.
+ if (!result) {
+ // On the first use, we allocate them into the bump pointer.
+ result = impl.allocator.Allocate<AffineBinaryOpExpr>();
+
+ // Initialize the memory using placement new.
+ new (result) AffineBinaryOpExpr(kind, lhsOperand, rhsOperand);
+ }
+ return result;
}
AffineAddExpr *AffineAddExpr::get(AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
MLIRContext *context) {
- // TODO(bondhugula): allocate this through context
- // FIXME
- return new AffineAddExpr(lhsOperand, rhsOperand);
+ return cast<AffineAddExpr>(
+ AffineBinaryOpExpr::get(Kind::Add, lhsOperand, rhsOperand, context));
}
-// TODO(bondhugula): add functions for AffineMulExpr, mod, floordiv, ceildiv
+AffineSubExpr *AffineSubExpr::get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand,
+ MLIRContext *context) {
+ return cast<AffineSubExpr>(
+ AffineBinaryOpExpr::get(Kind::Sub, lhsOperand, rhsOperand, context));
+}
+
+AffineMulExpr *AffineMulExpr::get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand,
+ MLIRContext *context) {
+ return cast<AffineMulExpr>(
+ AffineBinaryOpExpr::get(Kind::Mul, lhsOperand, rhsOperand, context));
+}
+
+AffineFloorDivExpr *AffineFloorDivExpr::get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand,
+ MLIRContext *context) {
+ return cast<AffineFloorDivExpr>(
+ AffineBinaryOpExpr::get(Kind::FloorDiv, lhsOperand, rhsOperand, context));
+}
+
+AffineCeilDivExpr *AffineCeilDivExpr::get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand,
+ MLIRContext *context) {
+ return cast<AffineCeilDivExpr>(
+ AffineBinaryOpExpr::get(Kind::CeilDiv, lhsOperand, rhsOperand, context));
+}
+
+AffineModExpr *AffineModExpr::get(AffineExpr *lhsOperand,
+ AffineExpr *rhsOperand,
+ MLIRContext *context) {
+ return cast<AffineModExpr>(
+ AffineBinaryOpExpr::get(Kind::Mod, lhsOperand, rhsOperand, context));
+}
AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
// TODO(bondhugula): complete this
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 8943200..3872565 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -91,7 +91,7 @@
++curPtr;
return formToken(Token::arrow, tokStart);
}
- return emitError(tokStart, "unexpected character");
+ return formToken(Token::minus, tokStart);
case '?':
if (*curPtr == '?') {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index dd0112e..8c44bbd 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -44,6 +44,26 @@
ParseFailure
};
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+ /// Null value.
+ LNoOp,
+ Add,
+ Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false in
+/// the boolean sense.
+enum AffineHighPrecOp {
+ /// Null value.
+ HNoOp,
+ Mul,
+ FloorDiv,
+ CeilDiv,
+ Mod
+};
+
/// Main parser implementation.
class Parser {
public:
@@ -109,6 +129,10 @@
return true;
}
+ // Binary affine op parsing
+ AffineLowPrecOp consumeIfLowPrecOp();
+ AffineHighPrecOp consumeIfHighPrecOp();
+
ParseResult parseCommaSeparatedList(Token::Kind rightToken,
const std::function<ParseResult()> &parseElement,
bool allowEmptyList = true);
@@ -129,19 +153,37 @@
Type *parseType();
ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
- // Identifiers
- ParseResult parseDimIdList(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols);
- ParseResult parseSymbolIdList(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols);
- StringRef parseDimOrSymbolId(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols,
- bool symbol);
+ // Parsing identifiers' lists for polyhedral structures.
+ ParseResult parseDimIdList(AffineMapParserState &state);
+ ParseResult parseSymbolIdList(AffineMapParserState &state);
+ ParseResult parseDimOrSymbolId(AffineMapParserState &state, bool dim);
- // Polyhedral structures
+ // Polyhedral structures.
ParseResult parseAffineMapDef();
- AffineMap *parseAffineMapInline(StringRef mapId);
- AffineExpr *parseAffineExpr(AffineMapParserState &state);
+ ParseResult parseAffineMapInline(StringRef mapId, AffineMap *&affineMap);
+ AffineExpr *parseAffineExpr(const AffineMapParserState &state);
+
+ AffineExpr *parseParentheticalExpr(const AffineMapParserState &state);
+ AffineExpr *parseIntegerExpr(const AffineMapParserState &state);
+ AffineExpr *parseBareIdExpr(const AffineMapParserState &state);
+
+ static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineHighPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs,
+ MLIRContext *context);
+ static AffineBinaryOpExpr *getBinaryAffineOpExpr(AffineLowPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs,
+ MLIRContext *context);
+ ParseResult parseAffineOperandExpr(const AffineMapParserState &state,
+ AffineExpr *&result);
+ ParseResult parseAffineLowPrecOpExpr(AffineExpr *llhs, AffineLowPrecOp llhsOp,
+ const AffineMapParserState &state,
+ AffineExpr *&result);
+ ParseResult parseAffineHighPrecOpExpr(AffineExpr *llhs,
+ AffineHighPrecOp llhsOp,
+ const AffineMapParserState &state,
+ AffineExpr *&result);
// Functions.
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
@@ -482,32 +524,21 @@
/// expression.
class AffineMapParserState {
public:
- explicit AffineMapParserState(ArrayRef<StringRef> dims,
- ArrayRef<StringRef> symbols) :
- dims_(dims), symbols_(symbols) {}
+ explicit AffineMapParserState() {}
- unsigned dimCount() const { return dims_.size(); }
- unsigned symbolCount() const { return symbols_.size(); }
+ void addDim(StringRef sRef) { dims.insert({sRef, dims.size()}); }
+ void addSymbol(StringRef sRef) { symbols.insert({sRef, symbols.size()}); }
- // Stack operations for affine expression parsing
- // TODO(bondhugula): all of this will be improved/made more principled
- void pushAffineExpr(AffineExpr *expr) { exprStack.push(expr); }
- AffineExpr *popAffineExpr() {
- auto *t = exprStack.top();
- exprStack.pop();
- return t;
- }
- AffineExpr *topAffineExpr() { return exprStack.top(); }
+ unsigned getNumDims() const { return dims.size(); }
+ unsigned getNumSymbols() const { return symbols.size(); }
- ArrayRef<StringRef> getDims() const { return dims_; }
- ArrayRef<StringRef> getSymbols() const { return symbols_; }
+ // TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
+ const llvm::StringMap<unsigned> &getDims() const { return dims; }
+ const llvm::StringMap<unsigned> &getSymbols() const { return symbols; }
private:
- const ArrayRef<StringRef> dims_;
- const ArrayRef<StringRef> symbols_;
-
- // TEMP: stack to hold affine expressions
- std::stack<AffineExpr *> exprStack;
+ llvm::StringMap<unsigned> dims;
+ llvm::StringMap<unsigned> symbols;
};
} // end anonymous namespace
@@ -533,16 +564,306 @@
if (!consumeIf(Token::equal))
return emitError("expected '=' in affine map outlined definition");
- auto *affineMap = parseAffineMapInline(affineMapId);
- affineMaps[affineMapId].reset(affineMap);
- if (!affineMap) return ParseFailure;
+ AffineMap *affineMap = nullptr;
+ if (parseAffineMapInline(affineMapId, affineMap))
+ return ParseFailure;
+ // TODO(bondhugula): Disable adding affineMapId to Parser::affineMaps for now;
+ // instead add to module for easy printing.
module->affineMapList.push_back(affineMap);
- return affineMap ? ParseSuccess : ParseFailure;
+
+ return ParseSuccess;
}
+/// Create an affine op expression
+AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineHighPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs,
+ MLIRContext *context) {
+ switch (op) {
+ case Mul:
+ return AffineMulExpr::get(lhs, rhs, context);
+ case FloorDiv:
+ return AffineFloorDivExpr::get(lhs, rhs, context);
+ case CeilDiv:
+ return AffineCeilDivExpr::get(lhs, rhs, context);
+ case Mod:
+ return AffineModExpr::get(lhs, rhs, context);
+ case HNoOp:
+ llvm_unreachable("can't create affine expression for null high prec op");
+ return nullptr;
+ }
+}
+
+AffineBinaryOpExpr *Parser::getBinaryAffineOpExpr(AffineLowPrecOp op,
+ AffineExpr *lhs,
+ AffineExpr *rhs,
+ MLIRContext *context) {
+ switch (op) {
+ case AffineLowPrecOp::Add:
+ return AffineAddExpr::get(lhs, rhs, context);
+ case AffineLowPrecOp::Sub:
+ return AffineSubExpr::get(lhs, rhs, context);
+ case AffineLowPrecOp::LNoOp:
+ llvm_unreachable("can't create affine expression for null low prec op");
+ return nullptr;
+ }
+}
+
+/// Parses an expression that can be a valid operand of an affine expression
+/// (where associativity may not have been specified through parentheses).
+// Eg: for an expression without parentheses (like i + j + k + l), each
+// of the four identifiers is an operand. For: i + j*k + l, j*k is not an
+// operand expression, it's an op expression and will be parsed via
+// parseAffineLowPrecOpExpression().
+ParseResult Parser::parseAffineOperandExpr(const AffineMapParserState &state,
+ AffineExpr *&result) {
+ result = parseParentheticalExpr(state);
+ if (!result)
+ result = parseBareIdExpr(state);
+ if (!result)
+ result = parseIntegerExpr(state);
+ return result ? ParseSuccess : ParseFailure;
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+/// expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs * lhs) * rhs, or (lhs * rhs) if llhs is null. If
+/// no rhs can be found, returns (llhs * lhs) or lhs if llhs is null.
+// TODO(bondhugula): check whether mul is w.r.t. a constant - otherwise, the
+/// map is semi-affine.
+ParseResult Parser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
+ AffineHighPrecOp llhsOp,
+ const AffineMapParserState &state,
+ AffineExpr *&result) {
+ // FIXME: Assume for now that llhsOp is mul.
+ AffineExpr *lhs = nullptr;
+ if (parseAffineOperandExpr(state, lhs)) {
+ return ParseFailure;
+ }
+ AffineHighPrecOp op = HNoOp;
+ // Found an LHS. Parse the remaining expression.
+ if ((op = consumeIfHighPrecOp())) {
+ if (llhs) {
+ // TODO(bondhugula): check whether 'lhs' here is a constant (for affine
+ // maps); semi-affine maps allow symbols.
+ AffineExpr *expr =
+ Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *subRes = nullptr;
+ if (parseAffineHighPrecOpExpr(expr, op, state, subRes)) {
+ if (!subRes)
+ emitError("missing right operand of multiply op");
+ // In spite of the error, setting result to prevent duplicate errors
+ // messages as the call stack unwinds. All of this due to left
+ // associativity.
+ result = expr;
+ return ParseFailure;
+ }
+ result = subRes ? subRes : expr;
+ return ParseSuccess;
+ }
+ // No LLHS, get RHS
+ AffineExpr *subRes = nullptr;
+ if (parseAffineHighPrecOpExpr(lhs, op, state, subRes)) {
+ // 'product' needs to be checked to prevent duplicate errors messages as
+ // the call stack unwinds. All of this due to left associativity.
+ if (!subRes)
+ emitError("missing right operand of multiply op");
+ return ParseFailure;
+ }
+ result = subRes;
+ return ParseSuccess;
+ }
+
+ // This is the last operand in this expression.
+ if (llhs) {
+ // TODO(bondhugula): check whether lhs here is a constant (for affine
+ // maps); semi-affine maps allow symbols.
+ result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ return ParseSuccess;
+ }
+
+ // No llhs, 'lhs' itself is the expression.
+ result = lhs;
+ return ParseSuccess;
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only two
+/// precedence levels)
+AffineLowPrecOp Parser::consumeIfLowPrecOp() {
+ switch (curToken.getKind()) {
+ case Token::plus:
+ consumeToken(Token::plus);
+ return AffineLowPrecOp::Add;
+ case Token::minus:
+ consumeToken(Token::minus);
+ return AffineLowPrecOp::Sub;
+ default:
+ return AffineLowPrecOp::LNoOp;
+ }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp Parser::consumeIfHighPrecOp() {
+ switch (curToken.getKind()) {
+ case Token::star:
+ consumeToken(Token::star);
+ return Mul;
+ case Token::kw_floordiv:
+ consumeToken(Token::kw_floordiv);
+ return FloorDiv;
+ case Token::kw_ceildiv:
+ consumeToken(Token::kw_ceildiv);
+ return CeilDiv;
+ case Token::kw_mod:
+ consumeToken(Token::kw_mod);
+ return Mod;
+ default:
+ return HNoOp;
+ }
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
///
-/// Parse a multi-dimensional affine expression
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. mul, div, and mod are
+/// at the same higher precedence level.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs + lhs) + rhs) if llhs is non null, and
+/// lhs + rhs otherwise; if there is no rhs, llhs + lhs is returned if llhs is
+/// non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4.
+///
+// TODO(bondhugula): add support for unary op negation. Assuming for now that
+// the op to associate with llhs is add.
+ParseResult Parser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
+ AffineLowPrecOp llhsOp,
+ const AffineMapParserState &state,
+ AffineExpr *&result) {
+ AffineExpr *lhs = nullptr;
+ if (parseAffineOperandExpr(state, lhs))
+ return ParseFailure;
+
+ // Found an LHS. Deal with the ops.
+ AffineLowPrecOp lOp;
+ AffineHighPrecOp rOp;
+ if ((lOp = consumeIfLowPrecOp())) {
+ if (llhs) {
+ AffineExpr *sum =
+ Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ AffineExpr *recSum = nullptr;
+ parseAffineLowPrecOpExpr(sum, lOp, state, recSum);
+ result = recSum ? recSum : sum;
+ return ParseSuccess;
+ }
+ // No LLHS, get RHS and form the expression.
+ if (parseAffineLowPrecOpExpr(lhs, lOp, state, result)) {
+ if (!result)
+ emitError("missing right operand of add op");
+ return ParseFailure;
+ }
+ return ParseSuccess;
+ } else if ((rOp = consumeIfHighPrecOp())) {
+ // We have a higher precedence op here. Get the rhs operand for the llhs
+ // through parseAffineHighPrecOpExpr.
+ AffineExpr *highRes = nullptr;
+ if (parseAffineHighPrecOpExpr(lhs, rOp, state, highRes)) {
+ // 'product' needs to be checked to prevent duplicate errors messages as
+ // the call stack unwinds. All of this due to left associativity.
+ if (!highRes)
+ emitError("missing right operand of binary op");
+ return ParseFailure;
+ }
+ // If llhs is null, the product forms the first operand of the yet to be
+ // found expression. If non-null, assume for now that the op to associate
+ // with llhs is add.
+ AffineExpr *expr =
+ llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes, context) : highRes;
+ // Recurse for subsequent add's after the affine mul expression
+ AffineLowPrecOp nextOp = consumeIfLowPrecOp();
+ if (nextOp) {
+ AffineExpr *sumProd = nullptr;
+ parseAffineLowPrecOpExpr(expr, nextOp, state, sumProd);
+ result = sumProd ? sumProd : expr;
+ } else {
+ result = expr;
+ }
+ return ParseSuccess;
+ } else {
+ // Last operand in the expression list.
+ if (llhs) {
+ result = Parser::getBinaryAffineOpExpr(llhsOp, llhs, lhs, context);
+ return ParseSuccess;
+ }
+ // No llhs, 'lhs' itself is the expression.
+ result = lhs;
+ return ParseSuccess;
+ }
+}
+
+/// Parse an affine expression inside parentheses.
+/// affine-expr ::= `(` affine-expr `)`
+AffineExpr *Parser::parseParentheticalExpr(const AffineMapParserState &state) {
+ if (!consumeIf(Token::l_paren)) {
+ return nullptr;
+ }
+ auto *expr = parseAffineExpr(state);
+ if (!consumeIf(Token::r_paren)) {
+ emitError("expected ')'");
+ return nullptr;
+ }
+ if (!expr)
+ emitError("no expression inside parentheses");
+ return expr;
+}
+
+/// Parse a bare id that may appear in an affine expression.
+/// affine-expr ::= bare-id
+AffineExpr *Parser::parseBareIdExpr(const AffineMapParserState &state) {
+ if (curToken.is(Token::bare_identifier)) {
+ StringRef sRef = curToken.getSpelling();
+ const auto &dims = state.getDims();
+ const auto &symbols = state.getSymbols();
+ if (dims.count(sRef)) {
+ consumeToken(Token::bare_identifier);
+ return AffineDimExpr::get(dims.lookup(sRef), context);
+ }
+ if (symbols.count(sRef)) {
+ consumeToken(Token::bare_identifier);
+ return AffineSymbolExpr::get(symbols.lookup(sRef), context);
+ }
+ return emitError("identifier is neither dimensional nor symbolic"), nullptr;
+ }
+ return nullptr;
+}
+
+/// Parse an integral constant appearing in an affine expression.
+/// affine-expr ::= `-`? integer-literal
+/// TODO(bondhugula): handle negative numbers.
+AffineExpr *Parser::parseIntegerExpr(const AffineMapParserState &state) {
+ if (curToken.is(Token::integer)) {
+ auto *expr = AffineConstantExpr::get(
+ curToken.getUnsignedIntegerValue().getValue(), context);
+ consumeToken(Token::integer);
+ return expr;
+ }
+ return nullptr;
+}
+
+/// Parse an affine expression.
/// affine-expr ::= `(` affine-expr `)`
/// | affine-expr `+` affine-expr
/// | affine-expr `-` affine-expr
@@ -552,171 +873,118 @@
/// | affine-expr `mod` integer-literal
/// | bare-id
/// | `-`? integer-literal
-/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
-///
/// Use 'state' to check if valid identifiers appear.
-///
-AffineExpr *Parser::parseAffineExpr(AffineMapParserState &state) {
- // TODO(bondhugula): complete support for this
- // The code below is all placeholder / it is wrong / not complete
- // Operator precedence not considered; pure left to right associativity
- if (curToken.is(Token::comma)) {
- emitError("expecting affine expression");
- return nullptr;
+// TODO(bondhugula): check if mul, mod, div take integral constants
+AffineExpr *Parser::parseAffineExpr(const AffineMapParserState &state) {
+ switch (curToken.getKind()) {
+ case Token::l_paren:
+ case Token::kw_ceildiv:
+ case Token::kw_floordiv:
+ case Token::bare_identifier:
+ case Token::integer: {
+ AffineExpr *result = nullptr;
+ parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp, state, result);
+ return result;
}
- while (curToken.isNot(Token::comma, Token::r_paren,
- Token::eof, Token::error)) {
- switch (curToken.getKind()) {
- case Token::bare_identifier: {
- // TODO(bondhugula): look up state to see if it's a symbol or dim_id and
- // get its position
- AffineExpr *expr = AffineDimExpr::get(0, context);
- state.pushAffineExpr(expr);
- consumeToken(Token::bare_identifier);
- break;
- }
- case Token::plus: {
- consumeToken(Token::plus);
- if (state.topAffineExpr()) {
- auto lChild = state.popAffineExpr();
- auto rChild = parseAffineExpr(state);
- if (rChild) {
- auto binaryOpExpr = AffineAddExpr::get(lChild, rChild, context);
- state.popAffineExpr();
- state.pushAffineExpr(binaryOpExpr);
- } else {
- emitError("right operand of + missing");
- }
- } else {
- emitError("left operand of + missing");
- }
- break;
- }
- case Token::integer: {
- AffineExpr *expr = AffineConstantExpr::get(
- curToken.getUnsignedIntegerValue().getValue(), context);
- state.pushAffineExpr(expr);
- consumeToken(Token::integer);
- break;
- }
- case Token::l_paren: {
- consumeToken(Token::l_paren);
- break;
- }
- case Token::r_paren: {
- consumeToken(Token::r_paren);
- break;
- }
- default: {
- emitError("affine map expr parse impl incomplete/unexpected token");
- return nullptr;
- }
- }
- }
- if (!state.topAffineExpr()) {
- // An error will be emitted by parse comma separated list on an empty list
+ case Token::plus:
+ case Token::minus:
+ case Token::star:
+ emitError("left operand of binary op missing");
+ return nullptr;
+
+ default:
return nullptr;
}
- return state.topAffineExpr();
}
-// Return empty string if no bare id was found
-StringRef Parser::parseDimOrSymbolId(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols,
- bool symbol = false) {
- if (curToken.isNot(Token::bare_identifier)) {
- emitError("expected bare identifier");
- return StringRef();
- }
- // TODO(bondhugula): check whether the id already exists in either
- // state.symbols or state.dims; report error if it does; otherwise create a
- // new one.
- StringRef ref = curToken.getSpelling();
+/// Parse a dim or symbol from the lists appearing before the actual expressions
+/// of the affine map. Update state to store the dimensional/symbolic
+/// identifier. 'dim': whether it's the dim list or symbol list that is being
+/// parsed.
+ParseResult Parser::parseDimOrSymbolId(AffineMapParserState &state, bool dim) {
+ if (curToken.isNot(Token::bare_identifier))
+ return emitError("expected bare identifier");
+ auto sRef = curToken.getSpelling();
consumeToken(Token::bare_identifier);
- return ref;
+ if (state.getDims().count(sRef) == 1)
+ return emitError("dimensional identifier name reused");
+ if (state.getSymbols().count(sRef) == 1)
+ return emitError("symbolic identifier name reused");
+ if (dim)
+ state.addDim(sRef);
+ else
+ state.addSymbol(sRef);
+ return ParseSuccess;
}
-ParseResult Parser::parseSymbolIdList(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols) {
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult Parser::parseSymbolIdList(AffineMapParserState &state) {
if (!consumeIf(Token::l_bracket)) return emitError("expected '['");
auto parseElt = [&]() -> ParseResult {
- auto elt = parseDimOrSymbolId(dims, symbols, true);
- // FIXME(bondhugula): assuming dim arg for now
- if (!elt.empty()) {
- symbols.push_back(elt);
- return ParseSuccess;
- }
- return ParseFailure;
+ return parseDimOrSymbolId(state, false);
};
return parseCommaSeparatedList(Token::r_bracket, parseElt);
}
-// TODO(andy,bondhugula)
-ParseResult Parser::parseDimIdList(SmallVectorImpl<StringRef> &dims,
- SmallVectorImpl<StringRef> &symbols) {
+/// Parse the list of dimensional identifiers to an affine map.
+ParseResult Parser::parseDimIdList(AffineMapParserState &state) {
if (!consumeIf(Token::l_paren))
return emitError("expected '(' at start of dimensional identifiers list");
auto parseElt = [&]() -> ParseResult {
- auto elt = parseDimOrSymbolId(dims, symbols, false);
- if (!elt.empty()) {
- dims.push_back(elt);
- return ParseSuccess;
- }
- return ParseFailure;
+ return parseDimOrSymbolId(state, true);
};
-
return parseCommaSeparatedList(Token::r_paren, parseElt);
}
-/// Affine map definition.
+/// Parse an affine map definition.
///
-/// affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+/// affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
/// ( `size` `(` dim-size (`,` dim-size)* `)` )?
-/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
+/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
///
-AffineMap *Parser::parseAffineMapInline(StringRef mapId) {
- SmallVector<StringRef, 4> dims;
- SmallVector<StringRef, 4> symbols;
+/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+ParseResult Parser::parseAffineMapInline(StringRef mapId,
+ AffineMap *&affineMap) {
+ AffineMapParserState state;
// List of dimensional identifiers.
- if (parseDimIdList(dims, symbols)) return nullptr;
+ if (parseDimIdList(state))
+ return ParseFailure;
// Symbols are optional.
if (curToken.is(Token::l_bracket)) {
- if (parseSymbolIdList(dims, symbols)) return nullptr;
+ if (parseSymbolIdList(state))
+ return ParseFailure;
}
if (!consumeIf(Token::arrow)) {
- emitError("expected '->' or '['");
- return nullptr;
+ return (emitError("expected '->' or '['"), ParseFailure);
}
if (!consumeIf(Token::l_paren)) {
emitError("expected '(' at start of affine map range");
- return nullptr;
+ return ParseFailure;
}
- AffineMapParserState affState(dims, symbols);
-
SmallVector<AffineExpr *, 4> exprs;
auto parseElt = [&]() -> ParseResult {
- auto elt = parseAffineExpr(affState);
+ auto *elt = parseAffineExpr(state);
ParseResult res = elt ? ParseSuccess : ParseFailure;
exprs.push_back(elt);
return res;
};
// Parse a multi-dimensional affine expression (a comma-separated list of 1-d
- // affine expressions)
- if (parseCommaSeparatedList(Token::r_paren, parseElt, false)) return nullptr;
+ // affine expressions); the list cannot be empty.
+ // Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+ if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
+ return ParseFailure;
- // Parsed a valid affine map
- auto *affineMap =
- AffineMap::get(affState.dimCount(), affState.symbolCount(), exprs,
- context);
-
- return affineMap;
+ // Parsed a valid affine map.
+ affineMap =
+ AffineMap::get(state.getNumDims(), state.getNumSymbols(), exprs, context);
+ return ParseSuccess;
}
//===----------------------------------------------------------------------===//
@@ -767,7 +1035,6 @@
if (parseFunctionSignature(name, type))
return ParseFailure;
-
// Okay, the external function definition was parsed correctly.
module->functionList.push_back(new ExtFunction(name, type));
return ParseSuccess;
@@ -1098,7 +1365,7 @@
emitError("expected a top level entity");
return nullptr;
- // If we got to the end of the file, then we're done.
+ // If we got to the end of the file, then we're done.
case Token::eof:
return module.release();
@@ -1115,6 +1382,7 @@
case Token::kw_cfgfunc:
if (parseCFGFunc()) return nullptr;
break;
+
case Token::affine_map_identifier:
if (parseAffineMapDef()) return nullptr;
break;
@@ -1123,7 +1391,7 @@
if (parseMLFunc()) return nullptr;
break;
- // TODO: affine entity declarations, etc.
+ // TODO: affine entity declarations, etc.
}
}
}
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 53b4a56..f87499f 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -79,9 +79,8 @@
// Operators.
TOK_OPERATOR(plus, "+")
+TOK_OPERATOR(minus, "-")
TOK_OPERATOR(star, "*")
-TOK_OPERATOR(ceildiv, "ceildiv")
-TOK_OPERATOR(floordiv, "floordiv")
// TODO: More operator tokens
// Keywords. These turn "foo" into Token::kw_foo enums.
@@ -101,6 +100,9 @@
TOK_KEYWORD(return)
TOK_KEYWORD(tensor)
TOK_KEYWORD(vector)
+TOK_KEYWORD(mod)
+TOK_KEYWORD(floordiv)
+TOK_KEYWORD(ceildiv)
#undef TOK_MARKER
#undef TOK_IDENTIFIER
diff --git a/test/IR/parser-affine-map-negative.mlir b/test/IR/parser-affine-map-negative.mlir
new file mode 100644
index 0000000..b0a7bec
--- /dev/null
+++ b/test/IR/parser-affine-map-negative.mlir
@@ -0,0 +1,56 @@
+;
+; RUN: %S/../../mlir-opt %s -o - -check-parser-errors
+
+; Check different error cases.
+; -----
+
+#hello_world1 = (i, j) -> ((), j) ; expected-error {{no expression inside parentheses}}
+
+; -----
+#hello_world2 (i, j) [s0] -> (i, j) ; expected-error {{expected '=' in affine map outlined definition}}
+
+; -----
+#hello_world3a = (i, j) [s0] -> (2*i*, 3*j*i*2 + 5) ; expected-error {{missing right operand of multiply op}}
+
+; -----
+#hello_world3b = (i, j) [s0] -> (i+, i+j+2 + 5) ; expected-error {{missing right operand of add op}}
+
+; -----
+#hello_world4 = (i, j) [s0] -> ((s0 + i, j) ; expected-error {{expected ')'}}
+
+; -----
+#hello_world5 = (i, j) [s0] -> ((s0 + i, j) ; expected-error {{expected ')'}}
+
+; -----
+#hello_world6 = (i, j) [s0] -> (((s0 + (i + j) + 5), j) ; expected-error {{expected ')'}}
+
+; -----
+#hello_world8 = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
+
+; -----
+#hello_world9 = (i, j) [s0] -> (x) ; expected-error {{identifier is neither dimensional nor symbolic}}
+
+; -----
+#hello_world10 = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
+
+; -----
+#hello_world11 = (i, j) [s0, s1, s0] -> (i) ; expected-error {{symbolic identifier name reused}}
+
+; -----
+#hello_world12 = (i, j) [i, s0] -> (j) ; expected-error {{dimensional identifier name reused}}
+
+; -----
+#hello_world13 = (i, j) [s0, s1] -> () ; expected-error {{expected list element}}
+
+; -----
+#hello_world14 = (i, j) [s0, s1] -> (+i, j) ; expected-error {{left operand of binary op missing}}
+
+; -----
+#hello_world15 = (i, j) [s0, s1] -> (i, *j+5) ; expected-error {{left operand of binary op missing}}
+
+; FIXME(bondhugula) This one leads to two errors: the first on identifier being
+; neither dimensional nor symbolic and then the right operand missing.
+;-----
+; #hello_world22 = (i, j) -> (i, 3*d0 + j)
+
+; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index 50dd9e2..5b2539c 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -1,7 +1,91 @@
-#hello_world0 = (i, j) [s0] -> (i, j)
-#hello_world1 = (i, j) -> (i, j)
+; RUN: %S/../../mlir-opt %s -o - | FileCheck %s
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1)
+#hello_world0 = (i, j) -> (i, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
+#hello_world1 = (i, j) [s0] -> (i, j)
+
+; CHECK: #{{[0-9]+}} = () -> (0)
#hello_world2 = () -> (0)
-#hello_world3 = (i, j) [s0] -> (i + s0, j)
-#hello_world4 = (i, j) [s0] -> (i + s0, j + 5)
-#hello_world5 (i, j) [s0] -> i + s0, j)
-#hello_world5 = (i, j) [s0] -> i + s0, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+#hello_world3 = (i, j) -> (i+1, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
+#hello_world4 = (i, j) [s0] -> (i + s0, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> ((1 + d0), d1)
+#hello_world5 = (i, j) -> (1+i, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
+#hello_world6 = (i, j) [s0] -> (i + s0, j + 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
+#hello_world7 = (i, j) [s0] -> (i + j + s0, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((((5 + d0) + d1) + s0), d1)
+#hello_world8 = (i, j) [s0] -> (5 + i + j + s0, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
+#hello_world9 = (i, j) [s0] -> ((i + j) + 5, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
+#hello_world10 = (i, j) [s0] -> (i + (j + 5), j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((2 * d0), (3 * d1))
+#hello_world11 = (i, j) [s0] -> (2*i, 3*j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + (2 * 6)) + (5 * (d1 + (s0 * 3)))), d1)
+#hello_world12 = (i, j) [s0] -> (i + 2*6 + 5*(j+s0*3), j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((5 * d0) + d1), d1)
+#hello_world13 = (i, j) [s0] -> (5*i + j, j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
+#hello_world14 = (i, j) [s0] -> ((i + j), (j))
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
+#hello_world15 = (i, j) [s0] -> ((i + j)+5, (j)+3)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
+#hello_world16 = (i, j) [s1] -> (i, 0)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d0 * d1))
+#hello_world17 = (i, j) [s1] -> (i, i*j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, ((3 * d0) + d1))
+#hello_world19 = (i, j) -> (i, 3*i + j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (3 * d1)))
+#hello_world20 = (i, j) -> (i, i + 3*j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, ((2 + (((d1 * d0) * 9) * d0)) + 1))
+#hello_world18 = (i, j) -> (i, 2 + j*i*9*i + 1)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (3 * d1)) + 5))
+#hello_world21 = (i, j) -> (1, i + 3*j + 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((5 * s0), ((d0 + (3 * d1)) + (5 * d0)))
+#hello_world22 = (i, j) [s0] -> (5*s0, i + 3*j + 5*i)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
+#hello_world23 = (i, j) [s0, s1] -> (i*(s0*s1), j)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
+#hello_world24 = (i, j) [s0, s1] -> (i, j mod 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
+#hello_world25 = (i, j) [s0, s1] -> (i, j floordiv 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
+#hello_world26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d1 + (((d1 ceildiv 128) mod 16) * d0)) - 4))
+#hello_world27 = (i, j) [s0, s1] -> (i, j + j ceildiv 128 mod 16 * i - 4)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - d1) - 5))
+#hello_world29 = (i, j) [s0, s1] -> (i, i - j - 5)
+
+; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d0 * d1)) + 2))
+#hello_world30 = (i, j) [s0, s1] -> (i, i - i*j + 2)