Teach the asmprinter to print out operands for OperationInst's. This
is still limited in several ways, which i'll build out in subsequent patches.
Rename the accessor for inst operands/results to make the Operand/Result
versions of these more obscure, allowing getOperand/getResult to traffic
in values (which is what - by far - most clients actually care about).
PiperOrigin-RevId: 205408439
diff --git a/include/mlir/IR/CFGValue.h b/include/mlir/IR/CFGValue.h
index 16ae003..95956f3 100644
--- a/include/mlir/IR/CFGValue.h
+++ b/include/mlir/IR/CFGValue.h
@@ -40,7 +40,7 @@
/// The operand of a CFG Instruction contains a CFGValue.
using InstOperand = SSAOperandImpl<CFGValue, Instruction>;
-/// CFGValue is the base class for CFG value types.
+/// CFGValue is the base class for SSA values in CFG functions.
class CFGValue : public SSAValueImpl<InstOperand, CFGValueKind> {
public:
static bool classof(const SSAValue *value) {
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 6817824..16073dd 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -97,19 +97,44 @@
MLIRContext *context);
~OperationInst();
- ArrayRef<InstOperand> getOperands() const {
+ unsigned getNumOperands() const { return numOperands; }
+
+ // TODO: Add a getOperands() custom sequence that provides a value projection
+ // of the operand list.
+ CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+ const CFGValue *getOperand(unsigned idx) const {
+ return getInstOperand(idx).get();
+ }
+
+ unsigned getNumResults() const { return numResults; }
+
+ // TODO: Add a getResults() custom sequence that provides a value projection
+ // of the result list.
+ CFGValue *getResult(unsigned idx) { return &getInstResult(idx); }
+ const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); }
+
+ ArrayRef<InstOperand> getInstOperands() const {
return {getTrailingObjects<InstOperand>(), numOperands};
}
- MutableArrayRef<InstOperand> getOperands() {
+ MutableArrayRef<InstOperand> getInstOperands() {
return {getTrailingObjects<InstOperand>(), numOperands};
}
- ArrayRef<InstResult> getResults() const {
+ InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return getInstOperands()[idx];
+ }
+
+ ArrayRef<InstResult> getInstResults() const {
return {getTrailingObjects<InstResult>(), numResults};
}
- MutableArrayRef<InstResult> getResults() {
+ MutableArrayRef<InstResult> getInstResults() {
return {getTrailingObjects<InstResult>(), numResults};
}
+ InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; }
+ const InstResult &getInstResult(unsigned idx) const {
+ return getInstResults()[idx];
+ }
/// Unlink this instruction from its BasicBlock and delete it.
void eraseFromBlock();
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 03d9b1d..31fd05c 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -38,6 +38,12 @@
void Identifier::dump() const { print(llvm::errs()); }
+template <typename Container, typename UnaryFunctor>
+inline void interleaveComma(raw_ostream &os, const Container &c,
+ UnaryFunctor each_fn) {
+ interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
+}
+
//===----------------------------------------------------------------------===//
// Module printing
//===----------------------------------------------------------------------===//
@@ -222,9 +228,7 @@
case Attribute::Kind::Array: {
auto elts = cast<ArrayAttr>(attr)->getValue();
os << '[';
- interleave(elts,
- [&](Attribute *attr) { print(attr); },
- [&]() { os << ", "; });
+ interleaveComma(os, elts, [&](Attribute *attr) { print(attr); });
os << ']';
break;
}
@@ -260,16 +264,14 @@
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
os << '(';
- interleave(func->getInputs(), [&](Type *type) { os << *type; },
- [&]() { os << ", "; });
+ interleaveComma(os, func->getInputs(), [&](Type *type) { os << *type; });
os << ") -> ";
auto results = func->getResults();
if (results.size() == 1)
os << *results[0];
else {
os << '(';
- interleave(results, [&](Type *type) { os << *type; },
- [&]() { os << ", "; });
+ interleaveComma(os, results, [&](Type *type) { os << *type; });
os << ')';
}
return;
@@ -331,9 +333,8 @@
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleave(type->getInputs(),
- [&](Type *eltType) { moduleState->print(eltType); },
- [&]() { os << ", "; });
+ interleaveComma(os, type->getInputs(),
+ [&](Type *eltType) { moduleState->print(eltType); });
os << ')';
switch (type->getResults().size()) {
@@ -345,9 +346,8 @@
break;
default:
os << " -> (";
- interleave(type->getResults(),
- [&](Type *eltType) { moduleState->print(eltType); },
- [&]() { os << ", "; });
+ interleaveComma(os, type->getResults(),
+ [&](Type *eltType) { moduleState->print(eltType); });
os << ')';
break;
}
@@ -374,6 +374,26 @@
raw_ostream &os;
const ModuleState *moduleState;
const OperationSet &operationSet;
+
+ void numberValueID(const SSAValue *value) {
+ assert(!valueIDs.count(value) && "Value numbered multiple times");
+ valueIDs[value] = nextValueID++;
+ }
+
+ void printValueID(const SSAValue *value) const {
+ // TODO: If this is the result of an operation with multiple results, look
+ // up the first result, and print the #32 syntax.
+ auto it = valueIDs.find(value);
+ if (it != valueIDs.end())
+ os << '%' << it->getSecond();
+ else
+ os << "<<INVALID SSA VALUE>>";
+ }
+
+private:
+ /// This is the value ID for each SSA value in the current function.
+ DenseMap<const SSAValue *, unsigned> valueIDs;
+ unsigned nextValueID = 0;
};
} // end anonymous namespace
@@ -384,32 +404,72 @@
operationSet(OperationSet::get(context)) {}
void FunctionState::printOperation(const Operation *op) {
+ os << " ";
+
+ // TODO: When we have SSAValue version of operands & results wired into
+ // Operation this check can go away.
+ if (auto *inst = dyn_cast<OperationInst>(op)) {
+ if (inst->getNumResults()) {
+ printValueID(inst->getResult(0));
+ os << " = ";
+ }
+ }
+
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
if (auto opInfo = operationSet.lookup(op->getName().str())) {
- os << " ";
opInfo->printAssembly(op, os);
return;
}
+ // Otherwise use the standard verbose printing approach.
+
// TODO: escape name if necessary.
- os << " \"" << op->getName().str() << "\"()";
+ os << "\"" << op->getName().str() << "\"(";
- // FIXME: Print operand references.
+ // TODO: When we have SSAValue version of operands & results wired into
+ // Operation this check can go away.
+ if (auto *inst = dyn_cast<OperationInst>(op)) {
+ // TODO: Use getOperands() when we have it.
+ interleaveComma(
+ os, inst->getInstOperands(),
+ [&](const InstOperand &operand) { printValueID(operand.get()); });
+ }
+ os << ')';
auto attrs = op->getAttrs();
if (!attrs.empty()) {
os << '{';
- interleave(
- attrs,
- [&](NamedAttribute attr) {
- os << attr.first << ": ";
- moduleState->print(attr.second); },
- [&]() { os << ", "; });
+ interleaveComma(os, attrs, [&](NamedAttribute attr) {
+ os << attr.first << ": ";
+ moduleState->print(attr.second);
+ });
os << '}';
}
- // TODO: Print signature type once that is plumbed through to Operation.
+ // TODO: When we have SSAValue version of operands & results wired into
+ // Operation this check can go away.
+ if (auto *inst = dyn_cast<OperationInst>(op)) {
+ // Print the type signature of the operation.
+ os << " : (";
+ // TODO: Switch to getOperands() when we have it.
+ interleaveComma(os, inst->getInstOperands(), [&](const InstOperand &op) {
+ moduleState->print(op.get()->getType());
+ });
+ os << ") -> ";
+
+ // TODO: Switch to getResults() when we have it.
+ if (inst->getNumResults() == 1) {
+ moduleState->print(inst->getInstResult(0).getType());
+ } else {
+ os << '(';
+ interleaveComma(os, inst->getInstResults(),
+ [&](const InstResult &result) {
+ moduleState->print(result.getType());
+ });
+ os << ')';
+ }
+ }
}
//===----------------------------------------------------------------------===//
@@ -441,6 +501,8 @@
private:
const CFGFunction *function;
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
+
+ void numberBlock(const BasicBlock *block);
};
} // end anonymous namespace
@@ -451,7 +513,23 @@
function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
- for (auto &block : *function) basicBlockIDs[&block] = blockID++;
+ for (auto &block : *function) {
+ basicBlockIDs[&block] = blockID++;
+ numberBlock(&block);
+ }
+}
+
+/// Number all of the SSA values in the specified basic block.
+void CFGFunctionState::numberBlock(const BasicBlock *block) {
+ // TODO: basic block arguments.
+ for (auto &op : *block) {
+ // We number instruction that have results, and we only number the first
+ // result.
+ if (op.getNumResults() != 0)
+ numberValueID(op.getResult(0));
+ }
+
+ // Terminators do not define values.
}
void CFGFunctionState::print() {
@@ -489,29 +567,6 @@
void CFGFunctionState::print(const OperationInst *inst) {
printOperation(inst);
-
- // FIXME: Move this into printOperation when Operation has operands and
- // results
-
- // Print the type signature of the operation.
- os << " : (";
- interleave(
- inst->getOperands(),
- [&](const InstOperand &op) { moduleState->print(op.get()->getType()); },
- [&]() { os << ", "; });
- os << ") -> ";
-
- auto resultList = inst->getResults();
- if (resultList.size() == 1) {
- moduleState->print(resultList[0].getType());
- } else {
- os << '(';
- interleave(
- resultList,
- [&](const InstResult &result) { moduleState->print(result.getType()); },
- [&]() { os << ", "; });
- os << ')';
- }
}
void CFGFunctionState::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
@@ -753,8 +808,7 @@
assert(!getResults().empty());
// Result affine expressions.
os << " -> (";
- interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
- [&]() { os << ", "; });
+ interleaveComma(os, getResults(), [&](AffineExpr *expr) { os << *expr; });
os << ")";
if (!isBounded()) {
@@ -763,8 +817,7 @@
// Print range sizes for bounded affine maps.
os << " size (";
- interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
- [&]() { os << ", "; });
+ interleaveComma(os, getRangeSizes(), [&](AffineExpr *expr) { os << *expr; });
os << ")";
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 23aaa4a..57cedfe 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -67,11 +67,11 @@
name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
- auto instOperands = inst->getOperands();
+ auto instOperands = inst->getInstOperands();
for (unsigned i = 0, e = operands.size(); i != e; ++i)
new (&instOperands[i]) InstOperand(inst, operands[i]);
- auto instResults = inst->getResults();
+ auto instResults = inst->getInstResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
new (&instResults[i]) InstResult(resultTypes[i], inst);
return inst;
@@ -87,10 +87,10 @@
OperationInst::~OperationInst() {
// Explicitly run the destructors for the operands and results.
- for (auto &operand : getOperands())
+ for (auto &operand : getInstOperands())
operand.~InstOperand();
- for (auto &result : getResults())
+ for (auto &result : getInstResults())
result.~InstResult();
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7b7ee89..b8bae40 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1364,8 +1364,7 @@
if (operandTypes.size() != operandInfos.size()) {
auto plural = "s"[operandInfos.size() == 1];
return emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
- " type" + plural +
- " in operand list but had " +
+ " operand type" + plural + " but had " +
llvm::utostr(operandTypes.size()));
}
@@ -1395,11 +1394,10 @@
// FIXME: Add result infra to handle Stmt results as well to make this
// generic.
if (auto *inst = dyn_cast<OperationInst>(op)) {
- if (inst->getResults().empty())
+ if (inst->getNumResults() == 0)
return emitError(loc, "cannot name an operation with no results");
- // TODO: This should be getResult(0)
- addDefinition({resultID, loc}, &inst->getResults()[0]);
+ addDefinition({resultID, loc}, inst->getResult(0));
}
}
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 49fd2da..b2c7432 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -168,7 +168,7 @@
cfgfunc @test() {
bb40:
- %1 = "foo"() : (i32)->i64 // expected-error {{expected 0 types in operand list but had 1}}
+ %1 = "foo"() : (i32)->i64 // expected-error {{expected 0 operand types but had 1}}
return
}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 9b384b4..222fa12 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -70,10 +70,10 @@
cfgfunc @simpleCFG(i32, f32) {
// CHECK: bb0:
bb42: // (%0: i32, %f: f32): TODO(clattner): implement bbargs.
- // CHECK: "foo"() : () -> i64
+ // CHECK: %0 = "foo"() : () -> i64
%1 = "foo"() : ()->i64
- // CHECK: "bar"() : (i64) -> (i1, i1, i1)
- "bar"(%1) : (i64) -> (i1,i1,i1)
+ // CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
+ %2 = "bar"(%1) : (i64) -> (i1,i1,i1)
// CHECK: return
return
// CHECK: }
@@ -104,6 +104,7 @@
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
cfgfunc @cfgfunc_with_ops() {
bb0:
+ // CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
%t = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: dim xxx, 2 : sometype
@@ -166,6 +167,7 @@
// CHECK-LABEL: cfgfunc @standard_instrs() {
cfgfunc @standard_instrs() {
bb42: // CHECK: bb0:
+ // CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
%42 = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: dim xxx, 2 : sometype