Teach RaiseControlFlow to handle IfOp's with partially infered shapes,
inserting shape_casts as necessary.
Along the way:
- Add some missing accessors to the AtLeastNOperands trait.
- Implement shape_cast / ShapeCastOp standard op.
- Improve handling of errors in mlir-opt, making it easier to understand
errors when invalid IR is rejected by the verifier.
PiperOrigin-RevId: 211897877
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 3070621..8caffee 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -714,6 +714,48 @@
}
//===----------------------------------------------------------------------===//
+// ShapeCastOp
+//===----------------------------------------------------------------------===//
+
+void ShapeCastOp::build(Builder *builder, OperationState *result,
+ SSAValue *input, Type *resultType) {
+ result->addOperands(input);
+ result->addTypes(resultType);
+}
+
+const char *ShapeCastOp::verify() const {
+ auto *opType = dyn_cast<TensorType>(getOperand()->getType());
+ auto *resType = dyn_cast<TensorType>(getResult()->getType());
+ if (!opType || !resType)
+ return "requires input and result types to be tensors";
+
+ if (opType == resType)
+ return "requires the input and result type to be different";
+
+ if (opType->getElementType() != resType->getElementType())
+ return "requires input and result element types to be the same";
+
+ // If the source or destination are unranked, then the cast is valid.
+ auto *opRType = dyn_cast<RankedTensorType>(opType);
+ auto *resRType = dyn_cast<RankedTensorType>(resType);
+ if (!opRType || !resRType)
+ return nullptr;
+
+ // If they are both ranked, they have to have the same rank, and any specified
+ // dimensions must match.
+ if (opRType->getRank() != resRType->getRank())
+ return "requires input and result ranks to match";
+
+ for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
+ int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+ if (opDim != -1 && resultDim != -1 && opDim != resultDim)
+ return "requires static dimensions to match";
+ }
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -788,6 +830,6 @@
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
- ReturnOp, StoreOp>(
+ ReturnOp, ShapeCastOp, StoreOp>(
/*prefix=*/"");
}