Use OperationState to simplify the create<Op> methods, move them out of line,
and simplify some other things. Change ConstantIntOp to not match affine
integers, since we now have ConstantAffineIntOp.
PiperOrigin-RevId: 207756316
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 80b30c6..1be6c3e 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -152,34 +152,28 @@
setInsertionPoint(block, block->end());
}
+ void insert(OperationInst *opInst) {
+ block->getOperations().insert(insertPoint, opInst);
+ }
+
// Add new basic block and set the insertion point to the end of it.
BasicBlock *createBlock();
- // TODO(clattner): remove this.
- /// Create an operation at the current insertion point.
- OperationInst *createOperation(Identifier name, ArrayRef<CFGValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attributes) {
- auto op =
- OperationInst::create(name, operands, resultTypes, attributes, context);
- block->getOperations().insert(insertPoint, op);
- return op;
- }
-
/// Create an operation given the fields represented as an OperationState.
OperationInst *createOperation(const OperationState &state);
- // TODO(clattner): rework build to return an OperationState so the
- // implementations can moved out of line.
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
- return OpTy::build(this, args...);
+ auto *inst = createOperation(OpTy::build(this, args...));
+ auto result = inst->template getAs<OpTy>();
+ assert(result && "Builder didn't return the right type");
+ return result;
}
OperationInst *cloneOperation(const OperationInst &srcOpInst) {
auto *op = srcOpInst.clone();
- block->getOperations().insert(insertPoint, op);
+ insert(op);
return op;
}
@@ -262,25 +256,16 @@
/// Get the current insertion point of the builder.
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
- // TODO(clattner): remove this.
- OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attributes) {
- auto *op =
- OperationStmt::create(name, operands, resultTypes, attributes, context);
- block->getStatements().insert(insertPoint, op);
- return op;
- }
-
/// Create an operation given the fields represented as an OperationState.
OperationStmt *createOperation(const OperationState &state);
- // TODO(clattner): rework build to return an OperationState so the
- // implementations can moved out of line.
- // Create operation of specific op type at the current insertion point.
+ /// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
- return OpTy::build(this, args...);
+ auto stmt = createOperation(OpTy::build(this, args...));
+ auto result = stmt->template getAs<OpTy>();
+ assert(result && "Builder didn't return the right type");
+ return result;
}
Statement *clone(const Statement &stmt) {
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index ac5dbad..c8aaa15 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -28,6 +28,7 @@
namespace mlir {
class OperationSet;
+class Builder;
/// The "addf" operation takes two operands and returns one result, each of
/// these is required to be of the same type. This type may be a floating point
@@ -165,16 +166,7 @@
class ConstantIntOp : public ConstantOp {
public:
/// Build a constant int op producing an integer of the specified width.
- template <class Builder>
- static OpPointer<ConstantIntOp> build(Builder *builder, int64_t value,
- unsigned width) {
- std::pair<Identifier, Attribute *> namedAttr(
- builder->getIdentifier("value"), builder->getIntegerAttr(value));
- auto *type = builder->getIntegerType(width);
-
- return OpPointer<ConstantIntOp>(ConstantIntOp(builder->createOperation(
- builder->getIdentifier("constant"), {}, type, {namedAttr})));
- }
+ static OperationState build(Builder *builder, int64_t value, unsigned width);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
@@ -195,16 +187,7 @@
class ConstantAffineIntOp : public ConstantOp {
public:
/// Build a constant int op producing an affineint.
- template <class Builder>
- static OpPointer<ConstantAffineIntOp> build(Builder *builder, int64_t value) {
- std::pair<Identifier, Attribute *> namedAttr(
- builder->getIdentifier("value"), builder->getIntegerAttr(value));
- auto *type = builder->getAffineIntType();
-
- return OpPointer<ConstantAffineIntOp>(
- ConstantAffineIntOp(builder->createOperation(
- builder->getIdentifier("constant"), {}, type, {namedAttr})));
- }
+ static OperationState build(Builder *builder, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index cb7560e..2f67c27 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -607,10 +607,10 @@
if (intOp->getType()->isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
- specialName << 'c' << intOp->getValue();
- if (!intOp->getType()->isAffineInt())
- specialName << '_' << *intOp->getType();
+ specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
}
+ } else if (auto intOp = op->getAs<ConstantAffineIntOp>()) {
+ specialName << 'c' << intOp->getValue();
}
}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index b67c12d..95d7931 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -227,12 +227,33 @@
return "requires a result type that aligns with the 'value' attribute";
}
-/// ConstantIntOp only matches values whose result type is an IntegerType or
-/// AffineInt.
+/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- (isa<IntegerType>(op->getResult(0)->getType()) ||
- op->getResult(0)->getType()->isAffineInt());
+ isa<IntegerType>(op->getResult(0)->getType());
+}
+
+OperationState ConstantIntOp::build(Builder *builder, int64_t value,
+ unsigned width) {
+ OperationState result(builder->getIdentifier("constant"));
+ result.attributes.push_back(
+ {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
+ result.types.push_back(builder->getIntegerType(width));
+ return result;
+}
+
+/// ConstantAffineIntOp only matches values whose result type is AffineInt.
+bool ConstantAffineIntOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ op->getResult(0)->getType()->isAffineInt();
+}
+
+OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
+ OperationState result(builder->getIdentifier("constant"));
+ result.attributes.push_back(
+ {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
+ result.types.push_back(builder->getAffineIntType());
+ return result;
}
void DimOp::print(OpAsmPrinter *p) const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d093064..d89c62e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -87,8 +87,7 @@
namespace {
-typedef std::function<Operation *(Identifier, ArrayRef<SSAValue *>,
- ArrayRef<Type *>, ArrayRef<NamedAttribute>)>
+typedef std::function<Operation *(const OperationState &)>
CreateOperationFunction;
/// This class implement support for parsing global entities like types and
@@ -1596,6 +1595,8 @@
consumeToken(Token::string);
+ OperationState result(builder.getIdentifier(name));
+
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@@ -1605,9 +1606,8 @@
return nullptr;
}
- SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
+ if (parseAttributeDict(result.attributes))
return nullptr;
}
@@ -1622,6 +1622,8 @@
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
+ result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
if (operandTypes.size() != operandInfos.size()) {
@@ -1633,15 +1635,13 @@
}
// Resolve all of the operands.
- SmallVector<SSAValue *, 8> operands;
for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
- operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
- if (!operands.back())
+ result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
+ if (!result.operands.back())
return nullptr;
}
- auto nameId = builder.getIdentifier(name);
- return createOpFunc(nameId, operands, fnType->getResults(), attributes);
+ return createOpFunc(result);
}
namespace {
@@ -1840,8 +1840,7 @@
return nullptr;
// Otherwise, we succeeded. Use the state it parsed as our op information.
- return createOpFunc(opState.name, opState.operands, opState.types,
- opState.attributes);
+ return createOpFunc(opState);
}
//===----------------------------------------------------------------------===//
@@ -1979,14 +1978,8 @@
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<CFGValue *, 8> cfgOperands;
- cfgOperands.reserve(operands.size());
- for (auto *op : operands)
- cfgOperands.push_back(cast<CFGValue>(op));
- return builder.createOperation(name, cfgOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &result) -> Operation * {
+ return builder.createOperation(result);
};
// Parse the list of operations that make up the body of the block.
@@ -2261,14 +2254,8 @@
/// Parse a list of statements ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
- auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> resultTypes,
- ArrayRef<NamedAttribute> attrs) -> Operation * {
- SmallVector<MLValue *, 8> stmtOperands;
- stmtOperands.reserve(operands.size());
- for (auto *op : operands)
- stmtOperands.push_back(cast<MLValue>(op));
- return builder.createOperation(name, stmtOperands, resultTypes, attrs);
+ auto createOpFunc = [&](const OperationState &state) -> Operation * {
+ return builder.createOperation(state);
};
builder.setInsertionPoint(block);