blob: 1458ce4f53a0f4ccaf93f99faac049633366e370 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team554a8ad2018-07-30 13:08:05 -070028static void printDimAndSymbolList(Operation::const_operand_iterator begin,
29 Operation::const_operand_iterator end,
30 unsigned numDims, OpAsmPrinter *p) {
31 *p << '(';
32 p->printOperands(begin, begin + numDims);
33 *p << ')';
34
35 if (begin + numDims != end) {
36 *p << '[';
37 p->printOperands(begin + numDims, end);
38 *p << ']';
39 }
40}
41
42// Parses dimension and symbol list, and sets 'numDims' to the number of
43// dimension operands parsed.
44// Returns 'false' on success and 'true' on error.
45static bool
46parseDimAndSymbolList(OpAsmParser *parser,
47 SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
48 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattner85cf26d2018-08-02 16:54:36 -070049 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070050 return true;
51 // Store number of dimensions for validation by caller.
52 numDims = opInfos.size();
53
54 // Parse the optional symbol operands.
55 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070056 if (parser->parseOperandList(opInfos, -1,
57 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070058 parser->resolveOperands(opInfos, affineIntTy, operands))
59 return true;
60 return false;
61}
62
MLIR Team39a3a602018-07-24 17:43:56 -070063// TODO: Have verify functions return std::string to enable more descriptive
64// error messages.
Chris Lattner85ee1512018-07-25 11:15:20 -070065OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
66 SmallVector<OpAsmParser::OperandType, 2> ops;
67 Type *type;
68 SSAValue *lhs, *rhs;
Chris Lattner85cf26d2018-08-02 16:54:36 -070069 SmallVector<NamedAttribute, 4> attrs;
70 if (parser->parseOperandList(ops, 2) ||
71 parser->parseOptionalAttributeDict(attrs) ||
72 parser->parseColonType(type) ||
Chris Lattner85ee1512018-07-25 11:15:20 -070073 parser->resolveOperand(ops[0], type, lhs) ||
74 parser->resolveOperand(ops[1], type, rhs))
75 return {};
76
Chris Lattner85cf26d2018-08-02 16:54:36 -070077 return OpAsmParserResult({lhs, rhs}, type, attrs);
Chris Lattner85ee1512018-07-25 11:15:20 -070078}
79
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070080void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070081 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
82 p->printOptionalAttrDict(getAttrs());
83 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070084}
85
Chris Lattner21e67f62018-07-06 10:46:19 -070086// Return an error message on failure.
87const char *AddFOp::verify() const {
88 // TODO: Check that the types of the LHS and RHS match.
89 // TODO: This should be a refinement of TwoOperands.
90 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
91 return nullptr;
92}
93
Chris Lattner3164ae62018-07-28 09:36:25 -070094OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
95 SmallVector<OpAsmParser::OperandType, 2> opInfos;
96 SmallVector<SSAValue *, 4> operands;
Chris Lattner85cf26d2018-08-02 16:54:36 -070097 SmallVector<NamedAttribute, 4> attrs;
Chris Lattner3164ae62018-07-28 09:36:25 -070098
99 auto &builder = parser->getBuilder();
100 auto *affineIntTy = builder.getAffineIntType();
101
102 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -0700103 unsigned numDims;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700104 if (parser->parseAttribute(mapAttr, "map", attrs) ||
105 parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
106 parser->parseOptionalAttributeDict(attrs))
MLIR Team554a8ad2018-07-30 13:08:05 -0700107 return {};
Chris Lattner3164ae62018-07-28 09:36:25 -0700108 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700109
Chris Lattner3164ae62018-07-28 09:36:25 -0700110 if (map->getNumDims() != numDims ||
111 numDims + map->getNumSymbols() != opInfos.size()) {
112 parser->emitError(parser->getNameLoc(),
113 "dimension or symbol index mismatch");
114 return {};
115 }
116
117 SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700118 return OpAsmParserResult(operands, resultTypes, attrs);
Chris Lattner3164ae62018-07-28 09:36:25 -0700119}
120
121void AffineApplyOp::print(OpAsmPrinter *p) const {
122 auto *map = getAffineMap();
123 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700124 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700125 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700126}
127
128const char *AffineApplyOp::verify() const {
129 // Check that affine map attribute was specified.
130 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
131 if (!affineMapAttr)
132 return "requires an affine map.";
133
134 // Check input and output dimensions match.
135 auto *map = affineMapAttr->getValue();
136
137 // Verify that operand count matches affine map dimension and symbol count.
138 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
139 return "operand count and affine map dimension and symbol count must match";
140
141 // Verify that result count matches affine map result count.
142 if (getNumResults() != map->getNumResults())
143 return "result count and affine map result count must match";
144
145 return nullptr;
146}
147
MLIR Team554a8ad2018-07-30 13:08:05 -0700148void AllocOp::print(OpAsmPrinter *p) const {
149 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
150 *p << "alloc";
151 // Print dynamic dimension operands.
152 printDimAndSymbolList(operand_begin(), operand_end(),
153 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700154 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700155 *p << " : " << *type;
156}
157
158OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
159 MemRefType *type;
160 SmallVector<SSAValue *, 4> operands;
161 SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700162 SmallVector<NamedAttribute, 4> attrs;
MLIR Team554a8ad2018-07-30 13:08:05 -0700163
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700164 // Parse the dimension operands and optional symbol operands, followed by a
165 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700166 unsigned numDimOperands;
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700167 if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700168 parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
MLIR Team554a8ad2018-07-30 13:08:05 -0700169 return {};
170
171 // Check numDynamicDims against number of question marks in memref type.
172 if (numDimOperands != type->getNumDynamicDims()) {
173 parser->emitError(parser->getNameLoc(),
174 "Dynamic dimensions count mismatch: dimension operand "
175 "count does not equal memref dynamic dimension count.");
176 return {};
177 }
178
179 // Check that the number of symbol operands matches the number of symbols in
180 // the first affinemap of the memref's affine map composition.
181 // Note that a memref must specify at least one affine map in the composition.
182 if ((operandsInfo.size() - numDimOperands) !=
183 type->getAffineMaps()[0]->getNumSymbols()) {
184 parser->emitError(parser->getNameLoc(),
185 "AffineMap symbol count mismatch: symbol operand "
186 "count does not equal memref affine map symbol count.");
187 return {};
188 }
189
Chris Lattner85cf26d2018-08-02 16:54:36 -0700190 return OpAsmParserResult(operands, type, attrs);
MLIR Team554a8ad2018-07-30 13:08:05 -0700191}
192
193const char *AllocOp::verify() const {
194 // TODO(andydavis): Verify alloc.
195 return nullptr;
196}
197
Chris Lattnerd4964212018-08-01 10:43:18 -0700198void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700199 *p << "constant " << *getValue();
200 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
201 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700202}
203
204OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
205 Attribute *valueAttr;
206 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700207 SmallVector<NamedAttribute, 4> attrs;
Chris Lattnerd4964212018-08-01 10:43:18 -0700208
Chris Lattner85cf26d2018-08-02 16:54:36 -0700209 if (parser->parseAttribute(valueAttr, "value", attrs) ||
210 parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
211 return {};
212 return OpAsmParserResult(/*operands=*/{}, type, attrs);
Chris Lattnerd4964212018-08-01 10:43:18 -0700213}
214
Chris Lattner9361fb32018-07-24 08:34:58 -0700215/// The constant op requires an attribute, and furthermore requires that it
216/// matches the return type.
217const char *ConstantOp::verify() const {
218 auto *value = getValue();
219 if (!value)
220 return "requires a 'value' attribute";
221
222 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700223 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700224 if (!isa<IntegerAttr>(value))
225 return "requires 'value' to be an integer for an integer result type";
226 return nullptr;
227 }
228
229 if (isa<FunctionType>(type)) {
230 // TODO: Verify a function attr.
231 }
232
233 return "requires a result type that aligns with the 'value' attribute";
234}
235
Chris Lattner3da86ad2018-07-26 09:58:23 -0700236/// ConstantIntOp only matches values whose result type is an IntegerType or
237/// AffineInt.
Chris Lattner9361fb32018-07-24 08:34:58 -0700238bool ConstantIntOp::isClassFor(const Operation *op) {
239 return ConstantOp::isClassFor(op) &&
Chris Lattner3da86ad2018-07-26 09:58:23 -0700240 (isa<IntegerType>(op->getResult(0)->getType()) ||
241 op->getResult(0)->getType()->isAffineInt());
Chris Lattner9361fb32018-07-24 08:34:58 -0700242}
243
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700244void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700245 *p << "dim " << *getOperand() << ", " << getIndex();
246 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
247 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700248}
249
Chris Lattner85ee1512018-07-25 11:15:20 -0700250OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
251 OpAsmParser::OperandType operandInfo;
252 IntegerAttr *indexAttr;
253 Type *type;
254 SSAValue *operand;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700255 SmallVector<NamedAttribute, 4> attrs;
256
Chris Lattner85ee1512018-07-25 11:15:20 -0700257 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700258 parser->parseAttribute(indexAttr, "index", attrs) ||
259 parser->parseOptionalAttributeDict(attrs) ||
260 parser->parseColonType(type) ||
Chris Lattner85ee1512018-07-25 11:15:20 -0700261 parser->resolveOperand(operandInfo, type, operand))
262 return {};
263
264 auto &builder = parser->getBuilder();
Chris Lattner85cf26d2018-08-02 16:54:36 -0700265 return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
Chris Lattner85ee1512018-07-25 11:15:20 -0700266}
267
Chris Lattner21e67f62018-07-06 10:46:19 -0700268const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700269 // Check that we have an integer index operand.
270 auto indexAttr = getAttrOfType<IntegerAttr>("index");
271 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700272 return "requires an integer attribute named 'index'";
273 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700274
Chris Lattner9361fb32018-07-24 08:34:58 -0700275 auto *type = getOperand()->getType();
276 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
277 if (index >= tensorType->getRank())
278 return "index is out of range";
279 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
280 if (index >= memrefType->getRank())
281 return "index is out of range";
282
283 } else if (isa<UnrankedTensorType>(type)) {
284 // ok, assumed to be in-range.
285 } else {
286 return "requires an operand with tensor or memref type";
287 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700288
289 return nullptr;
290}
291
Chris Lattner85ee1512018-07-25 11:15:20 -0700292void LoadOp::print(OpAsmPrinter *p) const {
293 *p << "load " << *getMemRef() << '[';
294 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700295 *p << ']';
296 p->printOptionalAttrDict(getAttrs());
297 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700298}
299
300OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
301 OpAsmParser::OperandType memrefInfo;
302 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
303 MemRefType *type;
304 SmallVector<SSAValue *, 4> operands;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700305 SmallVector<NamedAttribute, 4> attrs;
Chris Lattner85ee1512018-07-25 11:15:20 -0700306
307 auto affineIntTy = parser->getBuilder().getAffineIntType();
308 if (parser->parseOperand(memrefInfo) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700309 parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
310 parser->parseOptionalAttributeDict(attrs) ||
Chris Lattner85ee1512018-07-25 11:15:20 -0700311 parser->parseColonType(type) ||
312 parser->resolveOperands(memrefInfo, type, operands) ||
313 parser->resolveOperands(indexInfo, affineIntTy, operands))
314 return {};
315
Chris Lattner85cf26d2018-08-02 16:54:36 -0700316 return OpAsmParserResult(operands, type->getElementType(), attrs);
Chris Lattner85ee1512018-07-25 11:15:20 -0700317}
318
319const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700320 if (getNumOperands() == 0)
321 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700322
Chris Lattner3164ae62018-07-28 09:36:25 -0700323 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
324 if (!memRefType)
325 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700326
Chris Lattner3164ae62018-07-28 09:36:25 -0700327 for (auto *idx : getIndices())
328 if (!idx->getType()->isAffineInt())
329 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700330
Chris Lattner3164ae62018-07-28 09:36:25 -0700331 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700332
Chris Lattner3164ae62018-07-28 09:36:25 -0700333 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
334 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700335 return nullptr;
336}
337
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700338void StoreOp::print(OpAsmPrinter *p) const {
339 *p << "store " << *getValueToStore();
340 *p << ", " << *getMemRef() << '[';
341 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700342 *p << ']';
343 p->printOptionalAttrDict(getAttrs());
344 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700345}
346
347OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
348 OpAsmParser::OperandType storeValueInfo;
349 OpAsmParser::OperandType memrefInfo;
350 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
351 SmallVector<SSAValue *, 4> operands;
352 MemRefType *memrefType;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700353 SmallVector<NamedAttribute, 4> attrs;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700354
355 auto affineIntTy = parser->getBuilder().getAffineIntType();
356 if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
357 parser->parseOperand(memrefInfo) ||
Chris Lattner85cf26d2018-08-02 16:54:36 -0700358 parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
359 parser->parseOptionalAttributeDict(attrs) ||
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700360 parser->parseColonType(memrefType) ||
361 parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
362 operands) ||
363 parser->resolveOperands(memrefInfo, memrefType, operands) ||
364 parser->resolveOperands(indexInfo, affineIntTy, operands))
365 return {};
366
Chris Lattner85cf26d2018-08-02 16:54:36 -0700367 return OpAsmParserResult(operands, {}, attrs);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700368}
369
370const char *StoreOp::verify() const {
371 if (getNumOperands() < 2)
372 return "expected a value to store and a memref";
373
374 // Second operand is a memref type.
375 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
376 if (!memRefType)
377 return "second operand must be a memref";
378
379 // First operand must have same type as memref element type.
380 if (getValueToStore()->getType() != memRefType->getElementType())
381 return "first operand must have same type memref element type ";
382
383 if (getNumOperands() != 2 + memRefType->getRank())
384 return "store index operand count not equal to memref rank";
385
386 for (auto *idx : getIndices())
387 if (!idx->getType()->isAffineInt())
388 return "index to load must have 'affineint' type";
389
390 // TODO: Verify we have the right number of indices.
391
392 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
393 // result of an affine_apply.
394 return nullptr;
395}
396
Chris Lattnerff0d5902018-07-05 09:12:11 -0700397/// Install the standard operations in the specified operation set.
398void mlir::registerStandardOperations(OperationSet &opSet) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700399 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
400 StoreOp>(
401 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700402}