Use OperationState to simplify the create<Op> methods, move them out of line,
and simplify some other things. Change ConstantIntOp to not match affine
integers, since we now have ConstantAffineIntOp.
PiperOrigin-RevId: 207756316
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index cb7560e..2f67c27 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -607,10 +607,10 @@
if (intOp->getType()->isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
- specialName << 'c' << intOp->getValue();
- if (!intOp->getType()->isAffineInt())
- specialName << '_' << *intOp->getType();
+ specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
}
+ } else if (auto intOp = op->getAs<ConstantAffineIntOp>()) {
+ specialName << 'c' << intOp->getValue();
}
}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index b67c12d..95d7931 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -227,12 +227,33 @@
return "requires a result type that aligns with the 'value' attribute";
}
-/// ConstantIntOp only matches values whose result type is an IntegerType or
-/// AffineInt.
+/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- (isa<IntegerType>(op->getResult(0)->getType()) ||
- op->getResult(0)->getType()->isAffineInt());
+ isa<IntegerType>(op->getResult(0)->getType());
+}
+
+OperationState ConstantIntOp::build(Builder *builder, int64_t value,
+ unsigned width) {
+ OperationState result(builder->getIdentifier("constant"));
+ result.attributes.push_back(
+ {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
+ result.types.push_back(builder->getIntegerType(width));
+ return result;
+}
+
+/// ConstantAffineIntOp only matches values whose result type is AffineInt.
+bool ConstantAffineIntOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ op->getResult(0)->getType()->isAffineInt();
+}
+
+OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
+ OperationState result(builder->getIdentifier("constant"));
+ result.attributes.push_back(
+ {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
+ result.types.push_back(builder->getAffineIntType());
+ return result;
}
void DimOp::print(OpAsmPrinter *p) const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d093064..d89c62e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -87,8 +87,7 @@
namespace {
-typedef std::function<Operation *(Identifier, ArrayRef<SSAValue *>,
- ArrayRef<Type *>, ArrayRef<NamedAttribute>)>
+typedef std::function<Operation *(const OperationState &)>
CreateOperationFunction;
/// This class implement support for parsing global entities like types and
@@ -1596,6 +1595,8 @@
consumeToken(Token::string);
+ OperationState result(builder.getIdentifier(name));
+
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@@ -1605,9 +1606,8 @@
return nullptr;
}
- SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
+ if (parseAttributeDict(result.attributes))
return nullptr;
}
@@ -1622,6 +1622,8 @@
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
+ result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
if (operandTypes.size() != operandInfos.size()) {
@@ -1633,15 +1635,13 @@
}
// Resolve all of the operands.
- SmallVector<SSAValue *, 8> operands;
for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
- operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
- if (!operands.back())
+ result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!result.operands.back())
return nullptr;
}
- auto nameId = builder.getIdentifier(name);
- return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+ return createOpFunc(result);
}
namespace {
@@ -1840,8 +1840,7 @@
return nullptr;
// Otherwise, we succeeded. Use the state it parsed as our op information.
- return createOpFunc(opState.name, opState.operands, opState.types,
- opState.attributes);
+ return createOpFunc(opState);
}
//===----------------------------------------------------------------------===//
@@ -1979,14 +1978,8 @@
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<CFGValue *, 8> cfgOperands;
- cfgOperands.reserve(operands.size());
- for (auto *op : operands)
- cfgOperands.push_back(cast<CFGValue>(op));
- return builder.createOperation(name, cfgOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &result) -> Operation * {
+ return builder.createOperation(result);
};
// Parse the list of operations that make up the body of the block.
@@ -2261,14 +2254,8 @@
/// Parse a list of statements ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<MLValue *, 8> stmtOperands;
- stmtOperands.reserve(operands.size());
- for (auto *op : operands)
- stmtOperands.push_back(cast<MLValue>(op));
- return builder.createOperation(name, stmtOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &state) -> Operation * {
+ return builder.createOperation(state);
};
builder.setInsertionPoint(block);