blob: 95d7931d978b6693263fef07c1971452390aafa0 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team554a8ad2018-07-30 13:08:05 -070028static void printDimAndSymbolList(Operation::const_operand_iterator begin,
29 Operation::const_operand_iterator end,
30 unsigned numDims, OpAsmPrinter *p) {
31 *p << '(';
32 p->printOperands(begin, begin + numDims);
33 *p << ')';
34
35 if (begin + numDims != end) {
36 *p << '[';
37 p->printOperands(begin + numDims, end);
38 *p << ']';
39 }
40}
41
42// Parses dimension and symbol list, and sets 'numDims' to the number of
43// dimension operands parsed.
44// Returns 'false' on success and 'true' on error.
45static bool
46parseDimAndSymbolList(OpAsmParser *parser,
MLIR Team554a8ad2018-07-30 13:08:05 -070047 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -070048 SmallVector<OpAsmParser::OperandType, 8> opInfos;
Chris Lattner85cf26d2018-08-02 16:54:36 -070049 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070050 return true;
51 // Store number of dimensions for validation by caller.
52 numDims = opInfos.size();
53
54 // Parse the optional symbol operands.
55 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070056 if (parser->parseOperandList(opInfos, -1,
57 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070058 parser->resolveOperands(opInfos, affineIntTy, operands))
59 return true;
60 return false;
61}
62
Chris Lattnereed6c4d2018-08-07 09:12:35 -070063bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -070064 SmallVector<OpAsmParser::OperandType, 2> ops;
65 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -070066 if (parser->parseOperandList(ops, 2) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -070067 parser->parseOptionalAttributeDict(result->attributes) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -070068 parser->parseColonType(type) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -070069 parser->resolveOperands(ops, type, result->operands))
70 return true;
Chris Lattner85ee1512018-07-25 11:15:20 -070071
Chris Lattnereed6c4d2018-08-07 09:12:35 -070072 // TODO(clattner): rework parseColonType to eliminate the need for this.
73 result->types.push_back(type);
74 return false;
Chris Lattner85ee1512018-07-25 11:15:20 -070075}
76
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070077void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070078 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
79 p->printOptionalAttrDict(getAttrs());
80 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070081}
82
Chris Lattnereed6c4d2018-08-07 09:12:35 -070083// TODO: Have verify functions return std::string to enable more descriptive
84// error messages.
Chris Lattner21e67f62018-07-06 10:46:19 -070085// Return an error message on failure.
86const char *AddFOp::verify() const {
87 // TODO: Check that the types of the LHS and RHS match.
88 // TODO: This should be a refinement of TwoOperands.
89 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
90 return nullptr;
91}
92
Chris Lattnereed6c4d2018-08-07 09:12:35 -070093bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner3164ae62018-07-28 09:36:25 -070094 auto &builder = parser->getBuilder();
95 auto *affineIntTy = builder.getAffineIntType();
96
97 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -070098 unsigned numDims;
Chris Lattnereed6c4d2018-08-07 09:12:35 -070099 if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
100 parseDimAndSymbolList(parser, result->operands, numDims) ||
101 parser->parseOptionalAttributeDict(result->attributes))
102 return true;
Chris Lattner3164ae62018-07-28 09:36:25 -0700103 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700104
Chris Lattner3164ae62018-07-28 09:36:25 -0700105 if (map->getNumDims() != numDims ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700106 numDims + map->getNumSymbols() != result->operands.size()) {
107 return parser->emitError(parser->getNameLoc(),
108 "dimension or symbol index mismatch");
Chris Lattner3164ae62018-07-28 09:36:25 -0700109 }
110
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700111 result->types.append(map->getNumResults(), affineIntTy);
112 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700113}
114
115void AffineApplyOp::print(OpAsmPrinter *p) const {
116 auto *map = getAffineMap();
117 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700118 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700119 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700120}
121
122const char *AffineApplyOp::verify() const {
123 // Check that affine map attribute was specified.
124 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
125 if (!affineMapAttr)
126 return "requires an affine map.";
127
128 // Check input and output dimensions match.
129 auto *map = affineMapAttr->getValue();
130
131 // Verify that operand count matches affine map dimension and symbol count.
132 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
133 return "operand count and affine map dimension and symbol count must match";
134
135 // Verify that result count matches affine map result count.
136 if (getNumResults() != map->getNumResults())
137 return "result count and affine map result count must match";
138
139 return nullptr;
140}
141
MLIR Team554a8ad2018-07-30 13:08:05 -0700142void AllocOp::print(OpAsmPrinter *p) const {
143 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
144 *p << "alloc";
145 // Print dynamic dimension operands.
146 printDimAndSymbolList(operand_begin(), operand_end(),
147 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700148 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700149 *p << " : " << *type;
150}
151
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700152bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Team554a8ad2018-07-30 13:08:05 -0700153 MemRefType *type;
MLIR Team554a8ad2018-07-30 13:08:05 -0700154
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700155 // Parse the dimension operands and optional symbol operands, followed by a
156 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700157 unsigned numDimOperands;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700158 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
159 parser->parseOptionalAttributeDict(result->attributes) ||
160 parser->parseColonType(type))
161 return true;
MLIR Team554a8ad2018-07-30 13:08:05 -0700162
163 // Check numDynamicDims against number of question marks in memref type.
164 if (numDimOperands != type->getNumDynamicDims()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700165 return parser->emitError(parser->getNameLoc(),
166 "dimension operand count does not equal memref "
167 "dynamic dimension count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700168 }
169
170 // Check that the number of symbol operands matches the number of symbols in
171 // the first affinemap of the memref's affine map composition.
172 // Note that a memref must specify at least one affine map in the composition.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700173 if (result->operands.size() - numDimOperands !=
MLIR Team554a8ad2018-07-30 13:08:05 -0700174 type->getAffineMaps()[0]->getNumSymbols()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700175 return parser->emitError(
176 parser->getNameLoc(),
177 "affine map symbol operand count does not equal memref affine map "
178 "symbol count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700179 }
180
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700181 result->types.push_back(type);
182 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700183}
184
185const char *AllocOp::verify() const {
186 // TODO(andydavis): Verify alloc.
187 return nullptr;
188}
189
Chris Lattnerd4964212018-08-01 10:43:18 -0700190void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700191 *p << "constant " << *getValue();
192 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
193 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700194}
195
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700196bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattnerd4964212018-08-01 10:43:18 -0700197 Attribute *valueAttr;
198 Type *type;
Chris Lattnerd4964212018-08-01 10:43:18 -0700199
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700200 if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
201 parser->parseOptionalAttributeDict(result->attributes) ||
202 parser->parseColonType(type))
203 return true;
204
205 result->types.push_back(type);
206 return false;
Chris Lattnerd4964212018-08-01 10:43:18 -0700207}
208
Chris Lattner9361fb32018-07-24 08:34:58 -0700209/// The constant op requires an attribute, and furthermore requires that it
210/// matches the return type.
211const char *ConstantOp::verify() const {
212 auto *value = getValue();
213 if (!value)
214 return "requires a 'value' attribute";
215
216 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700217 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700218 if (!isa<IntegerAttr>(value))
219 return "requires 'value' to be an integer for an integer result type";
220 return nullptr;
221 }
222
223 if (isa<FunctionType>(type)) {
224 // TODO: Verify a function attr.
225 }
226
227 return "requires a result type that aligns with the 'value' attribute";
228}
229
Chris Lattner992a1272018-08-07 12:02:37 -0700230/// ConstantIntOp only matches values whose result type is an IntegerType.
Chris Lattner9361fb32018-07-24 08:34:58 -0700231bool ConstantIntOp::isClassFor(const Operation *op) {
232 return ConstantOp::isClassFor(op) &&
Chris Lattner992a1272018-08-07 12:02:37 -0700233 isa<IntegerType>(op->getResult(0)->getType());
234}
235
236OperationState ConstantIntOp::build(Builder *builder, int64_t value,
237 unsigned width) {
238 OperationState result(builder->getIdentifier("constant"));
239 result.attributes.push_back(
240 {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
241 result.types.push_back(builder->getIntegerType(width));
242 return result;
243}
244
245/// ConstantAffineIntOp only matches values whose result type is AffineInt.
246bool ConstantAffineIntOp::isClassFor(const Operation *op) {
247 return ConstantOp::isClassFor(op) &&
248 op->getResult(0)->getType()->isAffineInt();
249}
250
251OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
252 OperationState result(builder->getIdentifier("constant"));
253 result.attributes.push_back(
254 {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
255 result.types.push_back(builder->getAffineIntType());
256 return result;
Chris Lattner9361fb32018-07-24 08:34:58 -0700257}
258
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700259void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700260 *p << "dim " << *getOperand() << ", " << getIndex();
261 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
262 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700263}
264
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700265bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700266 OpAsmParser::OperandType operandInfo;
267 IntegerAttr *indexAttr;
268 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700269
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700270 // TODO(clattner): remove resolveOperand or change it to push onto the
271 // operands list.
Chris Lattner85ee1512018-07-25 11:15:20 -0700272 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700273 parser->parseAttribute(indexAttr, "index", result->attributes) ||
274 parser->parseOptionalAttributeDict(result->attributes) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700275 parser->parseColonType(type) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700276 parser->resolveOperands(operandInfo, type, result->operands))
277 return true;
Chris Lattner85ee1512018-07-25 11:15:20 -0700278
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700279 result->types.push_back(parser->getBuilder().getAffineIntType());
280 return false;
Chris Lattner85ee1512018-07-25 11:15:20 -0700281}
282
Chris Lattner21e67f62018-07-06 10:46:19 -0700283const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700284 // Check that we have an integer index operand.
285 auto indexAttr = getAttrOfType<IntegerAttr>("index");
286 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700287 return "requires an integer attribute named 'index'";
288 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700289
Chris Lattner9361fb32018-07-24 08:34:58 -0700290 auto *type = getOperand()->getType();
291 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
292 if (index >= tensorType->getRank())
293 return "index is out of range";
294 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
295 if (index >= memrefType->getRank())
296 return "index is out of range";
297
298 } else if (isa<UnrankedTensorType>(type)) {
299 // ok, assumed to be in-range.
300 } else {
301 return "requires an operand with tensor or memref type";
302 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700303
304 return nullptr;
305}
306
Chris Lattner85ee1512018-07-25 11:15:20 -0700307void LoadOp::print(OpAsmPrinter *p) const {
308 *p << "load " << *getMemRef() << '[';
309 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700310 *p << ']';
311 p->printOptionalAttrDict(getAttrs());
312 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700313}
314
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700315bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700316 OpAsmParser::OperandType memrefInfo;
317 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
318 MemRefType *type;
Chris Lattner85ee1512018-07-25 11:15:20 -0700319
320 auto affineIntTy = parser->getBuilder().getAffineIntType();
321 if (parser->parseOperand(memrefInfo) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700322 parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700323 parser->parseOptionalAttributeDict(result->attributes) ||
Chris Lattner85ee1512018-07-25 11:15:20 -0700324 parser->parseColonType(type) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700325 // TODO: use a new resolveOperand()
326 parser->resolveOperands(memrefInfo, type, result->operands) ||
327 parser->resolveOperands(indexInfo, affineIntTy, result->operands))
328 return true;
Chris Lattner85ee1512018-07-25 11:15:20 -0700329
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700330 result->types.push_back(type->getElementType());
331 return false;
Chris Lattner85ee1512018-07-25 11:15:20 -0700332}
333
334const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700335 if (getNumOperands() == 0)
336 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700337
Chris Lattner3164ae62018-07-28 09:36:25 -0700338 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
339 if (!memRefType)
340 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700341
Chris Lattner3164ae62018-07-28 09:36:25 -0700342 for (auto *idx : getIndices())
343 if (!idx->getType()->isAffineInt())
344 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700345
Chris Lattner3164ae62018-07-28 09:36:25 -0700346 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700347
Chris Lattner3164ae62018-07-28 09:36:25 -0700348 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
349 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700350 return nullptr;
351}
352
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700353void StoreOp::print(OpAsmPrinter *p) const {
354 *p << "store " << *getValueToStore();
355 *p << ", " << *getMemRef() << '[';
356 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700357 *p << ']';
358 p->printOptionalAttrDict(getAttrs());
359 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700360}
361
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700362bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700363 OpAsmParser::OperandType storeValueInfo;
364 OpAsmParser::OperandType memrefInfo;
365 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700366 MemRefType *memrefType;
367
368 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700369 return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
370 parser->parseOperand(memrefInfo) ||
371 parser->parseOperandList(indexInfo, -1,
372 OpAsmParser::Delimiter::Square) ||
373 parser->parseOptionalAttributeDict(result->attributes) ||
374 parser->parseColonType(memrefType) ||
375 parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
376 result->operands) ||
377 // TODO: use a new resolveOperand().
378 parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
379 parser->resolveOperands(indexInfo, affineIntTy, result->operands);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700380}
381
382const char *StoreOp::verify() const {
383 if (getNumOperands() < 2)
384 return "expected a value to store and a memref";
385
386 // Second operand is a memref type.
387 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
388 if (!memRefType)
389 return "second operand must be a memref";
390
391 // First operand must have same type as memref element type.
392 if (getValueToStore()->getType() != memRefType->getElementType())
393 return "first operand must have same type memref element type ";
394
395 if (getNumOperands() != 2 + memRefType->getRank())
396 return "store index operand count not equal to memref rank";
397
398 for (auto *idx : getIndices())
399 if (!idx->getType()->isAffineInt())
400 return "index to load must have 'affineint' type";
401
402 // TODO: Verify we have the right number of indices.
403
404 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
405 // result of an affine_apply.
406 return nullptr;
407}
408
Chris Lattnerff0d5902018-07-05 09:12:11 -0700409/// Install the standard operations in the specified operation set.
410void mlir::registerStandardOperations(OperationSet &opSet) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700411 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
412 StoreOp>(
413 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700414}