blob: 66d3d05c885559b414246b9dda1e4bfe46f3469e [file] [log] [blame]
//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
static void printDimAndSymbolList(Operation::const_operand_iterator begin,
Operation::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p) {
*p << '(';
p->printOperands(begin, begin + numDims);
*p << ')';
if (begin + numDims != end) {
*p << '[';
p->printOperands(begin + numDims, end);
*p << ']';
}
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
static bool
parseDimAndSymbolList(OpAsmParser *parser,
SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimeter::ParenDelimeter))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperandList(
opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
}
// TODO: Have verify functions return std::string to enable more descriptive
// error messages.
OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
SSAValue *lhs, *rhs;
if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
parser->resolveOperand(ops[0], type, lhs) ||
parser->resolveOperand(ops[1], type, rhs))
return {};
return OpAsmParserResult({lhs, rhs}, type);
}
void AddFOp::print(OpAsmPrinter *p) const {
*p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
<< *getType();
}
// Return an error message on failure.
const char *AddFOp::verify() const {
// TODO: Check that the types of the LHS and RHS match.
// TODO: This should be a refinement of TwoOperands.
// TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
return nullptr;
}
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> opInfos;
SmallVector<SSAValue *, 4> operands;
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
if (parser->parseAttribute(mapAttr))
return {};
unsigned numDims;
if (parseDimAndSymbolList(parser, opInfos, operands, numDims))
return {};
auto *map = mapAttr->getValue();
if (map->getNumDims() != numDims ||
numDims + map->getNumSymbols() != opInfos.size()) {
parser->emitError(parser->getNameLoc(),
"dimension or symbol index mismatch");
return {};
}
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
return OpAsmParserResult(
operands, resultTypes,
NamedAttribute(builder.getIdentifier("map"), mapAttr));
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
}
const char *AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return "requires an affine map.";
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
return "operand count and affine map dimension and symbol count must match";
// Verify that result count matches affine map result count.
if (getNumResults() != map->getNumResults())
return "result count and affine map result count must match";
return nullptr;
}
void AllocOp::print(OpAsmPrinter *p) const {
MemRefType *type = cast<MemRefType>(getMemRef()->getType());
*p << "alloc";
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p);
// Print memref type.
*p << " : " << *type;
}
OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
// Parse the dimension operands and optional symbol operands.
unsigned numDimOperands;
if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands))
return {};
// Parse memref type.
if (parser->parseColonType(type))
return {};
// Check numDynamicDims against number of question marks in memref type.
if (numDimOperands != type->getNumDynamicDims()) {
parser->emitError(parser->getNameLoc(),
"Dynamic dimensions count mismatch: dimension operand "
"count does not equal memref dynamic dimension count.");
return {};
}
// Check that the number of symbol operands matches the number of symbols in
// the first affinemap of the memref's affine map composition.
// Note that a memref must specify at least one affine map in the composition.
if ((operandsInfo.size() - numDimOperands) !=
type->getAffineMaps()[0]->getNumSymbols()) {
parser->emitError(parser->getNameLoc(),
"AffineMap symbol count mismatch: symbol operand "
"count does not equal memref affine map symbol count.");
return {};
}
return OpAsmParserResult(operands, type);
}
const char *AllocOp::verify() const {
// TODO(andydavis): Verify alloc.
return nullptr;
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
auto *value = getValue();
if (!value)
return "requires a 'value' attribute";
auto *type = this->getType();
if (isa<IntegerType>(type) || type->isAffineInt()) {
if (!isa<IntegerAttr>(value))
return "requires 'value' to be an integer for an integer result type";
return nullptr;
}
if (isa<FunctionType>(type)) {
// TODO: Verify a function attr.
}
return "requires a result type that aligns with the 'value' attribute";
}
/// ConstantIntOp only matches values whose result type is an IntegerType or
/// AffineInt.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
(isa<IntegerType>(op->getResult(0)->getType()) ||
op->getResult(0)->getType()->isAffineInt());
}
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex() << " : "
<< *getOperand()->getType();
}
OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
OpAsmParser::OperandType operandInfo;
IntegerAttr *indexAttr;
Type *type;
SSAValue *operand;
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, operand))
return {};
auto &builder = parser->getBuilder();
return OpAsmParserResult(
operand, builder.getAffineIntType(),
NamedAttribute(builder.getIdentifier("index"), indexAttr));
}
const char *DimOp::verify() const {
// Check that we have an integer index operand.
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return "requires an integer attribute named 'index'";
uint64_t index = (uint64_t)indexAttr->getValue();
auto *type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
if (index >= tensorType->getRank())
return "index is out of range";
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
if (index >= memrefType->getRank())
return "index is out of range";
} else if (isa<UnrankedTensorType>(type)) {
// ok, assumed to be in-range.
} else {
return "requires an operand with tensor or memref type";
}
return nullptr;
}
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
*p << "] : " << *getMemRef()->getType();
}
OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimeter::SquareDelimeter) ||
parser->parseColonType(type) ||
parser->resolveOperands(memrefInfo, type, operands) ||
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
return OpAsmParserResult(operands, type->getElementType());
}
const char *LoadOp::verify() const {
if (getNumOperands() == 0)
return "expected a memref to load from";
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
if (!memRefType)
return "first operand must be a memref";
for (auto *idx : getIndices())
if (!idx->getType()->isAffineInt())
return "index to load must have 'affineint' type";
// TODO: Verify we have the right number of indices.
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
// result of an affine_apply.
return nullptr;
}
void StoreOp::print(OpAsmPrinter *p) const {
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
p->printOperands(getIndices());
*p << "] : " << *getMemRef()->getType();
}
OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimeter::SquareDelimeter) ||
parser->parseColonType(memrefType) ||
parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
operands) ||
parser->resolveOperands(memrefInfo, memrefType, operands) ||
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
return OpAsmParserResult(operands, {});
}
const char *StoreOp::verify() const {
if (getNumOperands() < 2)
return "expected a value to store and a memref";
// Second operand is a memref type.
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
if (!memRefType)
return "second operand must be a memref";
// First operand must have same type as memref element type.
if (getValueToStore()->getType() != memRefType->getElementType())
return "first operand must have same type memref element type ";
if (getNumOperands() != 2 + memRefType->getRank())
return "store index operand count not equal to memref rank";
for (auto *idx : getIndices())
if (!idx->getType()->isAffineInt())
return "index to load must have 'affineint' type";
// TODO: Verify we have the right number of indices.
// TODO: in MLFunction verify that the indices are parameters, IV's, or the
// result of an affine_apply.
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
StoreOp>(
/*prefix=*/"");
}