Clean up the op builder APIs, and simplify the implementation of ops by making
OperationState contain a context and have the generic builder mechanics handle
the job of initializing the OperationState and setting the op name. NFC.
PiperOrigin-RevId: 209869948
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 44fa224..95d73a4 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -175,7 +175,9 @@
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
- auto *inst = createOperation(OpTy::build(this, args...));
+ OperationState state(getContext(), OpTy::getOperationName());
+ OpTy::build(this, &state, args...);
+ auto *inst = createOperation(state);
auto result = inst->template getAs<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
@@ -279,7 +281,9 @@
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
- auto stmt = createOperation(OpTy::build(this, args...));
+ OperationState state(getContext(), OpTy::getOperationName());
+ OpTy::build(this, &state, args...);
+ auto *stmt = createOperation(state);
auto result = stmt->template getAs<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index af01a1d..2e44cd5 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -45,6 +45,7 @@
/// be used as a temporary object on the stack. It is generally unwise to put
/// this in a collection.
struct OperationState {
+ MLIRContext *const context;
Identifier name;
SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation.
@@ -52,14 +53,31 @@
SmallVector<NamedAttribute, 4> attributes;
public:
- OperationState(Identifier name) : name(name) {}
+ OperationState(MLIRContext *context, StringRef name)
+ : context(context), name(Identifier::get(name, context)) {}
- OperationState(Identifier name, ArrayRef<SSAValue *> operands,
- ArrayRef<Type *> types,
+ OperationState(MLIRContext *context, Identifier name)
+ : context(context), name(name) {}
+
+ OperationState(MLIRContext *context, StringRef name,
+ ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
ArrayRef<NamedAttribute> attributes = {})
- : name(name), operands(operands.begin(), operands.end()),
+ : context(context), name(Identifier::get(name, context)),
+ operands(operands.begin(), operands.end()),
types(types.begin(), types.end()),
attributes(attributes.begin(), attributes.end()) {}
+
+ void addOperands(ArrayRef<SSAValue *> newOperands) {
+ operands.append(newOperands.begin(), newOperands.end());
+ }
+
+ void addTypes(ArrayRef<Type *> newTypes) {
+ types.append(newTypes.begin(), newTypes.end());
+ }
+
+ void addAttribute(StringRef name, Attribute *attr) {
+ attributes.push_back({Identifier::get(name, context), attr});
+ }
};
/// Operations represent all of the arithmetic and other basic computation in
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 3bff44e..a7b671c 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -43,13 +43,8 @@
public:
static StringRef getOperationName() { return "addf"; }
- template <class Builder, class Value>
- static OpPointer<AddFOp> build(Builder *builder, Value *lhs, Value *rhs) {
- // The resultant type of a addf is the same as both the lhs and rhs.
- return OpPointer<AddFOp>(AddFOp(builder->createOperation(
- builder->getIdentifier("addf"), {lhs, rhs}, {lhs->getType()}, {})));
- }
-
+ static void build(Builder *builder, OperationState *result, SSAValue *lhs,
+ SSAValue *rhs);
const char *verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@@ -77,8 +72,8 @@
OpTrait::VariadicResults> {
public:
/// Builds an affine apply op with the specified map and operands.
- static OperationState build(Builder *builder, AffineMap *map,
- ArrayRef<SSAValue *> operands);
+ static void build(Builder *builder, OperationState *result, AffineMap *map,
+ ArrayRef<SSAValue *> operands);
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
@@ -145,8 +140,8 @@
public:
static StringRef getOperationName() { return "call"; }
- static OperationState build(Builder *builder, Function *callee,
- ArrayRef<SSAValue *> operands);
+ static void build(Builder *builder, OperationState *result, Function *callee,
+ ArrayRef<SSAValue *> operands);
Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee")->getValue();
@@ -175,8 +170,8 @@
public:
static StringRef getOperationName() { return "call_indirect"; }
- static OperationState build(Builder *builder, SSAValue *callee,
- ArrayRef<SSAValue *> operands);
+ static void build(Builder *builder, OperationState *result, SSAValue *callee,
+ ArrayRef<SSAValue *> operands);
const SSAValue *getCallee() const { return getOperand(0); }
SSAValue *getCallee() { return getOperand(0); }
@@ -222,7 +217,8 @@
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
- static OperationState build(Builder *builder, double value, FloatType *type);
+ static void build(Builder *builder, OperationState *result, double value,
+ FloatType *type);
double getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
@@ -243,7 +239,8 @@
class ConstantIntOp : public ConstantOp {
public:
/// Build a constant int op producing an integer of the specified width.
- static OperationState build(Builder *builder, int64_t value, unsigned width);
+ static void build(Builder *builder, OperationState *result, int64_t value,
+ unsigned width);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
@@ -264,7 +261,7 @@
class ConstantAffineIntOp : public ConstantOp {
public:
/// Build a constant int op producing an affineint.
- static OperationState build(Builder *builder, int64_t value);
+ static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index fb7dfe7..9f01c13 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -66,6 +66,13 @@
// AddFOp
//===----------------------------------------------------------------------===//
+void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
+ SSAValue *rhs) {
+ assert(lhs->getType() == rhs->getType());
+ result->addOperands({lhs, rhs});
+ result->types.push_back(lhs->getType());
+}
+
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
@@ -201,15 +208,11 @@
// CallOp
//===----------------------------------------------------------------------===//
-OperationState CallOp::build(Builder *builder, Function *callee,
- ArrayRef<SSAValue *> operands) {
- OperationState result(builder->getIdentifier("call"));
- result.operands.append(operands.begin(), operands.end());
- result.attributes.push_back(
- {builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
- result.types.append(callee->getType()->getResults().begin(),
- callee->getType()->getResults().end());
- return result;
+void CallOp::build(Builder *builder, OperationState *result, Function *callee,
+ ArrayRef<SSAValue *> operands) {
+ result->addOperands(operands);
+ result->addAttribute("callee", builder->getFunctionAttr(callee));
+ result->addTypes(callee->getType()->getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -229,10 +232,7 @@
result->operands))
return true;
- auto &builder = parser->getBuilder();
- result->attributes.push_back(
- {builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
-
+ result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
return false;
}
@@ -277,15 +277,12 @@
// CallIndirectOp
//===----------------------------------------------------------------------===//
-OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
- ArrayRef<SSAValue *> operands) {
+void CallIndirectOp::build(Builder *builder, OperationState *result,
+ SSAValue *callee, ArrayRef<SSAValue *> operands) {
auto *fnType = cast<FunctionType>(callee->getType());
-
- OperationState result(builder->getIdentifier("call_indirect"));
- result.operands.push_back(callee);
- result.operands.append(operands.begin(), operands.end());
- result.types.append(fnType->getResults().begin(), fnType->getResults().end());
- return result;
+ result->operands.push_back(callee);
+ result->addOperands(operands);
+ result->addTypes(fnType->getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -406,13 +403,10 @@
return "requires a result type that aligns with the 'value' attribute";
}
-OperationState ConstantFloatOp::build(Builder *builder, double value,
- FloatType *type) {
- OperationState result(builder->getIdentifier("constant"));
- result.attributes.push_back(
- {builder->getIdentifier("value"), builder->getFloatAttr(value)});
- result.types.push_back(type);
- return result;
+void ConstantFloatOp::build(Builder *builder, OperationState *result,
+ double value, FloatType *type) {
+ result->addAttribute("value", builder->getFloatAttr(value));
+ result->types.push_back(type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
@@ -426,13 +420,10 @@
isa<IntegerType>(op->getResult(0)->getType());
}
-OperationState ConstantIntOp::build(Builder *builder, int64_t value,
- unsigned width) {
- OperationState result(builder->getIdentifier("constant"));
- result.attributes.push_back(
- {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
- result.types.push_back(builder->getIntegerType(width));
- return result;
+void ConstantIntOp::build(Builder *builder, OperationState *result,
+ int64_t value, unsigned width) {
+ result->addAttribute("value", builder->getIntegerAttr(value));
+ result->types.push_back(builder->getIntegerType(width));
}
/// ConstantAffineIntOp only matches values whose result type is AffineInt.
@@ -441,28 +432,21 @@
op->getResult(0)->getType()->isAffineInt();
}
-OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
- OperationState result(builder->getIdentifier("constant"));
- result.attributes.push_back(
- {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
- result.types.push_back(builder->getAffineIntType());
- return result;
+void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
+ int64_t value) {
+ result->addAttribute("value", builder->getIntegerAttr(value));
+ result->types.push_back(builder->getAffineIntType());
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
-OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
- ArrayRef<SSAValue *> operands) {
- SmallVector<Type *, 4> resultTypes(map->getNumResults(),
- builder->getAffineIntType());
-
- OperationState result(
- builder->getIdentifier("affine_apply"), operands, resultTypes,
- {{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
-
- return result;
+void AffineApplyOp::build(Builder *builder, OperationState *result,
+ AffineMap *map, ArrayRef<SSAValue *> operands) {
+ result->addOperands(operands);
+ result->types.append(map->getNumResults(), builder->getAffineIntType());
+ result->addAttribute("map", builder->getAffineMapAttr(map));
}
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d2ea5bf..de82e45 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1648,7 +1648,7 @@
consumeToken(Token::string);
- OperationState result(builder.getIdentifier(name));
+ OperationState result(builder.getContext(), name);
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@@ -1675,7 +1675,7 @@
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
- result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+ result.addTypes(fnType->getResults());
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
@@ -1916,7 +1916,7 @@
opNameStr.c_str());
// Have the op implementation take a crack and parsing this.
- OperationState opState(builder.getIdentifier(opName));
+ OperationState opState(builder.getContext(), opName);
if (opDefinition->parseAssembly(&opAsmParser, &opState))
return nullptr;