Fix FIXME's/TODOs:
- Enhance memref type to allow omission of mappings and address
spaces (implying a default mapping).
- Fix printing of function types to properly recurse with printType
so mappings are printed by name.
- Simplify parsing of AffineMaps a bit now that we have
isSymbolicOrConstant()
PiperOrigin-RevId: 206039755
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index fa4462d..e23964d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -295,14 +295,14 @@
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
os << '(';
- interleaveComma(func->getInputs(), [&](Type *type) { os << *type; });
+ interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
os << ") -> ";
auto results = func->getResults();
if (results.size() == 1)
os << *results[0];
else {
os << '(';
- interleaveComma(results, [&](Type *type) { os << *type; });
+ interleaveComma(results, [&](Type *type) { printType(type); });
os << ')';
}
return;
@@ -330,7 +330,9 @@
}
case Type::Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(type);
- os << "tensor<??" << *v->getElementType() << '>';
+ os << "tensor<??";
+ printType(v->getElementType());
+ os << '>';
return;
}
case Type::Kind::MemRef: {
@@ -343,12 +345,14 @@
os << dim;
os << 'x';
}
- os << *v->getElementType();
+ printType(v->getElementType());
for (auto map : v->getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
- os << ", " << v->getMemorySpace();
+ // Only print the memory space if it is the non-default one.
+ if (v->getMemorySpace())
+ os << ", " << v->getMemorySpace();
os << '>';
return;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 63ebeff..55fd260 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -443,12 +443,9 @@
if (!elementType)
return nullptr;
- if (parseToken(Token::comma, "expected ',' in memref type"))
- return nullptr;
-
// Parse semi-affine-map-composition.
SmallVector<AffineMap *, 2> affineMapComposition;
- unsigned memorySpace;
+ unsigned memorySpace = 0;
bool parsedMemorySpace = false;
auto parseElt = [&]() -> ParseResult {
@@ -474,16 +471,17 @@
return ParseSuccess;
};
- // Parse comma separated list of affine maps, followed by memory space.
- if (parseCommaSeparatedListUntil(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
- return nullptr;
+ // Parse a list of mappings and address space if present.
+ if (consumeIf(Token::comma)) {
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ } else {
+ if (parseToken(Token::greater, "expected ',' or '>' in memref type"))
+ return nullptr;
}
- // Check that MemRef type specifies at least one affine map in composition.
- if (affineMapComposition.empty())
- return (emitError("expected semi-affine-map in memref type"), nullptr);
- if (!parsedMemorySpace)
- return (emitError("expected memory space in memref type"), nullptr);
return MemRefType::get(dimensions, elementType, affineMapComposition,
memorySpace);
@@ -710,22 +708,14 @@
AffineMap *parseAffineMapInline();
private:
- unsigned getNumDims() const { return dims.size(); }
- unsigned getNumSymbols() const { return symbols.size(); }
-
- /// Returns true if the only identifiers the parser accepts in affine
- /// expressions are symbolic identifiers.
- bool isPureSymbolic() const { return pureSymbolic; }
- void setSymbolicParsing(bool val) { pureSymbolic = val; }
-
// Binary affine op parsing.
AffineLowPrecOp consumeIfLowPrecOp();
AffineHighPrecOp consumeIfHighPrecOp();
// Identifier lists for polyhedral structures.
- ParseResult parseDimIdList();
- ParseResult parseSymbolIdList();
- ParseResult parseDimOrSymbolId(bool isDim);
+ ParseResult parseDimIdList(unsigned &numDims);
+ ParseResult parseSymbolIdList(unsigned &numSymbols);
+ ParseResult parseIdentifierDefinition(AffineExpr *idExpr);
AffineExpr *parseAffineExpr();
AffineExpr *parseParentheticalExpr();
@@ -745,12 +735,7 @@
SMLoc llhsOpLoc);
private:
- // TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
- llvm::StringMap<unsigned> dims;
- llvm::StringMap<unsigned> symbols;
- /// True if the parser should allow only symbolic identifiers in affine
- /// expressions.
- bool pureSymbolic = false;
+ SmallVector<std::pair<StringRef, AffineExpr *>, 4> dimsAndSymbols;
};
} // end anonymous namespace
@@ -931,18 +916,12 @@
return (emitError("expected bare identifier"), nullptr);
StringRef sRef = getTokenSpelling();
- // dims, symbols are all pairwise distinct.
- if (dims.count(sRef)) {
- if (isPureSymbolic())
- return (emitError("identifier used is not a symbolic identifier"),
- nullptr);
- consumeToken(Token::bare_identifier);
- return builder.getDimExpr(dims.lookup(sRef));
- }
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first != sRef)
+ continue;
- if (symbols.count(sRef)) {
consumeToken(Token::bare_identifier);
- return builder.getSymbolExpr(symbols.lookup(sRef));
+ return entry.second;
}
return (emitError("use of undeclared identifier"), nullptr);
@@ -1087,40 +1066,42 @@
/// Parse a dim or symbol from the lists appearing before the actual expressions
/// of the affine map. Update our state to store the dimensional/symbolic
-/// identifier. 'dim': whether it's the dim list or symbol list that is being
-/// parsed.
-ParseResult AffineMapParser::parseDimOrSymbolId(bool isDim) {
+/// identifier.
+ParseResult AffineMapParser::parseIdentifierDefinition(AffineExpr *idExpr) {
if (getToken().isNot(Token::bare_identifier))
return emitError("expected bare identifier");
- auto sRef = getTokenSpelling();
+
+ auto name = getTokenSpelling();
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name)
+ return emitError("redefinition of identifier '" + Twine(name) + "'");
+ }
consumeToken(Token::bare_identifier);
- if (dims.count(sRef))
- return emitError("dimensional identifier name reused");
- if (symbols.count(sRef))
- return emitError("symbolic identifier name reused");
- if (isDim)
- dims.insert({sRef, dims.size()});
- else
- symbols.insert({sRef, symbols.size()});
+
+ dimsAndSymbols.push_back({name, idExpr});
return ParseSuccess;
}
/// Parse the list of symbolic identifiers to an affine map.
-ParseResult AffineMapParser::parseSymbolIdList() {
- if (parseToken(Token::l_square, "expected '['"))
- return ParseFailure;
-
- auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
+ParseResult AffineMapParser::parseSymbolIdList(unsigned &numSymbols) {
+ consumeToken(Token::l_square);
+ auto parseElt = [&]() -> ParseResult {
+ auto *symbol = AffineSymbolExpr::get(numSymbols++, getContext());
+ return parseIdentifierDefinition(symbol);
+ };
return parseCommaSeparatedListUntil(Token::r_square, parseElt);
}
/// Parse the list of dimensional identifiers to an affine map.
-ParseResult AffineMapParser::parseDimIdList() {
+ParseResult AffineMapParser::parseDimIdList(unsigned &numDims) {
if (parseToken(Token::l_paren,
"expected '(' at start of dimensional identifiers list"))
return ParseFailure;
- auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(true); };
+ auto parseElt = [&]() -> ParseResult {
+ auto *dimension = AffineDimExpr::get(numDims++, getContext());
+ return parseIdentifierDefinition(dimension);
+ };
return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}
@@ -1132,13 +1113,15 @@
///
/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
AffineMap *AffineMapParser::parseAffineMapInline() {
+ unsigned numDims = 0, numSymbols = 0;
+
// List of dimensional identifiers.
- if (parseDimIdList())
+ if (parseDimIdList(numDims))
return nullptr;
// Symbols are optional.
if (getToken().is(Token::l_square)) {
- if (parseSymbolIdList())
+ if (parseSymbolIdList(numSymbols))
return nullptr;
}
@@ -1173,13 +1156,19 @@
return nullptr;
auto parseRangeSize = [&]() -> ParseResult {
+ auto loc = getToken().getLoc();
auto *elt = parseAffineExpr();
- ParseResult res = elt ? ParseSuccess : ParseFailure;
+ if (!elt)
+ return ParseFailure;
+
+ if (!elt->isSymbolicOrConstant())
+ return emitError(loc,
+ "size expressions cannot refer to dimension values");
+
rangeSizes.push_back(elt);
- return res;
+ return ParseSuccess;
};
- setSymbolicParsing(true);
if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false))
return nullptr;
if (exprs.size() > rangeSizes.size())
@@ -1191,7 +1180,7 @@
}
// Parsed a valid affine map.
- return builder.getAffineMap(dims.size(), symbols.size(), exprs, rangeSizes);
+ return builder.getAffineMap(numDims, numSymbols, exprs, rangeSizes);
}
AffineMap *Parser::parseAffineMapInline() {
@@ -2448,8 +2437,6 @@
if (parseMLFunc())
return ParseFailure;
break;
-
- // TODO: affine entity declarations, etc.
}
}
}
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 5c24ad4..6a7cd98 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -245,3 +245,8 @@
// CHECK: extfunc @f42(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f42(memref<2x4xi8, #map42, 1>)
+
+// CHECK: extfunc @f43(memref<2x4xi8, #map{{[0-9]+}}>)
+extfunc @f43(memref<2x4xi8, #map42>)
+
+
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index 5618239..37d6ec1 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -58,14 +58,13 @@
}
// CHECK-LABEL: cfgfunc @load_store
-cfgfunc @load_store(memref<4x4xi32, #id2, 0>, affineint) {
-bb0(%0: memref<4x4xi32, #id2, 0>, %1: affineint):
+cfgfunc @load_store(memref<4x4xi32>, affineint) {
+bb0(%0: memref<4x4xi32>, %1: affineint):
+ // CHECK: %2 = load %0[%1, %1] : memref<4x4xi32>
+ %2 = "load"(%0, %1, %1) : (memref<4x4xi32>, affineint, affineint)->i32
- // CHECK: %2 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
- %2 = "load"(%0, %1, %1) : (memref<4x4xi32, #id2, 0>, affineint, affineint)->i32
-
- // CHECK: %3 = load %0[%1, %1] : memref<4x4xi32, #map2, 0>
- %3 = load %0[%1, %1] : memref<4x4xi32, #id2, 0>
+ // CHECK: %3 = load %0[%1, %1] : memref<4x4xi32>
+ %3 = load %0[%1, %1] : memref<4x4xi32>
return
}
diff --git a/test/IR/invalid-affinemap.mlir b/test/IR/invalid-affinemap.mlir
index 4ef4615..7953e3f 100644
--- a/test/IR/invalid-affinemap.mlir
+++ b/test/IR/invalid-affinemap.mlir
@@ -36,13 +36,13 @@
#hello_world = (i, j) [s0] -> (x) // expected-error {{use of undeclared identifier}}
// -----
-#hello_world = (i, j, i) [s0] -> (i) // expected-error {{dimensional identifier name reused}}
+#hello_world = (i, j, i) [s0] -> (i) // expected-error {{redefinition of identifier 'i'}}
// -----
-#hello_world = (i, j) [s0, s1, s0] -> (i) // expected-error {{symbolic identifier name reused}}
+#hello_world = (i, j) [s0, s1, s0] -> (i) // expected-error {{redefinition of identifier 's0'}}
// -----
-#hello_world = (i, j) [i, s0] -> (j) // expected-error {{dimensional identifier name reused}}
+#hello_world = (i, j) [i, s0] -> (j) // expected-error {{redefinition of identifier 'i'}}
// -----
#hello_world = (i, j) [s0, s1] -> () // expected-error {{expected list element}}
@@ -104,10 +104,10 @@
#hello_world = (i, j) -> (i, j) size (10, x) // expected-error {{use of undeclared identifier}}
// -----
-#hello_world = (i, j) [M] -> (i, j) size (10, j) // expected-error {{identifier used is not a symbolic identifier}}
+#hello_world = (i, j) [M] -> (i, j) size (10, j) // expected-error {{size expressions cannot refer to dimension values}}
// -----
-#hello_world = (i, j) [M] -> (i, j) size (10, M+i) // expected-error {{identifier used is not a symbolic identifier}}
+#hello_world = (i, j) [M] -> (i, j) size (10, M+i) // expected-error {{size expressions cannot refer to dimension values}}
// -----
#hello_world = (i, j) -> (i, j) size (10) // expected-error {{fewer range sizes than range expressions}}
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 2e87540..7bf9248 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -13,45 +13,34 @@
extfunc @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{expected type}}
// -----
-// Test no comma in memref type.
-// TODO(andydavis) Fix this test if we decide to allow empty affine map to
-// imply identity affine map.
-extfunc @memrefs(memref<2x4xi8>) ; expected-error {{expected ',' in memref type}}
-
-// -----
// Test no map in memref type.
-extfunc @memrefs(memref<2x4xi8, >) ; expected-error {{expected list element}}
+extfunc @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}
// -----
// Test non-existent map in memref type.
-extfunc @memrefs(memref<2x4xi8, #map7>) ; expected-error {{undefined affine map id 'map7'}}
+extfunc @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined affine map id 'map7'}}
// -----
// Test non hash identifier in memref type.
-extfunc @memrefs(memref<2x4xi8, %map7>) ; expected-error {{expected '(' at start of dimensional identifiers list}}
+extfunc @memrefs(memref<2x4xi8, %map7>) // expected-error {{expected '(' at start of dimensional identifiers list}}
// -----
// Test non-existent map in map composition of memref type.
#map0 = (d0, d1) -> (d0, d1)
-extfunc @memrefs(memref<2x4xi8, #map0, #map8>) ; expected-error {{undefined affine map id 'map8'}}
+extfunc @memrefs(memref<2x4xi8, #map0, #map8>) // expected-error {{undefined affine map id 'map8'}}
// -----
// Test multiple memory space error.
#map0 = (d0, d1) -> (d0, d1)
-extfunc @memrefs(memref<2x4xi8, #map0, 1, 2>) ; expected-error {{multiple memory spaces specified in memref type}}
+extfunc @memrefs(memref<2x4xi8, #map0, 1, 2>) // expected-error {{multiple memory spaces specified in memref type}}
// -----
// Test affine map after memory space.
#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0, d1) -> (d0, d1)
-extfunc @memrefs(memref<2x4xi8, #map0, 1, #map1>) ; expected-error {{affine map after memory space in memref type}}
-
-// -----
-// Test no memory space error.
-#map0 = (d0, d1) -> (d0, d1)
-extfunc @memrefs(memref<2x4xi8, #map0>) ; expected-error {{expected memory space in memref type}}
+extfunc @memrefs(memref<2x4xi8, #map0, 1, #map1>) // expected-error {{affine map after memory space in memref type}}
// -----
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index eb57a1b..7276724 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -41,15 +41,15 @@
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
tensor<1x?x4x?x?xaffineint>, tensor<i8>)
-// CHECK: extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map{{[0-9]+}}, 0>, memref<i8, #map{{[0-9]+}}, 0>)
-extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>)
+// CHECK: extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map{{[0-9]+}}>, memref<i8, #map{{[0-9]+}}>)
+extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map0>, memref<i8, #map1>)
// Test memref affine map compositions.
// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}, 1>)
extfunc @memrefs2(memref<2x4x8xi8, #map2, 1>)
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 0>)
+// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
extfunc @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>)
// CHECK: extfunc @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
@@ -57,13 +57,13 @@
// Test memref inline affine map compositions.
-// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}, 0>)
-extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), 0>)
+// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}>)
+extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
-// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>) -> (), () -> ())
+// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, #map0>, memref<i8, #map1>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {