blob: 55c0c252899b0f21674c6a1347b2b97294d0e63c [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team39a3a602018-07-24 17:43:56 -070028// TODO: Have verify functions return std::string to enable more descriptive
29// error messages.
Chris Lattner85ee1512018-07-25 11:15:20 -070030OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
31 SmallVector<OpAsmParser::OperandType, 2> ops;
32 Type *type;
33 SSAValue *lhs, *rhs;
34 if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
35 parser->resolveOperand(ops[0], type, lhs) ||
36 parser->resolveOperand(ops[1], type, rhs))
37 return {};
38
39 return OpAsmParserResult({lhs, rhs}, type);
40}
41
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070042void AddFOp::print(OpAsmPrinter *p) const {
43 *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
44 << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070045}
46
Chris Lattner21e67f62018-07-06 10:46:19 -070047// Return an error message on failure.
48const char *AddFOp::verify() const {
49 // TODO: Check that the types of the LHS and RHS match.
50 // TODO: This should be a refinement of TwoOperands.
51 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
52 return nullptr;
53}
54
Chris Lattner3164ae62018-07-28 09:36:25 -070055OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
56 SmallVector<OpAsmParser::OperandType, 2> opInfos;
57 SmallVector<SSAValue *, 4> operands;
58
59 auto &builder = parser->getBuilder();
60 auto *affineIntTy = builder.getAffineIntType();
61
62 AffineMapAttr *mapAttr;
63 if (parser->parseAttribute(mapAttr) ||
64 parser->parseOperandList(opInfos, -1,
65 OpAsmParser::Delimeter::ParenDelimeter))
66 return {};
67 unsigned numDims = opInfos.size();
68
69 if (parser->parseOperandList(
70 opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
71 parser->resolveOperands(opInfos, affineIntTy, operands))
72 return {};
73
74 auto *map = mapAttr->getValue();
75 if (map->getNumDims() != numDims ||
76 numDims + map->getNumSymbols() != opInfos.size()) {
77 parser->emitError(parser->getNameLoc(),
78 "dimension or symbol index mismatch");
79 return {};
80 }
81
82 SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
83 return OpAsmParserResult(
84 operands, resultTypes,
85 NamedAttribute(builder.getIdentifier("map"), mapAttr));
86}
87
88void AffineApplyOp::print(OpAsmPrinter *p) const {
89 auto *map = getAffineMap();
90 *p << "affine_apply " << *map;
91
92 auto opit = operand_begin();
93 *p << '(';
94 p->printOperands(opit, opit + map->getNumDims());
95 *p << ')';
96
97 if (map->getNumSymbols()) {
98 *p << '[';
99 p->printOperands(opit + map->getNumDims(), operand_end());
100 *p << ']';
101 }
102}
103
104const char *AffineApplyOp::verify() const {
105 // Check that affine map attribute was specified.
106 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
107 if (!affineMapAttr)
108 return "requires an affine map.";
109
110 // Check input and output dimensions match.
111 auto *map = affineMapAttr->getValue();
112
113 // Verify that operand count matches affine map dimension and symbol count.
114 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
115 return "operand count and affine map dimension and symbol count must match";
116
117 // Verify that result count matches affine map result count.
118 if (getNumResults() != map->getNumResults())
119 return "result count and affine map result count must match";
120
121 return nullptr;
122}
123
Chris Lattner9361fb32018-07-24 08:34:58 -0700124/// The constant op requires an attribute, and furthermore requires that it
125/// matches the return type.
126const char *ConstantOp::verify() const {
127 auto *value = getValue();
128 if (!value)
129 return "requires a 'value' attribute";
130
131 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700132 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700133 if (!isa<IntegerAttr>(value))
134 return "requires 'value' to be an integer for an integer result type";
135 return nullptr;
136 }
137
138 if (isa<FunctionType>(type)) {
139 // TODO: Verify a function attr.
140 }
141
142 return "requires a result type that aligns with the 'value' attribute";
143}
144
Chris Lattner3da86ad2018-07-26 09:58:23 -0700145/// ConstantIntOp only matches values whose result type is an IntegerType or
146/// AffineInt.
Chris Lattner9361fb32018-07-24 08:34:58 -0700147bool ConstantIntOp::isClassFor(const Operation *op) {
148 return ConstantOp::isClassFor(op) &&
Chris Lattner3da86ad2018-07-26 09:58:23 -0700149 (isa<IntegerType>(op->getResult(0)->getType()) ||
150 op->getResult(0)->getType()->isAffineInt());
Chris Lattner9361fb32018-07-24 08:34:58 -0700151}
152
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700153void DimOp::print(OpAsmPrinter *p) const {
154 *p << "dim " << *getOperand() << ", " << getIndex() << " : "
155 << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700156}
157
Chris Lattner85ee1512018-07-25 11:15:20 -0700158OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
159 OpAsmParser::OperandType operandInfo;
160 IntegerAttr *indexAttr;
161 Type *type;
162 SSAValue *operand;
163 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
164 parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
165 parser->resolveOperand(operandInfo, type, operand))
166 return {};
167
168 auto &builder = parser->getBuilder();
169 return OpAsmParserResult(
170 operand, builder.getAffineIntType(),
171 NamedAttribute(builder.getIdentifier("index"), indexAttr));
172}
173
Chris Lattner21e67f62018-07-06 10:46:19 -0700174const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700175 // Check that we have an integer index operand.
176 auto indexAttr = getAttrOfType<IntegerAttr>("index");
177 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700178 return "requires an integer attribute named 'index'";
179 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700180
Chris Lattner9361fb32018-07-24 08:34:58 -0700181 auto *type = getOperand()->getType();
182 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
183 if (index >= tensorType->getRank())
184 return "index is out of range";
185 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
186 if (index >= memrefType->getRank())
187 return "index is out of range";
188
189 } else if (isa<UnrankedTensorType>(type)) {
190 // ok, assumed to be in-range.
191 } else {
192 return "requires an operand with tensor or memref type";
193 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700194
195 return nullptr;
196}
197
Chris Lattner85ee1512018-07-25 11:15:20 -0700198void LoadOp::print(OpAsmPrinter *p) const {
199 *p << "load " << *getMemRef() << '[';
200 p->printOperands(getIndices());
201 *p << "] : " << *getMemRef()->getType();
202}
203
204OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
205 OpAsmParser::OperandType memrefInfo;
206 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
207 MemRefType *type;
208 SmallVector<SSAValue *, 4> operands;
209
210 auto affineIntTy = parser->getBuilder().getAffineIntType();
211 if (parser->parseOperand(memrefInfo) ||
212 parser->parseOperandList(indexInfo, -1,
213 OpAsmParser::Delimeter::SquareDelimeter) ||
214 parser->parseColonType(type) ||
215 parser->resolveOperands(memrefInfo, type, operands) ||
216 parser->resolveOperands(indexInfo, affineIntTy, operands))
217 return {};
218
219 return OpAsmParserResult(operands, type->getElementType());
220}
221
222const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700223 if (getNumOperands() == 0)
224 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700225
Chris Lattner3164ae62018-07-28 09:36:25 -0700226 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
227 if (!memRefType)
228 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700229
Chris Lattner3164ae62018-07-28 09:36:25 -0700230 for (auto *idx : getIndices())
231 if (!idx->getType()->isAffineInt())
232 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700233
Chris Lattner3164ae62018-07-28 09:36:25 -0700234 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700235
Chris Lattner3164ae62018-07-28 09:36:25 -0700236 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
237 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700238 return nullptr;
239}
240
Chris Lattnerff0d5902018-07-05 09:12:11 -0700241/// Install the standard operations in the specified operation set.
242void mlir::registerStandardOperations(OperationSet &opSet) {
Chris Lattner3164ae62018-07-28 09:36:25 -0700243 opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
Chris Lattner85ee1512018-07-25 11:15:20 -0700244 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700245}