blob: 8f4f1b30dbc1fa33785a2ff69335d494200c0edd [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team554a8ad2018-07-30 13:08:05 -070028static void printDimAndSymbolList(Operation::const_operand_iterator begin,
29 Operation::const_operand_iterator end,
30 unsigned numDims, OpAsmPrinter *p) {
31 *p << '(';
32 p->printOperands(begin, begin + numDims);
33 *p << ')';
34
35 if (begin + numDims != end) {
36 *p << '[';
37 p->printOperands(begin + numDims, end);
38 *p << ']';
39 }
40}
41
42// Parses dimension and symbol list, and sets 'numDims' to the number of
43// dimension operands parsed.
44// Returns 'false' on success and 'true' on error.
45static bool
46parseDimAndSymbolList(OpAsmParser *parser,
47 SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
48 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
49 if (parser->parseOperandList(opInfos, -1,
50 OpAsmParser::Delimeter::ParenDelimeter))
51 return true;
52 // Store number of dimensions for validation by caller.
53 numDims = opInfos.size();
54
55 // Parse the optional symbol operands.
56 auto *affineIntTy = parser->getBuilder().getAffineIntType();
57 if (parser->parseOperandList(
58 opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
59 parser->resolveOperands(opInfos, affineIntTy, operands))
60 return true;
61 return false;
62}
63
MLIR Team39a3a602018-07-24 17:43:56 -070064// TODO: Have verify functions return std::string to enable more descriptive
65// error messages.
Chris Lattner85ee1512018-07-25 11:15:20 -070066OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
67 SmallVector<OpAsmParser::OperandType, 2> ops;
68 Type *type;
69 SSAValue *lhs, *rhs;
70 if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
71 parser->resolveOperand(ops[0], type, lhs) ||
72 parser->resolveOperand(ops[1], type, rhs))
73 return {};
74
75 return OpAsmParserResult({lhs, rhs}, type);
76}
77
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070078void AddFOp::print(OpAsmPrinter *p) const {
79 *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
80 << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070081}
82
Chris Lattner21e67f62018-07-06 10:46:19 -070083// Return an error message on failure.
84const char *AddFOp::verify() const {
85 // TODO: Check that the types of the LHS and RHS match.
86 // TODO: This should be a refinement of TwoOperands.
87 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
88 return nullptr;
89}
90
Chris Lattner3164ae62018-07-28 09:36:25 -070091OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
92 SmallVector<OpAsmParser::OperandType, 2> opInfos;
93 SmallVector<SSAValue *, 4> operands;
94
95 auto &builder = parser->getBuilder();
96 auto *affineIntTy = builder.getAffineIntType();
97
98 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -070099 unsigned numDims;
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700100 if (parser->parseAttribute(mapAttr) ||
101 parseDimAndSymbolList(parser, opInfos, operands, numDims))
MLIR Team554a8ad2018-07-30 13:08:05 -0700102 return {};
Chris Lattner3164ae62018-07-28 09:36:25 -0700103 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700104
Chris Lattner3164ae62018-07-28 09:36:25 -0700105 if (map->getNumDims() != numDims ||
106 numDims + map->getNumSymbols() != opInfos.size()) {
107 parser->emitError(parser->getNameLoc(),
108 "dimension or symbol index mismatch");
109 return {};
110 }
111
112 SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
113 return OpAsmParserResult(
114 operands, resultTypes,
115 NamedAttribute(builder.getIdentifier("map"), mapAttr));
116}
117
118void AffineApplyOp::print(OpAsmPrinter *p) const {
119 auto *map = getAffineMap();
120 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700121 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner3164ae62018-07-28 09:36:25 -0700122}
123
124const char *AffineApplyOp::verify() const {
125 // Check that affine map attribute was specified.
126 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
127 if (!affineMapAttr)
128 return "requires an affine map.";
129
130 // Check input and output dimensions match.
131 auto *map = affineMapAttr->getValue();
132
133 // Verify that operand count matches affine map dimension and symbol count.
134 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
135 return "operand count and affine map dimension and symbol count must match";
136
137 // Verify that result count matches affine map result count.
138 if (getNumResults() != map->getNumResults())
139 return "result count and affine map result count must match";
140
141 return nullptr;
142}
143
MLIR Team554a8ad2018-07-30 13:08:05 -0700144void AllocOp::print(OpAsmPrinter *p) const {
145 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
146 *p << "alloc";
147 // Print dynamic dimension operands.
148 printDimAndSymbolList(operand_begin(), operand_end(),
149 type->getNumDynamicDims(), p);
150 // Print memref type.
151 *p << " : " << *type;
152}
153
154OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
155 MemRefType *type;
156 SmallVector<SSAValue *, 4> operands;
157 SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
158
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700159 // Parse the dimension operands and optional symbol operands, followed by a
160 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700161 unsigned numDimOperands;
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700162 if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
163 parser->parseColonType(type))
MLIR Team554a8ad2018-07-30 13:08:05 -0700164 return {};
165
166 // Check numDynamicDims against number of question marks in memref type.
167 if (numDimOperands != type->getNumDynamicDims()) {
168 parser->emitError(parser->getNameLoc(),
169 "Dynamic dimensions count mismatch: dimension operand "
170 "count does not equal memref dynamic dimension count.");
171 return {};
172 }
173
174 // Check that the number of symbol operands matches the number of symbols in
175 // the first affinemap of the memref's affine map composition.
176 // Note that a memref must specify at least one affine map in the composition.
177 if ((operandsInfo.size() - numDimOperands) !=
178 type->getAffineMaps()[0]->getNumSymbols()) {
179 parser->emitError(parser->getNameLoc(),
180 "AffineMap symbol count mismatch: symbol operand "
181 "count does not equal memref affine map symbol count.");
182 return {};
183 }
184
185 return OpAsmParserResult(operands, type);
186}
187
188const char *AllocOp::verify() const {
189 // TODO(andydavis): Verify alloc.
190 return nullptr;
191}
192
Chris Lattnerd4964212018-08-01 10:43:18 -0700193void ConstantOp::print(OpAsmPrinter *p) const {
194 *p << "constant " << *getValue() << " : " << *getType();
195}
196
197OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
198 Attribute *valueAttr;
199 Type *type;
200 if (parser->parseAttribute(valueAttr) || parser->parseColonType(type))
201 return {};
202
203 auto &builder = parser->getBuilder();
204 return OpAsmParserResult(
205 /*operands=*/{}, type,
206 NamedAttribute(builder.getIdentifier("value"), valueAttr));
207}
208
Chris Lattner9361fb32018-07-24 08:34:58 -0700209/// The constant op requires an attribute, and furthermore requires that it
210/// matches the return type.
211const char *ConstantOp::verify() const {
212 auto *value = getValue();
213 if (!value)
214 return "requires a 'value' attribute";
215
216 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700217 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700218 if (!isa<IntegerAttr>(value))
219 return "requires 'value' to be an integer for an integer result type";
220 return nullptr;
221 }
222
223 if (isa<FunctionType>(type)) {
224 // TODO: Verify a function attr.
225 }
226
227 return "requires a result type that aligns with the 'value' attribute";
228}
229
Chris Lattner3da86ad2018-07-26 09:58:23 -0700230/// ConstantIntOp only matches values whose result type is an IntegerType or
231/// AffineInt.
Chris Lattner9361fb32018-07-24 08:34:58 -0700232bool ConstantIntOp::isClassFor(const Operation *op) {
233 return ConstantOp::isClassFor(op) &&
Chris Lattner3da86ad2018-07-26 09:58:23 -0700234 (isa<IntegerType>(op->getResult(0)->getType()) ||
235 op->getResult(0)->getType()->isAffineInt());
Chris Lattner9361fb32018-07-24 08:34:58 -0700236}
237
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700238void DimOp::print(OpAsmPrinter *p) const {
239 *p << "dim " << *getOperand() << ", " << getIndex() << " : "
240 << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700241}
242
Chris Lattner85ee1512018-07-25 11:15:20 -0700243OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
244 OpAsmParser::OperandType operandInfo;
245 IntegerAttr *indexAttr;
246 Type *type;
247 SSAValue *operand;
248 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
249 parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
250 parser->resolveOperand(operandInfo, type, operand))
251 return {};
252
253 auto &builder = parser->getBuilder();
254 return OpAsmParserResult(
255 operand, builder.getAffineIntType(),
256 NamedAttribute(builder.getIdentifier("index"), indexAttr));
257}
258
Chris Lattner21e67f62018-07-06 10:46:19 -0700259const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700260 // Check that we have an integer index operand.
261 auto indexAttr = getAttrOfType<IntegerAttr>("index");
262 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700263 return "requires an integer attribute named 'index'";
264 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700265
Chris Lattner9361fb32018-07-24 08:34:58 -0700266 auto *type = getOperand()->getType();
267 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
268 if (index >= tensorType->getRank())
269 return "index is out of range";
270 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
271 if (index >= memrefType->getRank())
272 return "index is out of range";
273
274 } else if (isa<UnrankedTensorType>(type)) {
275 // ok, assumed to be in-range.
276 } else {
277 return "requires an operand with tensor or memref type";
278 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700279
280 return nullptr;
281}
282
Chris Lattner85ee1512018-07-25 11:15:20 -0700283void LoadOp::print(OpAsmPrinter *p) const {
284 *p << "load " << *getMemRef() << '[';
285 p->printOperands(getIndices());
286 *p << "] : " << *getMemRef()->getType();
287}
288
289OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
290 OpAsmParser::OperandType memrefInfo;
291 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
292 MemRefType *type;
293 SmallVector<SSAValue *, 4> operands;
294
295 auto affineIntTy = parser->getBuilder().getAffineIntType();
296 if (parser->parseOperand(memrefInfo) ||
297 parser->parseOperandList(indexInfo, -1,
298 OpAsmParser::Delimeter::SquareDelimeter) ||
299 parser->parseColonType(type) ||
300 parser->resolveOperands(memrefInfo, type, operands) ||
301 parser->resolveOperands(indexInfo, affineIntTy, operands))
302 return {};
303
304 return OpAsmParserResult(operands, type->getElementType());
305}
306
307const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700308 if (getNumOperands() == 0)
309 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700310
Chris Lattner3164ae62018-07-28 09:36:25 -0700311 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
312 if (!memRefType)
313 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700314
Chris Lattner3164ae62018-07-28 09:36:25 -0700315 for (auto *idx : getIndices())
316 if (!idx->getType()->isAffineInt())
317 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700318
Chris Lattner3164ae62018-07-28 09:36:25 -0700319 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700320
Chris Lattner3164ae62018-07-28 09:36:25 -0700321 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
322 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700323 return nullptr;
324}
325
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700326void StoreOp::print(OpAsmPrinter *p) const {
327 *p << "store " << *getValueToStore();
328 *p << ", " << *getMemRef() << '[';
329 p->printOperands(getIndices());
330 *p << "] : " << *getMemRef()->getType();
331}
332
333OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
334 OpAsmParser::OperandType storeValueInfo;
335 OpAsmParser::OperandType memrefInfo;
336 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
337 SmallVector<SSAValue *, 4> operands;
338 MemRefType *memrefType;
339
340 auto affineIntTy = parser->getBuilder().getAffineIntType();
341 if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
342 parser->parseOperand(memrefInfo) ||
343 parser->parseOperandList(indexInfo, -1,
344 OpAsmParser::Delimeter::SquareDelimeter) ||
345 parser->parseColonType(memrefType) ||
346 parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
347 operands) ||
348 parser->resolveOperands(memrefInfo, memrefType, operands) ||
349 parser->resolveOperands(indexInfo, affineIntTy, operands))
350 return {};
351
352 return OpAsmParserResult(operands, {});
353}
354
355const char *StoreOp::verify() const {
356 if (getNumOperands() < 2)
357 return "expected a value to store and a memref";
358
359 // Second operand is a memref type.
360 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
361 if (!memRefType)
362 return "second operand must be a memref";
363
364 // First operand must have same type as memref element type.
365 if (getValueToStore()->getType() != memRefType->getElementType())
366 return "first operand must have same type memref element type ";
367
368 if (getNumOperands() != 2 + memRefType->getRank())
369 return "store index operand count not equal to memref rank";
370
371 for (auto *idx : getIndices())
372 if (!idx->getType()->isAffineInt())
373 return "index to load must have 'affineint' type";
374
375 // TODO: Verify we have the right number of indices.
376
377 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
378 // result of an affine_apply.
379 return nullptr;
380}
381
Chris Lattnerff0d5902018-07-05 09:12:11 -0700382/// Install the standard operations in the specified operation set.
383void mlir::registerStandardOperations(OperationSet &opSet) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700384 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
385 StoreOp>(
386 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700387}