Use OperationState to simplify the create<Op> methods, move them out of line,
and simplify some other things. Change ConstantIntOp to not match affine
integers, since we now have ConstantAffineIntOp.
PiperOrigin-RevId: 207756316
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d093064..d89c62e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -87,8 +87,7 @@
namespace {
-typedef std::function<Operation *(Identifier, ArrayRef<SSAValue *>,
- ArrayRef<Type *>, ArrayRef<NamedAttribute>)>
+typedef std::function<Operation *(const OperationState &)>
CreateOperationFunction;
/// This class implement support for parsing global entities like types and
@@ -1596,6 +1595,8 @@
consumeToken(Token::string);
+ OperationState result(builder.getIdentifier(name));
+
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@@ -1605,9 +1606,8 @@
return nullptr;
}
- SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
+ if (parseAttributeDict(result.attributes))
return nullptr;
}
@@ -1622,6 +1622,8 @@
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
+ result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
if (operandTypes.size() != operandInfos.size()) {
@@ -1633,15 +1635,13 @@
}
// Resolve all of the operands.
- SmallVector<SSAValue *, 8> operands;
for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
- operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
- if (!operands.back())
+ result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!result.operands.back())
return nullptr;
}
- auto nameId = builder.getIdentifier(name);
- return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+ return createOpFunc(result);
}
namespace {
@@ -1840,8 +1840,7 @@
return nullptr;
// Otherwise, we succeeded. Use the state it parsed as our op information.
- return createOpFunc(opState.name, opState.operands, opState.types,
- opState.attributes);
+ return createOpFunc(opState);
}
//===----------------------------------------------------------------------===//
@@ -1979,14 +1978,8 @@
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<CFGValue *, 8> cfgOperands;
- cfgOperands.reserve(operands.size());
- for (auto *op : operands)
- cfgOperands.push_back(cast<CFGValue>(op));
- return builder.createOperation(name, cfgOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &result) -> Operation * {
+ return builder.createOperation(result);
};
// Parse the list of operations that make up the body of the block.
@@ -2261,14 +2254,8 @@
/// Parse a list of statements ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<MLValue *, 8> stmtOperands;
- stmtOperands.reserve(operands.size());
- for (auto *op : operands)
- stmtOperands.push_back(cast<MLValue>(op));
- return builder.createOperation(name, stmtOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &state) -> Operation * {
+ return builder.createOperation(state);
};
builder.setInsertionPoint(block);