Add support for multiple results to the printer/parser, add support
for forward references to the parser, add initial support for SSA
use-list iteration and RAUW.
PiperOrigin-RevId: 205484031
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index da936db..1fd3432 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using llvm::SourceMgr;
@@ -1179,10 +1180,19 @@
public:
FunctionParser(ParserState &state) : Parser(state) {}
- /// This represents a use of an SSA value in the program. This tracks
- /// location information in case this ends up being a use of an undefined
- /// value.
- typedef std::pair<StringRef, SMLoc> SSAUseInfo;
+ /// After the function is finished parsing, this function checks to see if
+ /// there are any remaining issues.
+ ParseResult finalizeFunction();
+
+ /// This represents a use of an SSA value in the program. The first two
+ /// entries in the tuple are the name and result number of a reference. The
+ /// third is the location of the reference, which is used in case this ends up
+ /// being a use of an undefined value.
+ struct SSAUseInfo {
+ StringRef name; // Value name, e.g. %42 or %abc
+ unsigned number; // Number, specified with #12
+ SMLoc loc; // Location of first definition or use.
+ };
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
@@ -1205,47 +1215,122 @@
private:
/// This keeps track of all of the SSA values we are tracking, indexed by
- /// their name (either an identifier or a number).
- llvm::StringMap<std::pair<SSAValue *, SMLoc>> values;
+ /// their name. This has one entry per result number.
+ llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values;
+
+ /// These are all of the placeholders we've made along with the location of
+ /// their first reference, to allow checking for use of undefined values.
+ DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
+
+ SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
+
+ /// Return true if this is a forward reference.
+ bool isForwardReferencePlaceholder(SSAValue *value) {
+ return forwardReferencePlaceholders.count(value);
+ }
};
} // end anonymous namespace
+/// Create and remember a new placeholder for a forward reference.
+SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
+ Type *type) {
+ // Forward references are always created as instructions, even in ML
+ // functions, because we just need something with a def/use chain.
+ //
+ // We create these placeholders as having an empty name, which we know cannot
+ // be created through normal user input, allowing us to distinguish them.
+ auto name = Identifier::get("placeholder", getContext());
+ auto *inst = OperationInst::create(name, /*operands*/ {}, type, /*attrs*/ {},
+ getContext());
+ forwardReferencePlaceholders[inst->getResult(0)] = loc;
+ return inst->getResult(0);
+}
+
/// Given an unbound reference to an SSA value and its type, return a the value
/// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
+ auto &entries = values[useInfo.name];
+
// If we have already seen a value of this name, return it.
- auto it = values.find(useInfo.first);
- if (it != values.end()) {
+ if (useInfo.number < entries.size() && entries[useInfo.number].first) {
+ auto *result = entries[useInfo.number].first;
// Check that the type matches the other uses.
- auto result = it->second.first;
if (result->getType() == type)
return result;
- emitError(useInfo.second, "use of value '" + useInfo.first.str() +
- "' expects different type than prior uses");
- emitError(it->second.second, "prior use here");
+ emitError(useInfo.loc, "use of value '" + useInfo.name.str() +
+ "' expects different type than prior uses");
+ emitError(entries[useInfo.number].second, "prior use here");
return nullptr;
}
- // Otherwise we have a forward reference.
- // TODO: Handle forward references.
- emitError(useInfo.second, "undeclared or forward reference");
- return nullptr;
+ // Make sure we have enough slots for this.
+ if (entries.size() <= useInfo.number)
+ entries.resize(useInfo.number + 1);
+
+ // If the value has already been defined and this is an overly large result
+ // number, diagnose that.
+ if (entries[0].first && !isForwardReferencePlaceholder(entries[0].first))
+ return (emitError(useInfo.loc, "reference to invalid result number"),
+ nullptr);
+
+ // Otherwise, this is a forward reference. Create a placeholder and remember
+ // that we did so.
+ auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
+ entries[useInfo.number].first = result;
+ entries[useInfo.number].second = useInfo.loc;
+ return result;
}
/// Register a definition of a value with the symbol table.
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) {
+ auto &entries = values[useInfo.name];
- // If this is the first definition of this thing, then we are trivially done.
- auto insertInfo = values.insert({useInfo.first, {value, useInfo.second}});
- if (insertInfo.second)
- return ParseSuccess;
+ // Make sure there is a slot for this value.
+ if (entries.size() <= useInfo.number)
+ entries.resize(useInfo.number + 1);
- // If we already had a value, replace it with the new one and remove the
- // placeholder, only if it was a forward ref.
- // TODO: Handle forward references.
- emitError(useInfo.second, "redefinition of SSA value " + useInfo.first.str());
- return ParseFailure;
+ // If we already have an entry for this, check to see if it was a definition
+ // or a forward reference.
+ if (auto *existing = entries[useInfo.number].first) {
+ if (!isForwardReferencePlaceholder(existing)) {
+ emitError(useInfo.loc,
+ "redefinition of SSA value '" + useInfo.name + "'");
+ return emitError(entries[useInfo.number].second,
+ "previously defined here");
+ }
+
+ // If it was a forward reference, update everything that used it to use the
+ // actual definition instead, delete the forward ref, and remove it from our
+ // set of forward references we track.
+ existing->replaceAllUsesWith(value);
+ existing->getDefiningInst()->destroy();
+ forwardReferencePlaceholders.erase(existing);
+ }
+
+ entries[useInfo.number].first = value;
+ entries[useInfo.number].second = useInfo.loc;
+ return ParseSuccess;
+}
+
+/// After the function is finished parsing, this function checks to see if
+/// there are any remaining issues.
+ParseResult FunctionParser::finalizeFunction() {
+ // Check for any forward references that are left. If we find any, error out.
+ if (!forwardReferencePlaceholders.empty()) {
+ SmallVector<std::pair<const char *, SSAValue *>, 4> errors;
+ // Iteration over the map isn't determinstic, so sort by source location.
+ for (auto entry : forwardReferencePlaceholders)
+ errors.push_back({entry.second.getPointer(), entry.first});
+ llvm::array_pod_sort(errors.begin(), errors.end());
+
+ for (auto entry : errors)
+ emitError(SMLoc::getFromPointer(entry.first),
+ "use of undeclared SSA value name");
+ return ParseFailure;
+ }
+
+ return ParseSuccess;
}
/// Parse a SSA operand for an instruction or statement.
@@ -1254,10 +1339,21 @@
/// TODO: SSA Constants.
///
ParseResult FunctionParser::parseSSAUse(SSAUseInfo &result) {
- result.first = getTokenSpelling();
- result.second = getToken().getLoc();
+ result.name = getTokenSpelling();
+ result.number = 0;
+ result.loc = getToken().getLoc();
if (!consumeIf(Token::percent_identifier))
return emitError("expected SSA operand");
+
+ // If we have an affine map ID, it is a result number.
+ if (getToken().is(Token::hash_identifier)) {
+ if (auto value = getToken().getHashIdentifierNumber())
+ result.number = value.getValue();
+ else
+ return emitError("invalid SSA value result number");
+ consumeToken(Token::hash_identifier);
+ }
+
return ParseSuccess;
}
@@ -1403,7 +1499,8 @@
if (inst->getNumResults() == 0)
return emitError(loc, "cannot name an operation with no results");
- addDefinition({resultID, loc}, inst->getResult(0));
+ for (unsigned i = 0, e = inst->getNumResults(); i != e; ++i)
+ addDefinition({resultID, i, loc}, inst->getResult(i));
}
}
@@ -1474,7 +1571,8 @@
}
getModule()->functionList.push_back(function);
- return ParseSuccess;
+
+ return finalizeFunction();
}
/// Basic block declaration.
@@ -1612,7 +1710,7 @@
getModule()->functionList.push_back(function);
- return ParseSuccess;
+ return finalizeFunction();
}
/// For statement.