blob: 427de7ab845a7e62171e0458400545ab3779fcf8 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070025#include "mlir/Support/STLExtras.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070026#include "llvm/Support/raw_ostream.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070027
Chris Lattnerff0d5902018-07-05 09:12:11 -070028using namespace mlir;
29
MLIR Team554a8ad2018-07-30 13:08:05 -070030static void printDimAndSymbolList(Operation::const_operand_iterator begin,
31 Operation::const_operand_iterator end,
32 unsigned numDims, OpAsmPrinter *p) {
33 *p << '(';
34 p->printOperands(begin, begin + numDims);
35 *p << ')';
36
37 if (begin + numDims != end) {
38 *p << '[';
39 p->printOperands(begin + numDims, end);
40 *p << ']';
41 }
42}
43
44// Parses dimension and symbol list, and sets 'numDims' to the number of
45// dimension operands parsed.
46// Returns 'false' on success and 'true' on error.
47static bool
48parseDimAndSymbolList(OpAsmParser *parser,
MLIR Team554a8ad2018-07-30 13:08:05 -070049 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -070050 SmallVector<OpAsmParser::OperandType, 8> opInfos;
Chris Lattner85cf26d2018-08-02 16:54:36 -070051 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070052 return true;
53 // Store number of dimensions for validation by caller.
54 numDims = opInfos.size();
55
56 // Parse the optional symbol operands.
57 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070058 if (parser->parseOperandList(opInfos, -1,
59 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070060 parser->resolveOperands(opInfos, affineIntTy, operands))
61 return true;
62 return false;
63}
64
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070065//===----------------------------------------------------------------------===//
66// AddFOp
67//===----------------------------------------------------------------------===//
68
Chris Lattner1eb77482018-08-22 19:25:49 -070069void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
70 SSAValue *rhs) {
71 assert(lhs->getType() == rhs->getType());
72 result->addOperands({lhs, rhs});
73 result->types.push_back(lhs->getType());
74}
75
Chris Lattnereed6c4d2018-08-07 09:12:35 -070076bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -070077 SmallVector<OpAsmParser::OperandType, 2> ops;
78 Type *type;
Chris Lattner8bdbebf2018-08-08 11:02:58 -070079 return parser->parseOperandList(ops, 2) ||
80 parser->parseOptionalAttributeDict(result->attributes) ||
81 parser->parseColonType(type) ||
82 parser->resolveOperands(ops, type, result->operands) ||
83 parser->addTypeToList(type, result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -070084}
85
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070086void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070087 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
88 p->printOptionalAttrDict(getAttrs());
89 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070090}
91
Chris Lattnereed6c4d2018-08-07 09:12:35 -070092// TODO: Have verify functions return std::string to enable more descriptive
93// error messages.
Chris Lattner21e67f62018-07-06 10:46:19 -070094// Return an error message on failure.
95const char *AddFOp::verify() const {
96 // TODO: Check that the types of the LHS and RHS match.
97 // TODO: This should be a refinement of TwoOperands.
98 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
99 return nullptr;
100}
101
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700102//===----------------------------------------------------------------------===//
103// AffineApplyOp
104//===----------------------------------------------------------------------===//
105
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700106bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner3164ae62018-07-28 09:36:25 -0700107 auto &builder = parser->getBuilder();
108 auto *affineIntTy = builder.getAffineIntType();
109
110 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -0700111 unsigned numDims;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700112 if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
113 parseDimAndSymbolList(parser, result->operands, numDims) ||
114 parser->parseOptionalAttributeDict(result->attributes))
115 return true;
Chris Lattner3164ae62018-07-28 09:36:25 -0700116 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700117
Chris Lattner3164ae62018-07-28 09:36:25 -0700118 if (map->getNumDims() != numDims ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700119 numDims + map->getNumSymbols() != result->operands.size()) {
120 return parser->emitError(parser->getNameLoc(),
121 "dimension or symbol index mismatch");
Chris Lattner3164ae62018-07-28 09:36:25 -0700122 }
123
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700124 result->types.append(map->getNumResults(), affineIntTy);
125 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700126}
127
128void AffineApplyOp::print(OpAsmPrinter *p) const {
129 auto *map = getAffineMap();
130 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700131 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700132 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700133}
134
135const char *AffineApplyOp::verify() const {
136 // Check that affine map attribute was specified.
137 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
138 if (!affineMapAttr)
Chris Lattner1aa46322018-08-21 17:55:22 -0700139 return "requires an affine map";
Chris Lattner3164ae62018-07-28 09:36:25 -0700140
141 // Check input and output dimensions match.
142 auto *map = affineMapAttr->getValue();
143
144 // Verify that operand count matches affine map dimension and symbol count.
145 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
146 return "operand count and affine map dimension and symbol count must match";
147
148 // Verify that result count matches affine map result count.
149 if (getNumResults() != map->getNumResults())
150 return "result count and affine map result count must match";
151
152 return nullptr;
153}
154
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700155// The result of the affine apply operation can be used as a dimension id if it
156// is a CFG value or if it is an MLValue, and all the operands are valid
157// dimension ids.
158bool AffineApplyOp::isValidDim() const {
159 for (auto *op : getOperands()) {
160 if (auto *v = dyn_cast<MLValue>(op))
161 if (!v->isValidDim())
162 return false;
163 }
164 return true;
165}
166
167// The result of the affine apply operation can be used as a symbol if it is
168// a CFG value or if it is an MLValue, and all the operands are symbols.
169bool AffineApplyOp::isValidSymbol() const {
170 for (auto *op : getOperands()) {
171 if (auto *v = dyn_cast<MLValue>(op))
172 if (!v->isValidSymbol())
173 return false;
174 }
175 return true;
176}
177
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700178//===----------------------------------------------------------------------===//
179// AllocOp
180//===----------------------------------------------------------------------===//
181
MLIR Team554a8ad2018-07-30 13:08:05 -0700182void AllocOp::print(OpAsmPrinter *p) const {
183 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
184 *p << "alloc";
185 // Print dynamic dimension operands.
186 printDimAndSymbolList(operand_begin(), operand_end(),
187 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700188 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700189 *p << " : " << *type;
190}
191
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700192bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Team554a8ad2018-07-30 13:08:05 -0700193 MemRefType *type;
MLIR Team554a8ad2018-07-30 13:08:05 -0700194
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700195 // Parse the dimension operands and optional symbol operands, followed by a
196 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700197 unsigned numDimOperands;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700198 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
199 parser->parseOptionalAttributeDict(result->attributes) ||
200 parser->parseColonType(type))
201 return true;
MLIR Team554a8ad2018-07-30 13:08:05 -0700202
203 // Check numDynamicDims against number of question marks in memref type.
204 if (numDimOperands != type->getNumDynamicDims()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700205 return parser->emitError(parser->getNameLoc(),
206 "dimension operand count does not equal memref "
207 "dynamic dimension count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700208 }
209
210 // Check that the number of symbol operands matches the number of symbols in
211 // the first affinemap of the memref's affine map composition.
212 // Note that a memref must specify at least one affine map in the composition.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700213 if (result->operands.size() - numDimOperands !=
MLIR Team554a8ad2018-07-30 13:08:05 -0700214 type->getAffineMaps()[0]->getNumSymbols()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700215 return parser->emitError(
216 parser->getNameLoc(),
217 "affine map symbol operand count does not equal memref affine map "
218 "symbol count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700219 }
220
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700221 result->types.push_back(type);
222 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700223}
224
225const char *AllocOp::verify() const {
226 // TODO(andydavis): Verify alloc.
227 return nullptr;
228}
229
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700230//===----------------------------------------------------------------------===//
Chris Lattner1aa46322018-08-21 17:55:22 -0700231// CallOp
232//===----------------------------------------------------------------------===//
233
Chris Lattner1eb77482018-08-22 19:25:49 -0700234void CallOp::build(Builder *builder, OperationState *result, Function *callee,
235 ArrayRef<SSAValue *> operands) {
236 result->addOperands(operands);
237 result->addAttribute("callee", builder->getFunctionAttr(callee));
238 result->addTypes(callee->getType()->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700239}
240
241bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
242 StringRef calleeName;
243 llvm::SMLoc calleeLoc;
244 FunctionType *calleeType = nullptr;
245 SmallVector<OpAsmParser::OperandType, 4> operands;
246 Function *callee = nullptr;
247 if (parser->parseFunctionName(calleeName, calleeLoc) ||
248 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
249 OpAsmParser::Delimiter::Paren) ||
250 parser->parseOptionalAttributeDict(result->attributes) ||
251 parser->parseColonType(calleeType) ||
252 parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
253 parser->addTypesToList(calleeType->getResults(), result->types) ||
254 parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
255 result->operands))
256 return true;
257
Chris Lattner1eb77482018-08-22 19:25:49 -0700258 result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
Chris Lattner1aa46322018-08-21 17:55:22 -0700259 return false;
260}
261
262void CallOp::print(OpAsmPrinter *p) const {
263 *p << "call ";
264 p->printFunctionReference(getCallee());
265 *p << '(';
266 p->printOperands(getOperands());
267 *p << ')';
268 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
269 *p << " : " << *getCallee()->getType();
270}
271
272const char *CallOp::verify() const {
273 // Check that the callee attribute was specified.
274 auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
275 if (!fnAttr)
276 return "requires a 'callee' function attribute";
277
278 // Verify that the operand and result types match the callee.
279 auto *fnType = fnAttr->getValue()->getType();
280 if (fnType->getNumInputs() != getNumOperands())
281 return "incorrect number of operands for callee";
282
283 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
284 if (getOperand(i)->getType() != fnType->getInput(i))
285 return "operand type mismatch";
286 }
287
288 if (fnType->getNumResults() != getNumResults())
289 return "incorrect number of results for callee";
290
291 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
292 if (getResult(i)->getType() != fnType->getResult(i))
293 return "result type mismatch";
294 }
295
296 return nullptr;
297}
298
299//===----------------------------------------------------------------------===//
300// CallIndirectOp
301//===----------------------------------------------------------------------===//
302
Chris Lattner1eb77482018-08-22 19:25:49 -0700303void CallIndirectOp::build(Builder *builder, OperationState *result,
304 SSAValue *callee, ArrayRef<SSAValue *> operands) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700305 auto *fnType = cast<FunctionType>(callee->getType());
Chris Lattner1eb77482018-08-22 19:25:49 -0700306 result->operands.push_back(callee);
307 result->addOperands(operands);
308 result->addTypes(fnType->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700309}
310
311bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
312 FunctionType *calleeType = nullptr;
313 OpAsmParser::OperandType callee;
314 llvm::SMLoc operandsLoc;
315 SmallVector<OpAsmParser::OperandType, 4> operands;
316 return parser->parseOperand(callee) ||
317 parser->getCurrentLocation(&operandsLoc) ||
318 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
319 OpAsmParser::Delimiter::Paren) ||
320 parser->parseOptionalAttributeDict(result->attributes) ||
321 parser->parseColonType(calleeType) ||
322 parser->resolveOperand(callee, calleeType, result->operands) ||
323 parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
324 result->operands) ||
325 parser->addTypesToList(calleeType->getResults(), result->types);
326}
327
328void CallIndirectOp::print(OpAsmPrinter *p) const {
329 *p << "call_indirect ";
330 p->printOperand(getCallee());
331 *p << '(';
332 auto operandRange = getOperands();
333 p->printOperands(++operandRange.begin(), operandRange.end());
334 *p << ')';
335 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
336 *p << " : " << *getCallee()->getType();
337}
338
339const char *CallIndirectOp::verify() const {
340 // The callee must be a function.
341 auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
342 if (!fnType)
343 return "callee must have function type";
344
345 // Verify that the operand and result types match the callee.
346 if (fnType->getNumInputs() != getNumOperands() - 1)
347 return "incorrect number of operands for callee";
348
349 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
350 if (getOperand(i + 1)->getType() != fnType->getInput(i))
351 return "operand type mismatch";
352 }
353
354 if (fnType->getNumResults() != getNumResults())
355 return "incorrect number of results for callee";
356
357 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
358 if (getResult(i)->getType() != fnType->getResult(i))
359 return "result type mismatch";
360 }
361
362 return nullptr;
363}
364
365//===----------------------------------------------------------------------===//
366// Constant*Op
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700367//===----------------------------------------------------------------------===//
368
Chris Lattnerd4964212018-08-01 10:43:18 -0700369void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700370 *p << "constant " << *getValue();
371 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
Chris Lattner4613d9e2018-08-19 21:17:22 -0700372
373 if (!isa<FunctionAttr>(getValue()))
374 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700375}
376
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700377bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattnerd4964212018-08-01 10:43:18 -0700378 Attribute *valueAttr;
379 Type *type;
Chris Lattnerd4964212018-08-01 10:43:18 -0700380
Chris Lattner4613d9e2018-08-19 21:17:22 -0700381 if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
382 parser->parseOptionalAttributeDict(result->attributes))
383 return true;
384
385 // 'constant' taking a function reference doesn't get a redundant type
386 // specifier. The attribute itself carries it.
387 if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
388 return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
389
390 return parser->parseColonType(type) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700391 parser->addTypeToList(type, result->types);
Chris Lattnerd4964212018-08-01 10:43:18 -0700392}
393
Chris Lattner9361fb32018-07-24 08:34:58 -0700394/// The constant op requires an attribute, and furthermore requires that it
395/// matches the return type.
396const char *ConstantOp::verify() const {
397 auto *value = getValue();
398 if (!value)
399 return "requires a 'value' attribute";
400
401 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700402 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700403 if (!isa<IntegerAttr>(value))
404 return "requires 'value' to be an integer for an integer result type";
405 return nullptr;
406 }
407
Chris Lattner7ba98c62018-08-16 16:56:40 -0700408 if (isa<FloatType>(type)) {
409 if (!isa<FloatAttr>(value))
410 return "requires 'value' to be a floating point constant";
411 return nullptr;
412 }
413
414 if (type->isTFString()) {
415 if (!isa<StringAttr>(value))
416 return "requires 'value' to be a string constant";
417 return nullptr;
418 }
419
Chris Lattner9361fb32018-07-24 08:34:58 -0700420 if (isa<FunctionType>(type)) {
Chris Lattner4613d9e2018-08-19 21:17:22 -0700421 if (!isa<FunctionAttr>(value))
422 return "requires 'value' to be a function reference";
423 return nullptr;
Chris Lattner9361fb32018-07-24 08:34:58 -0700424 }
425
426 return "requires a result type that aligns with the 'value' attribute";
427}
428
Chris Lattner1eb77482018-08-22 19:25:49 -0700429void ConstantFloatOp::build(Builder *builder, OperationState *result,
430 double value, FloatType *type) {
431 result->addAttribute("value", builder->getFloatAttr(value));
432 result->types.push_back(type);
Chris Lattner7ba98c62018-08-16 16:56:40 -0700433}
434
435bool ConstantFloatOp::isClassFor(const Operation *op) {
436 return ConstantOp::isClassFor(op) &&
437 isa<FloatType>(op->getResult(0)->getType());
438}
439
Chris Lattner992a1272018-08-07 12:02:37 -0700440/// ConstantIntOp only matches values whose result type is an IntegerType.
Chris Lattner9361fb32018-07-24 08:34:58 -0700441bool ConstantIntOp::isClassFor(const Operation *op) {
442 return ConstantOp::isClassFor(op) &&
Chris Lattner992a1272018-08-07 12:02:37 -0700443 isa<IntegerType>(op->getResult(0)->getType());
444}
445
Chris Lattner1eb77482018-08-22 19:25:49 -0700446void ConstantIntOp::build(Builder *builder, OperationState *result,
447 int64_t value, unsigned width) {
448 result->addAttribute("value", builder->getIntegerAttr(value));
449 result->types.push_back(builder->getIntegerType(width));
Chris Lattner992a1272018-08-07 12:02:37 -0700450}
451
452/// ConstantAffineIntOp only matches values whose result type is AffineInt.
453bool ConstantAffineIntOp::isClassFor(const Operation *op) {
454 return ConstantOp::isClassFor(op) &&
455 op->getResult(0)->getType()->isAffineInt();
456}
457
Chris Lattner1eb77482018-08-22 19:25:49 -0700458void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
459 int64_t value) {
460 result->addAttribute("value", builder->getIntegerAttr(value));
461 result->types.push_back(builder->getAffineIntType());
Chris Lattner9361fb32018-07-24 08:34:58 -0700462}
463
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700464//===----------------------------------------------------------------------===//
Uday Bondhugula67701712018-08-21 16:01:23 -0700465// AffineApplyOp
466//===----------------------------------------------------------------------===//
467
Chris Lattner1eb77482018-08-22 19:25:49 -0700468void AffineApplyOp::build(Builder *builder, OperationState *result,
469 AffineMap *map, ArrayRef<SSAValue *> operands) {
470 result->addOperands(operands);
471 result->types.append(map->getNumResults(), builder->getAffineIntType());
472 result->addAttribute("map", builder->getAffineMapAttr(map));
Uday Bondhugula67701712018-08-21 16:01:23 -0700473}
474
475//===----------------------------------------------------------------------===//
MLIR Team1989cc12018-08-15 15:39:26 -0700476// DeallocOp
477//===----------------------------------------------------------------------===//
478
479void DeallocOp::print(OpAsmPrinter *p) const {
480 *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
481}
482
483bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
484 OpAsmParser::OperandType memrefInfo;
485 MemRefType *type;
486
487 return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
488 parser->resolveOperand(memrefInfo, type, result->operands);
489}
490
491const char *DeallocOp::verify() const {
492 if (!isa<MemRefType>(getMemRef()->getType()))
493 return "operand must be a memref";
494 return nullptr;
495}
496
497//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700498// DimOp
499//===----------------------------------------------------------------------===//
500
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700501void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700502 *p << "dim " << *getOperand() << ", " << getIndex();
503 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
504 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700505}
506
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700507bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700508 OpAsmParser::OperandType operandInfo;
509 IntegerAttr *indexAttr;
510 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700511
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700512 return parser->parseOperand(operandInfo) || parser->parseComma() ||
513 parser->parseAttribute(indexAttr, "index", result->attributes) ||
514 parser->parseOptionalAttributeDict(result->attributes) ||
515 parser->parseColonType(type) ||
516 parser->resolveOperand(operandInfo, type, result->operands) ||
517 parser->addTypeToList(parser->getBuilder().getAffineIntType(),
518 result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700519}
520
Chris Lattner21e67f62018-07-06 10:46:19 -0700521const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700522 // Check that we have an integer index operand.
523 auto indexAttr = getAttrOfType<IntegerAttr>("index");
524 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700525 return "requires an integer attribute named 'index'";
526 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700527
Chris Lattner9361fb32018-07-24 08:34:58 -0700528 auto *type = getOperand()->getType();
529 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
530 if (index >= tensorType->getRank())
531 return "index is out of range";
532 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
533 if (index >= memrefType->getRank())
534 return "index is out of range";
535
536 } else if (isa<UnrankedTensorType>(type)) {
537 // ok, assumed to be in-range.
538 } else {
539 return "requires an operand with tensor or memref type";
540 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700541
542 return nullptr;
543}
544
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700545//===----------------------------------------------------------------------===//
Chris Lattner8c7feba2018-08-23 09:58:23 -0700546// ExtractElementOp
547//===----------------------------------------------------------------------===//
548
549void ExtractElementOp::build(Builder *builder, OperationState *result,
550 SSAValue *aggregate,
551 ArrayRef<SSAValue *> indices) {
552 auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
553 result->addOperands(aggregate);
554 result->addOperands(indices);
555 result->types.push_back(aggregateType->getElementType());
556}
557
558void ExtractElementOp::print(OpAsmPrinter *p) const {
559 *p << "extract_element " << *getAggregate() << '[';
560 p->printOperands(getIndices());
561 *p << ']';
562 p->printOptionalAttrDict(getAttrs());
563 *p << " : " << *getAggregate()->getType();
564}
565
566bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
567 OpAsmParser::OperandType aggregateInfo;
568 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
569 VectorOrTensorType *type;
570
571 auto affineIntTy = parser->getBuilder().getAffineIntType();
572 return parser->parseOperand(aggregateInfo) ||
573 parser->parseOperandList(indexInfo, -1,
574 OpAsmParser::Delimiter::Square) ||
575 parser->parseOptionalAttributeDict(result->attributes) ||
576 parser->parseColonType(type) ||
577 parser->resolveOperand(aggregateInfo, type, result->operands) ||
578 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
579 parser->addTypeToList(type->getElementType(), result->types);
580}
581
582const char *ExtractElementOp::verify() const {
583 if (getNumOperands() == 0)
584 return "expected an aggregate to index into";
585
586 auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
587 if (!aggregateType)
588 return "first operand must be a vector or tensor";
589
590 if (getResult()->getType() != aggregateType->getElementType())
591 return "result type must match element type of aggregate";
592
593 for (auto *idx : getIndices())
594 if (!idx->getType()->isAffineInt())
595 return "index to extract_element must have 'affineint' type";
596
597 // Verify the # indices match if we have a ranked type.
598 auto aggregateRank = aggregateType->getRankIfPresent();
599 if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
600 return "incorrect number of indices for extract_element";
601
602 return nullptr;
603}
604
605//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700606// LoadOp
607//===----------------------------------------------------------------------===//
608
Chris Lattner8c7feba2018-08-23 09:58:23 -0700609void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
610 ArrayRef<SSAValue *> indices) {
611 auto *memrefType = cast<MemRefType>(memref->getType());
612 result->addOperands(memref);
613 result->addOperands(indices);
614 result->types.push_back(memrefType->getElementType());
615}
616
Chris Lattner85ee1512018-07-25 11:15:20 -0700617void LoadOp::print(OpAsmPrinter *p) const {
618 *p << "load " << *getMemRef() << '[';
619 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700620 *p << ']';
621 p->printOptionalAttrDict(getAttrs());
622 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700623}
624
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700625bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700626 OpAsmParser::OperandType memrefInfo;
627 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
628 MemRefType *type;
Chris Lattner85ee1512018-07-25 11:15:20 -0700629
630 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700631 return parser->parseOperand(memrefInfo) ||
632 parser->parseOperandList(indexInfo, -1,
633 OpAsmParser::Delimiter::Square) ||
634 parser->parseOptionalAttributeDict(result->attributes) ||
635 parser->parseColonType(type) ||
636 parser->resolveOperand(memrefInfo, type, result->operands) ||
637 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
638 parser->addTypeToList(type->getElementType(), result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700639}
640
641const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700642 if (getNumOperands() == 0)
643 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700644
Chris Lattner3164ae62018-07-28 09:36:25 -0700645 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
646 if (!memRefType)
647 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700648
Chris Lattner8c7feba2018-08-23 09:58:23 -0700649 if (getResult()->getType() != memRefType->getElementType())
650 return "result type must match element type of memref";
651
652 if (memRefType->getRank() != getNumOperands() - 1)
653 return "incorrect number of indices for load";
654
Chris Lattner3164ae62018-07-28 09:36:25 -0700655 for (auto *idx : getIndices())
656 if (!idx->getType()->isAffineInt())
657 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700658
Chris Lattner3164ae62018-07-28 09:36:25 -0700659 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700660
Chris Lattner3164ae62018-07-28 09:36:25 -0700661 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
662 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700663 return nullptr;
664}
665
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700666//===----------------------------------------------------------------------===//
667// ReturnOp
668//===----------------------------------------------------------------------===//
669
670bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
671 SmallVector<OpAsmParser::OperandType, 2> opInfo;
672 SmallVector<Type *, 2> types;
Chris Lattner1aa46322018-08-21 17:55:22 -0700673 llvm::SMLoc loc;
674 return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700675 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
Chris Lattner1aa46322018-08-21 17:55:22 -0700676 parser->resolveOperands(opInfo, types, loc, result->operands);
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700677}
678
679void ReturnOp::print(OpAsmPrinter *p) const {
680 *p << "return";
681 if (getNumOperands() > 0) {
682 *p << " ";
683 p->printOperands(operand_begin(), operand_end());
684 *p << " : ";
685 interleave(operand_begin(), operand_end(),
686 [&](auto *e) { p->printType(e->getType()); },
687 [&]() { *p << ", "; });
688 }
689}
690
691const char *ReturnOp::verify() const {
692 // ReturnOp must be part of an ML function.
693 if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
Tatiana Shpeisman3abd6bd2018-08-16 20:19:44 -0700694 StmtBlock *block = stmt->getBlock();
695 if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700696 return "must be the last statement in the ML function";
697
698 // Return success. Checking that operand types match those in the function
699 // signature is performed in the ML function verifier.
700 return nullptr;
701 }
Tatiana Shpeismanb697ac12018-08-09 23:21:19 -0700702 return "cannot occur in a CFG function";
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700703}
704
705//===----------------------------------------------------------------------===//
706// StoreOp
707//===----------------------------------------------------------------------===//
708
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700709void StoreOp::print(OpAsmPrinter *p) const {
710 *p << "store " << *getValueToStore();
711 *p << ", " << *getMemRef() << '[';
712 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700713 *p << ']';
714 p->printOptionalAttrDict(getAttrs());
715 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700716}
717
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700718bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700719 OpAsmParser::OperandType storeValueInfo;
720 OpAsmParser::OperandType memrefInfo;
721 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700722 MemRefType *memrefType;
723
724 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700725 return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
726 parser->parseOperand(memrefInfo) ||
727 parser->parseOperandList(indexInfo, -1,
728 OpAsmParser::Delimiter::Square) ||
729 parser->parseOptionalAttributeDict(result->attributes) ||
730 parser->parseColonType(memrefType) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700731 parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
732 result->operands) ||
733 parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700734 parser->resolveOperands(indexInfo, affineIntTy, result->operands);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700735}
736
737const char *StoreOp::verify() const {
738 if (getNumOperands() < 2)
739 return "expected a value to store and a memref";
740
741 // Second operand is a memref type.
742 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
743 if (!memRefType)
744 return "second operand must be a memref";
745
746 // First operand must have same type as memref element type.
747 if (getValueToStore()->getType() != memRefType->getElementType())
748 return "first operand must have same type memref element type ";
749
750 if (getNumOperands() != 2 + memRefType->getRank())
751 return "store index operand count not equal to memref rank";
752
753 for (auto *idx : getIndices())
754 if (!idx->getType()->isAffineInt())
755 return "index to load must have 'affineint' type";
756
757 // TODO: Verify we have the right number of indices.
758
759 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
760 // result of an affine_apply.
761 return nullptr;
762}
763
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700764//===----------------------------------------------------------------------===//
765// Register operations.
766//===----------------------------------------------------------------------===//
767
Chris Lattnerff0d5902018-07-05 09:12:11 -0700768/// Install the standard operations in the specified operation set.
769void mlir::registerStandardOperations(OperationSet &opSet) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700770 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
Chris Lattner8c7feba2018-08-23 09:58:23 -0700771 ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
772 ReturnOp, StoreOp>(
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700773 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700774}