Implement operands for the lower and upper bounds of the for statement.
This revamps implementation of the loop bounds in the ForStmt, using general representation that supports operands. The frequent case of constant bounds is supported
via special access methods.
This also includes:
- Operand iterators for the Statement class.
- OpPointer::is() method to query the class of the Operation.
- Support for the bound shorthand notation parsing and printing.
- Validity checks for the bound operands used as dim ids and symbols
I didn't mean this CL to be so large. It just happened this way, as one thing led to another.
PiperOrigin-RevId: 210204858
diff --git a/include/mlir/IR/AffineMap.h b/include/mlir/IR/AffineMap.h
index 5e575ec..46b64f2 100644
--- a/include/mlir/IR/AffineMap.h
+++ b/include/mlir/IR/AffineMap.h
@@ -55,6 +55,13 @@
/// dimensional identifiers.
bool isIdentity() const;
+ /// Returns true if this affine map is a single result constant function.
+ bool isSingleConstant() const;
+
+ /// Returns the constant value that is the result of this map.
+ /// This methods asserts that the map has a single constant result.
+ int64_t getSingleConstantValue() const;
+
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
void dump() const;
@@ -62,11 +69,14 @@
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return numSymbols; }
unsigned getNumResults() const { return numResults; }
+ unsigned getNumOperands() const { return numDims + numSymbols; }
ArrayRef<AffineExpr *> getResults() const {
return ArrayRef<AffineExpr *>(results, numResults);
}
+ AffineExpr *getResult(unsigned idx) const { return results[idx]; }
+
ArrayRef<AffineExpr *> getRangeSizes() const {
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
: ArrayRef<AffineExpr *>();
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 207fefd..3cc8227 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -107,6 +107,14 @@
ArrayRef<AffineExpr *> constraints,
ArrayRef<bool> isEq);
+ // Special cases of affine maps and integer sets
+ // One constant result: () -> (val).
+ AffineMap *getConstantMap(int64_t val);
+ // One dimension id identity map: (i) -> (i).
+ AffineMap *getDimIdentityMap();
+ // One symbol identity map: ()[s] -> (s).
+ AffineMap *getSymbolIdentityMap();
+
// TODO: Helpers for affine map/exprs, etc.
protected:
MLIRContext *context;
@@ -266,14 +274,12 @@
/// Set the insertion point to the start of the specified block.
void setInsertionPointToStart(StmtBlock *block) {
- this->block = block;
- insertPoint = block->begin();
+ setInsertionPoint(block, block->begin());
}
/// Set the insertion point to the end of the specified block.
void setInsertionPointToEnd(StmtBlock *block) {
- this->block = block;
- insertPoint = block->end();
+ setInsertionPoint(block, block->end());
}
/// Get the current insertion point of the builder.
@@ -305,10 +311,14 @@
return cloneStmt;
}
- // Creates for statement. When step is not specified, it is set to 1.
- ForStmt *createFor(Attribute *location, AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound, int64_t step = 1);
+ /// Create a 'for' statement with bounds that may involve MLValue operands.
+ /// When step is not specified, it is set to 1.
+ ForStmt *createFor(Attribute *location, ArrayRef<MLValue *> lbOperands,
+ AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
+ AffineMap *ubMap, int64_t step = 1);
+ /// Create if statement.
+ /// TODO: pass operands.
IfStmt *createIf(Attribute *location, IntegerSet *condition);
private:
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 226e36b..ad0cb05 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -198,7 +198,7 @@
return getInstOperand(idx).get();
}
void setOperand(unsigned idx, CFGValue *value) {
- return getInstOperand(idx).set(value);
+ getInstOperand(idx).set(value);
}
// Support non-const operand iteration.
@@ -236,6 +236,13 @@
MutableArrayRef<InstOperand> getInstOperands() {
return {getTrailingObjects<InstOperand>(), numOperands};
}
+ // Accessors to InstOperand. Without these methods invoking getInstOperand()
+ // calls Instruction::getInstOperands() resulting in execution of
+ // an unnecessary switch statement.
+ InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return getInstOperands()[idx];
+ }
//===--------------------------------------------------------------------===//
// Results
diff --git a/include/mlir/IR/MLValue.h b/include/mlir/IR/MLValue.h
index 99e4628..8443e9f 100644
--- a/include/mlir/IR/MLValue.h
+++ b/include/mlir/IR/MLValue.h
@@ -45,6 +45,11 @@
/// MLValue is the base class for SSA values in ML functions.
class MLValue : public SSAValueImpl<StmtOperand, MLValueKind> {
public:
+ /// Returns true if this MLValue can be used as a dimension id.
+ bool isValidDim() const;
+ /// Returns true if this MLValue can be used as a symbol.
+ bool isValidSymbol() const;
+
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::MLFuncArgument:
@@ -96,7 +101,7 @@
OperationStmt *getOwner() { return owner; }
const OperationStmt *getOwner() const { return owner; }
- /// Return the number of this result.
+ /// Returns the number of this result.
unsigned getResultNumber() const;
private:
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 158362e..d5b2b25 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -234,6 +234,12 @@
return ConstOpPointer<OpClass>(OpClass(isMatch ? this : nullptr));
}
+ /// The is methods return true if the operation is a typed op (like DimOp) of
+ /// of the given class.
+ template <typename OpClass> bool is() const {
+ return OpClass::isClassFor(this);
+ }
+
enum class OperationKind { Instruction, Statement };
// This is used to implement the dynamic casting logic, but you shouldn't
// call it directly. Use something like isa<OperationInst>(someOp) instead.
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index e688ea2..aa9457d 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -75,11 +75,17 @@
static void build(Builder *builder, OperationState *result, AffineMap *map,
ArrayRef<SSAValue *> operands);
- // Returns the affine map to be applied by this operation.
+ /// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
}
+ /// Returns true if the result of this operation can be used as dimension id.
+ bool isValidDim() const;
+
+ /// Returns true if the result of this operation is a symbol.
+ bool isValidSymbol() const;
+
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 9669849..d6d16e3 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -22,6 +22,8 @@
#ifndef MLIR_IR_STATEMENT_H
#define MLIR_IR_STATEMENT_H
+#include "mlir/IR/MLValue.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h"
@@ -32,7 +34,6 @@
class StmtBlock;
class ForStmt;
class MLIRContext;
-class MLValue;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
@@ -68,7 +69,7 @@
/// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map.
- Statement *clone(OperandMapTy &operandMapping, MLIRContext *context) const;
+ Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const;
/// Returns the statement block that contains this statement.
StmtBlock *getBlock() const { return block; }
@@ -91,6 +92,55 @@
void print(raw_ostream &os) const;
void dump() const;
+ //===--------------------------------------------------------------------===//
+ // Operands
+ //===--------------------------------------------------------------------===//
+
+ unsigned getNumOperands() const;
+
+ MLValue *getOperand(unsigned idx);
+ const MLValue *getOperand(unsigned idx) const;
+ void setOperand(unsigned idx, MLValue *value);
+
+ // Support non-const operand iteration.
+ using operand_iterator = OperandIterator<Statement, MLValue>;
+
+ operand_iterator operand_begin() { return operand_iterator(this, 0); }
+
+ operand_iterator operand_end() {
+ return operand_iterator(this, getNumOperands());
+ }
+
+ llvm::iterator_range<operand_iterator> getOperands() {
+ return {operand_begin(), operand_end()};
+ }
+
+ // Support const operand iteration.
+ using const_operand_iterator =
+ OperandIterator<const Statement, const MLValue>;
+
+ const_operand_iterator operand_begin() const {
+ return const_operand_iterator(this, 0);
+ }
+
+ const_operand_iterator operand_end() const {
+ return const_operand_iterator(this, getNumOperands());
+ }
+
+ llvm::iterator_range<const_operand_iterator> getOperands() const {
+ return {operand_begin(), operand_end()};
+ }
+
+ MutableArrayRef<StmtOperand> getStmtOperands();
+ ArrayRef<StmtOperand> getStmtOperands() const {
+ return const_cast<Statement *>(this)->getStmtOperands();
+ }
+
+ StmtOperand &getStmtOperand(unsigned idx) { return getStmtOperands()[idx]; }
+ const StmtOperand &getStmtOperand(unsigned idx) const {
+ return getStmtOperands()[idx];
+ }
+
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening. NOTE: This may terminate
/// the containing application, only use when the IR is in an inconsistent
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 47b2427..1aa6581 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -22,16 +22,16 @@
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
-#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/Statement.h"
#include "mlir/IR/StmtBlock.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
+class AffineMap;
+class AffineBound;
/// Operation statements represent operations inside ML functions.
class OperationStmt final
@@ -55,7 +55,7 @@
using Statement::getLoc;
/// Check if this statement is a return statement.
- bool isReturn() const { return getName().str() == "return"; }
+ bool isReturn() const;
//===--------------------------------------------------------------------===//
// Operands
@@ -112,11 +112,6 @@
return getStmtOperands()[idx];
}
- /// This drops all operand uses from this instruction, which is an essential
- /// step in breaking cyclic dependences between references when they are to
- /// be deleted.
- void dropAllReferences();
-
//===--------------------------------------------------------------------===//
// Results
//===--------------------------------------------------------------------===//
@@ -203,31 +198,114 @@
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue, public StmtBlock {
public:
- // TODO: lower and upper bounds should be affine maps with
- // dimension and symbol use lists.
- explicit ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound, int64_t step,
- MLIRContext *context);
+ static ForStmt *create(Attribute *location, ArrayRef<MLValue *> lbOperands,
+ AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
+ AffineMap *ubMap, int64_t step, MLIRContext *context);
~ForStmt() {
- // Loop bounds and step are immortal objects and don't need to be deleted.
-
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before the MLValue that this
- // for stmt represents is destroyed.
+ // for stmt represents is destroyed. Affine maps are immortal objects and
+ // don't need to be deleted.
clear();
}
/// Resolve base class ambiguity.
using Statement::findFunction;
- AffineConstantExpr *getLowerBound() const { return lowerBound; }
- AffineConstantExpr *getUpperBound() const { return upperBound; }
+ /// Operand iterators.
+ using operand_iterator = OperandIterator<ForStmt, MLValue>;
+ using const_operand_iterator = OperandIterator<const ForStmt, const MLValue>;
+
+ /// Operand iterator range.
+ using operand_range = llvm::iterator_range<operand_iterator>;
+ using const_operand_range = llvm::iterator_range<const_operand_iterator>;
+
+ //===--------------------------------------------------------------------===//
+ // Bounds and step
+ //===--------------------------------------------------------------------===//
+
+ /// Returns information about the lower bound as a single object.
+ const AffineBound getLowerBound() const;
+
+ /// Returns information about the upper bound as a single object.
+ const AffineBound getUpperBound() const;
+
+ /// Returns loop step.
int64_t getStep() const { return step; }
- void setLowerBound(AffineConstantExpr *lb) { lowerBound = lb; }
- void setUpperBound(AffineConstantExpr *ub) { upperBound = ub; }
- void setStep(unsigned s) { step = s; }
+ /// Returns affine map for the lower bound.
+ AffineMap *getLowerBoundMap() const { return lbMap; }
+ /// Returns affine map for the upper bound.
+ AffineMap *getUpperBoundMap() const { return ubMap; }
+
+ /// Set lower bound.
+ void setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map);
+ /// Set upper bound.
+ void setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map);
+ /// Set loop step.
+ void setStep(int64_t step) { this->step = step; }
+
+ /// Returns true if the lower bound is constant.
+ bool hasConstantLowerBound() const;
+ /// Returns true if the upper bound is constant.
+ bool hasConstantUpperBound() const;
+ /// Returns true if both bounds are constant.
+ bool hasConstantBounds() const {
+ return hasConstantLowerBound() && hasConstantUpperBound();
+ }
+ /// Returns the value of the constant lower bound.
+ /// Fails assertion if the bound is non-constant.
+ int64_t getConstantLowerBound() const;
+ /// Returns the value of the constant upper bound.
+ /// Fails assertion if the bound is non-constant.
+ int64_t getConstantUpperBound() const;
+ /// Sets the lower bound to the given constant value.
+ void setConstantLowerBound(int64_t value);
+ /// Sets the upper bound to the given constant value.
+ void setConstantUpperBound(int64_t value);
+
+ //===--------------------------------------------------------------------===//
+ // Operands
+ //===--------------------------------------------------------------------===//
+
+ unsigned getNumOperands() const { return operands.size(); }
+
+ MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
+ const MLValue *getOperand(unsigned idx) const {
+ return getStmtOperand(idx).get();
+ }
+ void setOperand(unsigned idx, MLValue *value) {
+ getStmtOperand(idx).set(value);
+ }
+
+ operand_iterator operand_begin() { return operand_iterator(this, 0); }
+ operand_iterator operand_end() {
+ return operand_iterator(this, getNumOperands());
+ }
+
+ const_operand_iterator operand_begin() const {
+ return const_operand_iterator(this, 0);
+ }
+ const_operand_iterator operand_end() const {
+ return const_operand_iterator(this, getNumOperands());
+ }
+
+ ArrayRef<StmtOperand> getStmtOperands() const { return operands; }
+ MutableArrayRef<StmtOperand> getStmtOperands() { return operands; }
+ StmtOperand &getStmtOperand(unsigned idx) { return getStmtOperands()[idx]; }
+ const StmtOperand &getStmtOperand(unsigned idx) const {
+ return getStmtOperands()[idx];
+ }
+ // TODO: provide iterators for the lower and upper bound operands
+ // if the current access via getLowerBound(), getUpperBound() is too slow.
+
+ //===--------------------------------------------------------------------===//
+ // Other
+ //===--------------------------------------------------------------------===//
+
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const { return getType()->getContext(); }
using Statement::dump;
using Statement::print;
@@ -249,11 +327,63 @@
}
private:
- // TODO(shpeisman): please name the ForStmt's bounds encapsulating
- // an affinemap and its operands as AffineBound.
- AffineConstantExpr *lowerBound;
- AffineConstantExpr *upperBound;
+ // Affine map for the lower bound.
+ AffineMap *lbMap;
+ // Affine map for the upper bound.
+ AffineMap *ubMap;
+ // Constant step.
int64_t step;
+ // Operands for the lower and upper bounds.
+ std::vector<StmtOperand> operands;
+
+ explicit ForStmt(Attribute *location, unsigned numOperands, AffineMap *lbMap,
+ AffineMap *ubMap, int64_t step, MLIRContext *context);
+};
+
+/// AffineBound represents a lower or upper bound in the for statement.
+/// This class does not own the underlying operands. Instead, it refers
+/// to the operands stored in the ForStmt. It's life span should not exceed
+/// that of the for statement it refers to.
+class AffineBound {
+public:
+ const ForStmt *getOwner() const { return &stmt; }
+ AffineMap *getMap() const { return map; }
+
+ unsigned getNumOperands() const { return opEnd - opStart; }
+ const MLValue *getOperand(unsigned idx) const {
+ return stmt.getOperand(opStart + idx);
+ }
+ const StmtOperand &getStmtOperand(unsigned idx) const {
+ return stmt.getStmtOperand(opStart + idx);
+ }
+
+ using operand_iterator = ForStmt::const_operand_iterator;
+ using operand_range = ForStmt::const_operand_range;
+
+ operand_iterator operand_begin() const {
+ return operand_iterator(&stmt, opStart);
+ }
+ operand_iterator operand_end() const {
+ return operand_iterator(&stmt, opEnd);
+ }
+
+ operand_range getOperands() const { return {operand_begin(), operand_end()}; }
+ ArrayRef<StmtOperand> getStmtOperands() const {
+ auto ops = stmt.getStmtOperands();
+ return ArrayRef<StmtOperand>(ops.begin() + opStart, ops.begin() + opEnd);
+ }
+
+private:
+ const ForStmt &stmt;
+ unsigned opStart, opEnd;
+ AffineMap *map;
+
+ AffineBound(const ForStmt &stmt, const unsigned opStart, const unsigned opEnd,
+ const AffineMap *map)
+ : stmt(stmt), opStart(opStart), opEnd(opEnd),
+ map(const_cast<AffineMap *>(map)) {}
+
+ friend class ForStmt;
};
/// An if clause represents statements contained within a then or an else clause