blob: 38235f04cdee88f5f9ba9da5d04072a325018330 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070025#include "mlir/Support/STLExtras.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070026#include "llvm/Support/raw_ostream.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070027
Chris Lattnerff0d5902018-07-05 09:12:11 -070028using namespace mlir;
29
MLIR Team554a8ad2018-07-30 13:08:05 -070030static void printDimAndSymbolList(Operation::const_operand_iterator begin,
31 Operation::const_operand_iterator end,
32 unsigned numDims, OpAsmPrinter *p) {
33 *p << '(';
34 p->printOperands(begin, begin + numDims);
35 *p << ')';
36
37 if (begin + numDims != end) {
38 *p << '[';
39 p->printOperands(begin + numDims, end);
40 *p << ']';
41 }
42}
43
44// Parses dimension and symbol list, and sets 'numDims' to the number of
45// dimension operands parsed.
46// Returns 'false' on success and 'true' on error.
47static bool
48parseDimAndSymbolList(OpAsmParser *parser,
MLIR Team554a8ad2018-07-30 13:08:05 -070049 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -070050 SmallVector<OpAsmParser::OperandType, 8> opInfos;
Chris Lattner85cf26d2018-08-02 16:54:36 -070051 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070052 return true;
53 // Store number of dimensions for validation by caller.
54 numDims = opInfos.size();
55
56 // Parse the optional symbol operands.
57 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070058 if (parser->parseOperandList(opInfos, -1,
59 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070060 parser->resolveOperands(opInfos, affineIntTy, operands))
61 return true;
62 return false;
63}
64
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070065//===----------------------------------------------------------------------===//
66// AddFOp
67//===----------------------------------------------------------------------===//
68
Chris Lattner1eb77482018-08-22 19:25:49 -070069void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
70 SSAValue *rhs) {
71 assert(lhs->getType() == rhs->getType());
72 result->addOperands({lhs, rhs});
73 result->types.push_back(lhs->getType());
74}
75
Chris Lattnereed6c4d2018-08-07 09:12:35 -070076bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -070077 SmallVector<OpAsmParser::OperandType, 2> ops;
78 Type *type;
Chris Lattner8bdbebf2018-08-08 11:02:58 -070079 return parser->parseOperandList(ops, 2) ||
80 parser->parseOptionalAttributeDict(result->attributes) ||
81 parser->parseColonType(type) ||
82 parser->resolveOperands(ops, type, result->operands) ||
83 parser->addTypeToList(type, result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -070084}
85
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070086void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070087 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
88 p->printOptionalAttrDict(getAttrs());
89 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070090}
91
Chris Lattnereed6c4d2018-08-07 09:12:35 -070092// TODO: Have verify functions return std::string to enable more descriptive
93// error messages.
Chris Lattner21e67f62018-07-06 10:46:19 -070094// Return an error message on failure.
95const char *AddFOp::verify() const {
96 // TODO: Check that the types of the LHS and RHS match.
97 // TODO: This should be a refinement of TwoOperands.
98 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
99 return nullptr;
100}
101
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700102//===----------------------------------------------------------------------===//
103// AffineApplyOp
104//===----------------------------------------------------------------------===//
105
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700106bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner3164ae62018-07-28 09:36:25 -0700107 auto &builder = parser->getBuilder();
108 auto *affineIntTy = builder.getAffineIntType();
109
110 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -0700111 unsigned numDims;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700112 if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
113 parseDimAndSymbolList(parser, result->operands, numDims) ||
114 parser->parseOptionalAttributeDict(result->attributes))
115 return true;
Chris Lattner3164ae62018-07-28 09:36:25 -0700116 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700117
Chris Lattner3164ae62018-07-28 09:36:25 -0700118 if (map->getNumDims() != numDims ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700119 numDims + map->getNumSymbols() != result->operands.size()) {
120 return parser->emitError(parser->getNameLoc(),
121 "dimension or symbol index mismatch");
Chris Lattner3164ae62018-07-28 09:36:25 -0700122 }
123
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700124 result->types.append(map->getNumResults(), affineIntTy);
125 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700126}
127
128void AffineApplyOp::print(OpAsmPrinter *p) const {
129 auto *map = getAffineMap();
130 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700131 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700132 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700133}
134
135const char *AffineApplyOp::verify() const {
136 // Check that affine map attribute was specified.
137 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
138 if (!affineMapAttr)
Chris Lattner1aa46322018-08-21 17:55:22 -0700139 return "requires an affine map";
Chris Lattner3164ae62018-07-28 09:36:25 -0700140
141 // Check input and output dimensions match.
142 auto *map = affineMapAttr->getValue();
143
144 // Verify that operand count matches affine map dimension and symbol count.
145 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
146 return "operand count and affine map dimension and symbol count must match";
147
148 // Verify that result count matches affine map result count.
149 if (getNumResults() != map->getNumResults())
150 return "result count and affine map result count must match";
151
152 return nullptr;
153}
154
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700155// The result of the affine apply operation can be used as a dimension id if it
156// is a CFG value or if it is an MLValue, and all the operands are valid
157// dimension ids.
158bool AffineApplyOp::isValidDim() const {
159 for (auto *op : getOperands()) {
160 if (auto *v = dyn_cast<MLValue>(op))
161 if (!v->isValidDim())
162 return false;
163 }
164 return true;
165}
166
167// The result of the affine apply operation can be used as a symbol if it is
168// a CFG value or if it is an MLValue, and all the operands are symbols.
169bool AffineApplyOp::isValidSymbol() const {
170 for (auto *op : getOperands()) {
171 if (auto *v = dyn_cast<MLValue>(op))
172 if (!v->isValidSymbol())
173 return false;
174 }
175 return true;
176}
177
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700178//===----------------------------------------------------------------------===//
179// AllocOp
180//===----------------------------------------------------------------------===//
181
MLIR Team3802ebd2018-08-31 14:49:38 -0700182void AllocOp::build(Builder *builder, OperationState *result,
183 MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
184 result->addOperands(operands);
185 result->types.push_back(memrefType);
186}
187
MLIR Team554a8ad2018-07-30 13:08:05 -0700188void AllocOp::print(OpAsmPrinter *p) const {
189 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
190 *p << "alloc";
191 // Print dynamic dimension operands.
192 printDimAndSymbolList(operand_begin(), operand_end(),
193 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700194 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700195 *p << " : " << *type;
196}
197
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700198bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Team554a8ad2018-07-30 13:08:05 -0700199 MemRefType *type;
MLIR Team554a8ad2018-07-30 13:08:05 -0700200
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700201 // Parse the dimension operands and optional symbol operands, followed by a
202 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700203 unsigned numDimOperands;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700204 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
205 parser->parseOptionalAttributeDict(result->attributes) ||
206 parser->parseColonType(type))
207 return true;
MLIR Team554a8ad2018-07-30 13:08:05 -0700208
209 // Check numDynamicDims against number of question marks in memref type.
210 if (numDimOperands != type->getNumDynamicDims()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700211 return parser->emitError(parser->getNameLoc(),
212 "dimension operand count does not equal memref "
213 "dynamic dimension count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700214 }
215
216 // Check that the number of symbol operands matches the number of symbols in
217 // the first affinemap of the memref's affine map composition.
218 // Note that a memref must specify at least one affine map in the composition.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700219 if (result->operands.size() - numDimOperands !=
MLIR Team554a8ad2018-07-30 13:08:05 -0700220 type->getAffineMaps()[0]->getNumSymbols()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700221 return parser->emitError(
222 parser->getNameLoc(),
223 "affine map symbol operand count does not equal memref affine map "
224 "symbol count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700225 }
226
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700227 result->types.push_back(type);
228 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700229}
230
231const char *AllocOp::verify() const {
232 // TODO(andydavis): Verify alloc.
233 return nullptr;
234}
235
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700236//===----------------------------------------------------------------------===//
Chris Lattner1aa46322018-08-21 17:55:22 -0700237// CallOp
238//===----------------------------------------------------------------------===//
239
Chris Lattner1eb77482018-08-22 19:25:49 -0700240void CallOp::build(Builder *builder, OperationState *result, Function *callee,
241 ArrayRef<SSAValue *> operands) {
242 result->addOperands(operands);
243 result->addAttribute("callee", builder->getFunctionAttr(callee));
244 result->addTypes(callee->getType()->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700245}
246
247bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
248 StringRef calleeName;
249 llvm::SMLoc calleeLoc;
250 FunctionType *calleeType = nullptr;
251 SmallVector<OpAsmParser::OperandType, 4> operands;
252 Function *callee = nullptr;
253 if (parser->parseFunctionName(calleeName, calleeLoc) ||
254 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
255 OpAsmParser::Delimiter::Paren) ||
256 parser->parseOptionalAttributeDict(result->attributes) ||
257 parser->parseColonType(calleeType) ||
258 parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
259 parser->addTypesToList(calleeType->getResults(), result->types) ||
260 parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
261 result->operands))
262 return true;
263
Chris Lattner1eb77482018-08-22 19:25:49 -0700264 result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
Chris Lattner1aa46322018-08-21 17:55:22 -0700265 return false;
266}
267
268void CallOp::print(OpAsmPrinter *p) const {
269 *p << "call ";
270 p->printFunctionReference(getCallee());
271 *p << '(';
272 p->printOperands(getOperands());
273 *p << ')';
274 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
275 *p << " : " << *getCallee()->getType();
276}
277
278const char *CallOp::verify() const {
279 // Check that the callee attribute was specified.
280 auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
281 if (!fnAttr)
282 return "requires a 'callee' function attribute";
283
284 // Verify that the operand and result types match the callee.
285 auto *fnType = fnAttr->getValue()->getType();
286 if (fnType->getNumInputs() != getNumOperands())
287 return "incorrect number of operands for callee";
288
289 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
290 if (getOperand(i)->getType() != fnType->getInput(i))
291 return "operand type mismatch";
292 }
293
294 if (fnType->getNumResults() != getNumResults())
295 return "incorrect number of results for callee";
296
297 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
298 if (getResult(i)->getType() != fnType->getResult(i))
299 return "result type mismatch";
300 }
301
302 return nullptr;
303}
304
305//===----------------------------------------------------------------------===//
306// CallIndirectOp
307//===----------------------------------------------------------------------===//
308
Chris Lattner1eb77482018-08-22 19:25:49 -0700309void CallIndirectOp::build(Builder *builder, OperationState *result,
310 SSAValue *callee, ArrayRef<SSAValue *> operands) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700311 auto *fnType = cast<FunctionType>(callee->getType());
Chris Lattner1eb77482018-08-22 19:25:49 -0700312 result->operands.push_back(callee);
313 result->addOperands(operands);
314 result->addTypes(fnType->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700315}
316
317bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
318 FunctionType *calleeType = nullptr;
319 OpAsmParser::OperandType callee;
320 llvm::SMLoc operandsLoc;
321 SmallVector<OpAsmParser::OperandType, 4> operands;
322 return parser->parseOperand(callee) ||
323 parser->getCurrentLocation(&operandsLoc) ||
324 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
325 OpAsmParser::Delimiter::Paren) ||
326 parser->parseOptionalAttributeDict(result->attributes) ||
327 parser->parseColonType(calleeType) ||
328 parser->resolveOperand(callee, calleeType, result->operands) ||
329 parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
330 result->operands) ||
331 parser->addTypesToList(calleeType->getResults(), result->types);
332}
333
334void CallIndirectOp::print(OpAsmPrinter *p) const {
335 *p << "call_indirect ";
336 p->printOperand(getCallee());
337 *p << '(';
338 auto operandRange = getOperands();
339 p->printOperands(++operandRange.begin(), operandRange.end());
340 *p << ')';
341 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
342 *p << " : " << *getCallee()->getType();
343}
344
345const char *CallIndirectOp::verify() const {
346 // The callee must be a function.
347 auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
348 if (!fnType)
349 return "callee must have function type";
350
351 // Verify that the operand and result types match the callee.
352 if (fnType->getNumInputs() != getNumOperands() - 1)
353 return "incorrect number of operands for callee";
354
355 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
356 if (getOperand(i + 1)->getType() != fnType->getInput(i))
357 return "operand type mismatch";
358 }
359
360 if (fnType->getNumResults() != getNumResults())
361 return "incorrect number of results for callee";
362
363 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
364 if (getResult(i)->getType() != fnType->getResult(i))
365 return "result type mismatch";
366 }
367
368 return nullptr;
369}
370
371//===----------------------------------------------------------------------===//
372// Constant*Op
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700373//===----------------------------------------------------------------------===//
374
Chris Lattnerd4964212018-08-01 10:43:18 -0700375void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700376 *p << "constant " << *getValue();
377 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
Chris Lattner4613d9e2018-08-19 21:17:22 -0700378
379 if (!isa<FunctionAttr>(getValue()))
380 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700381}
382
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700383bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattnerd4964212018-08-01 10:43:18 -0700384 Attribute *valueAttr;
385 Type *type;
Chris Lattnerd4964212018-08-01 10:43:18 -0700386
Chris Lattner4613d9e2018-08-19 21:17:22 -0700387 if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
388 parser->parseOptionalAttributeDict(result->attributes))
389 return true;
390
391 // 'constant' taking a function reference doesn't get a redundant type
392 // specifier. The attribute itself carries it.
393 if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
394 return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
395
396 return parser->parseColonType(type) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700397 parser->addTypeToList(type, result->types);
Chris Lattnerd4964212018-08-01 10:43:18 -0700398}
399
Chris Lattner9361fb32018-07-24 08:34:58 -0700400/// The constant op requires an attribute, and furthermore requires that it
401/// matches the return type.
402const char *ConstantOp::verify() const {
403 auto *value = getValue();
404 if (!value)
405 return "requires a 'value' attribute";
406
407 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700408 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700409 if (!isa<IntegerAttr>(value))
410 return "requires 'value' to be an integer for an integer result type";
411 return nullptr;
412 }
413
Chris Lattner7ba98c62018-08-16 16:56:40 -0700414 if (isa<FloatType>(type)) {
415 if (!isa<FloatAttr>(value))
416 return "requires 'value' to be a floating point constant";
417 return nullptr;
418 }
419
420 if (type->isTFString()) {
421 if (!isa<StringAttr>(value))
422 return "requires 'value' to be a string constant";
423 return nullptr;
424 }
425
Chris Lattner9361fb32018-07-24 08:34:58 -0700426 if (isa<FunctionType>(type)) {
Chris Lattner4613d9e2018-08-19 21:17:22 -0700427 if (!isa<FunctionAttr>(value))
428 return "requires 'value' to be a function reference";
429 return nullptr;
Chris Lattner9361fb32018-07-24 08:34:58 -0700430 }
431
432 return "requires a result type that aligns with the 'value' attribute";
433}
434
Chris Lattner1eb77482018-08-22 19:25:49 -0700435void ConstantFloatOp::build(Builder *builder, OperationState *result,
436 double value, FloatType *type) {
437 result->addAttribute("value", builder->getFloatAttr(value));
438 result->types.push_back(type);
Chris Lattner7ba98c62018-08-16 16:56:40 -0700439}
440
441bool ConstantFloatOp::isClassFor(const Operation *op) {
442 return ConstantOp::isClassFor(op) &&
443 isa<FloatType>(op->getResult(0)->getType());
444}
445
Chris Lattner992a1272018-08-07 12:02:37 -0700446/// ConstantIntOp only matches values whose result type is an IntegerType.
Chris Lattner9361fb32018-07-24 08:34:58 -0700447bool ConstantIntOp::isClassFor(const Operation *op) {
448 return ConstantOp::isClassFor(op) &&
Chris Lattner992a1272018-08-07 12:02:37 -0700449 isa<IntegerType>(op->getResult(0)->getType());
450}
451
Chris Lattner1eb77482018-08-22 19:25:49 -0700452void ConstantIntOp::build(Builder *builder, OperationState *result,
453 int64_t value, unsigned width) {
454 result->addAttribute("value", builder->getIntegerAttr(value));
455 result->types.push_back(builder->getIntegerType(width));
Chris Lattner992a1272018-08-07 12:02:37 -0700456}
457
458/// ConstantAffineIntOp only matches values whose result type is AffineInt.
459bool ConstantAffineIntOp::isClassFor(const Operation *op) {
460 return ConstantOp::isClassFor(op) &&
461 op->getResult(0)->getType()->isAffineInt();
462}
463
Chris Lattner1eb77482018-08-22 19:25:49 -0700464void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
465 int64_t value) {
466 result->addAttribute("value", builder->getIntegerAttr(value));
467 result->types.push_back(builder->getAffineIntType());
Chris Lattner9361fb32018-07-24 08:34:58 -0700468}
469
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700470//===----------------------------------------------------------------------===//
Uday Bondhugula67701712018-08-21 16:01:23 -0700471// AffineApplyOp
472//===----------------------------------------------------------------------===//
473
Chris Lattner1eb77482018-08-22 19:25:49 -0700474void AffineApplyOp::build(Builder *builder, OperationState *result,
475 AffineMap *map, ArrayRef<SSAValue *> operands) {
476 result->addOperands(operands);
477 result->types.append(map->getNumResults(), builder->getAffineIntType());
478 result->addAttribute("map", builder->getAffineMapAttr(map));
Uday Bondhugula67701712018-08-21 16:01:23 -0700479}
480
481//===----------------------------------------------------------------------===//
MLIR Team1989cc12018-08-15 15:39:26 -0700482// DeallocOp
483//===----------------------------------------------------------------------===//
484
MLIR Team3802ebd2018-08-31 14:49:38 -0700485void DeallocOp::build(Builder *builder, OperationState *result,
486 SSAValue *memref) {
487 result->addOperands(memref);
488}
489
MLIR Team1989cc12018-08-15 15:39:26 -0700490void DeallocOp::print(OpAsmPrinter *p) const {
491 *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
492}
493
494bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
495 OpAsmParser::OperandType memrefInfo;
496 MemRefType *type;
497
498 return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
499 parser->resolveOperand(memrefInfo, type, result->operands);
500}
501
502const char *DeallocOp::verify() const {
503 if (!isa<MemRefType>(getMemRef()->getType()))
504 return "operand must be a memref";
505 return nullptr;
506}
507
508//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700509// DimOp
510//===----------------------------------------------------------------------===//
511
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700512void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700513 *p << "dim " << *getOperand() << ", " << getIndex();
514 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
515 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700516}
517
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700518bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700519 OpAsmParser::OperandType operandInfo;
520 IntegerAttr *indexAttr;
521 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700522
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700523 return parser->parseOperand(operandInfo) || parser->parseComma() ||
524 parser->parseAttribute(indexAttr, "index", result->attributes) ||
525 parser->parseOptionalAttributeDict(result->attributes) ||
526 parser->parseColonType(type) ||
527 parser->resolveOperand(operandInfo, type, result->operands) ||
528 parser->addTypeToList(parser->getBuilder().getAffineIntType(),
529 result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700530}
531
Chris Lattner21e67f62018-07-06 10:46:19 -0700532const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700533 // Check that we have an integer index operand.
534 auto indexAttr = getAttrOfType<IntegerAttr>("index");
535 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700536 return "requires an integer attribute named 'index'";
537 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700538
Chris Lattner9361fb32018-07-24 08:34:58 -0700539 auto *type = getOperand()->getType();
540 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
541 if (index >= tensorType->getRank())
542 return "index is out of range";
543 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
544 if (index >= memrefType->getRank())
545 return "index is out of range";
546
547 } else if (isa<UnrankedTensorType>(type)) {
548 // ok, assumed to be in-range.
549 } else {
550 return "requires an operand with tensor or memref type";
551 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700552
553 return nullptr;
554}
555
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700556//===----------------------------------------------------------------------===//
Chris Lattner8c7feba2018-08-23 09:58:23 -0700557// ExtractElementOp
558//===----------------------------------------------------------------------===//
559
560void ExtractElementOp::build(Builder *builder, OperationState *result,
561 SSAValue *aggregate,
562 ArrayRef<SSAValue *> indices) {
563 auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
564 result->addOperands(aggregate);
565 result->addOperands(indices);
566 result->types.push_back(aggregateType->getElementType());
567}
568
569void ExtractElementOp::print(OpAsmPrinter *p) const {
570 *p << "extract_element " << *getAggregate() << '[';
571 p->printOperands(getIndices());
572 *p << ']';
573 p->printOptionalAttrDict(getAttrs());
574 *p << " : " << *getAggregate()->getType();
575}
576
577bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
578 OpAsmParser::OperandType aggregateInfo;
579 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
580 VectorOrTensorType *type;
581
582 auto affineIntTy = parser->getBuilder().getAffineIntType();
583 return parser->parseOperand(aggregateInfo) ||
584 parser->parseOperandList(indexInfo, -1,
585 OpAsmParser::Delimiter::Square) ||
586 parser->parseOptionalAttributeDict(result->attributes) ||
587 parser->parseColonType(type) ||
588 parser->resolveOperand(aggregateInfo, type, result->operands) ||
589 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
590 parser->addTypeToList(type->getElementType(), result->types);
591}
592
593const char *ExtractElementOp::verify() const {
594 if (getNumOperands() == 0)
595 return "expected an aggregate to index into";
596
597 auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
598 if (!aggregateType)
599 return "first operand must be a vector or tensor";
600
601 if (getResult()->getType() != aggregateType->getElementType())
602 return "result type must match element type of aggregate";
603
604 for (auto *idx : getIndices())
605 if (!idx->getType()->isAffineInt())
606 return "index to extract_element must have 'affineint' type";
607
608 // Verify the # indices match if we have a ranked type.
609 auto aggregateRank = aggregateType->getRankIfPresent();
610 if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
611 return "incorrect number of indices for extract_element";
612
613 return nullptr;
614}
615
616//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700617// LoadOp
618//===----------------------------------------------------------------------===//
619
Chris Lattner8c7feba2018-08-23 09:58:23 -0700620void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
621 ArrayRef<SSAValue *> indices) {
622 auto *memrefType = cast<MemRefType>(memref->getType());
623 result->addOperands(memref);
624 result->addOperands(indices);
625 result->types.push_back(memrefType->getElementType());
626}
627
Chris Lattner85ee1512018-07-25 11:15:20 -0700628void LoadOp::print(OpAsmPrinter *p) const {
629 *p << "load " << *getMemRef() << '[';
630 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700631 *p << ']';
632 p->printOptionalAttrDict(getAttrs());
633 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700634}
635
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700636bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700637 OpAsmParser::OperandType memrefInfo;
638 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
639 MemRefType *type;
Chris Lattner85ee1512018-07-25 11:15:20 -0700640
641 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700642 return parser->parseOperand(memrefInfo) ||
643 parser->parseOperandList(indexInfo, -1,
644 OpAsmParser::Delimiter::Square) ||
645 parser->parseOptionalAttributeDict(result->attributes) ||
646 parser->parseColonType(type) ||
647 parser->resolveOperand(memrefInfo, type, result->operands) ||
648 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
649 parser->addTypeToList(type->getElementType(), result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700650}
651
652const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700653 if (getNumOperands() == 0)
654 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700655
Chris Lattner3164ae62018-07-28 09:36:25 -0700656 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
657 if (!memRefType)
658 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700659
Chris Lattner8c7feba2018-08-23 09:58:23 -0700660 if (getResult()->getType() != memRefType->getElementType())
661 return "result type must match element type of memref";
662
663 if (memRefType->getRank() != getNumOperands() - 1)
664 return "incorrect number of indices for load";
665
Chris Lattner3164ae62018-07-28 09:36:25 -0700666 for (auto *idx : getIndices())
667 if (!idx->getType()->isAffineInt())
668 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700669
Chris Lattner3164ae62018-07-28 09:36:25 -0700670 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700671
Chris Lattner3164ae62018-07-28 09:36:25 -0700672 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
673 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700674 return nullptr;
675}
676
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700677//===----------------------------------------------------------------------===//
678// ReturnOp
679//===----------------------------------------------------------------------===//
680
681bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
682 SmallVector<OpAsmParser::OperandType, 2> opInfo;
683 SmallVector<Type *, 2> types;
Chris Lattner1aa46322018-08-21 17:55:22 -0700684 llvm::SMLoc loc;
685 return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700686 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
Chris Lattner1aa46322018-08-21 17:55:22 -0700687 parser->resolveOperands(opInfo, types, loc, result->operands);
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700688}
689
690void ReturnOp::print(OpAsmPrinter *p) const {
691 *p << "return";
692 if (getNumOperands() > 0) {
693 *p << " ";
694 p->printOperands(operand_begin(), operand_end());
695 *p << " : ";
696 interleave(operand_begin(), operand_end(),
MLIR Team6a220a62018-09-07 12:34:19 -0700697 [&](const SSAValue *e) { p->printType(e->getType()); },
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700698 [&]() { *p << ", "; });
699 }
700}
701
702const char *ReturnOp::verify() const {
703 // ReturnOp must be part of an ML function.
704 if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
Tatiana Shpeisman3abd6bd2018-08-16 20:19:44 -0700705 StmtBlock *block = stmt->getBlock();
706 if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700707 return "must be the last statement in the ML function";
708
709 // Return success. Checking that operand types match those in the function
710 // signature is performed in the ML function verifier.
711 return nullptr;
712 }
Tatiana Shpeismanb697ac12018-08-09 23:21:19 -0700713 return "cannot occur in a CFG function";
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700714}
715
716//===----------------------------------------------------------------------===//
Chris Lattner0473e4b2018-09-06 17:31:21 -0700717// ShapeCastOp
718//===----------------------------------------------------------------------===//
719
720void ShapeCastOp::build(Builder *builder, OperationState *result,
721 SSAValue *input, Type *resultType) {
722 result->addOperands(input);
723 result->addTypes(resultType);
724}
725
726const char *ShapeCastOp::verify() const {
727 auto *opType = dyn_cast<TensorType>(getOperand()->getType());
728 auto *resType = dyn_cast<TensorType>(getResult()->getType());
729 if (!opType || !resType)
730 return "requires input and result types to be tensors";
731
732 if (opType == resType)
733 return "requires the input and result type to be different";
734
735 if (opType->getElementType() != resType->getElementType())
736 return "requires input and result element types to be the same";
737
738 // If the source or destination are unranked, then the cast is valid.
739 auto *opRType = dyn_cast<RankedTensorType>(opType);
740 auto *resRType = dyn_cast<RankedTensorType>(resType);
741 if (!opRType || !resRType)
742 return nullptr;
743
744 // If they are both ranked, they have to have the same rank, and any specified
745 // dimensions must match.
746 if (opRType->getRank() != resRType->getRank())
747 return "requires input and result ranks to match";
748
749 for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
750 int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
751 if (opDim != -1 && resultDim != -1 && opDim != resultDim)
752 return "requires static dimensions to match";
753 }
754
755 return nullptr;
756}
757
758//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700759// StoreOp
760//===----------------------------------------------------------------------===//
761
MLIR Team3802ebd2018-08-31 14:49:38 -0700762void StoreOp::build(Builder *builder, OperationState *result,
763 SSAValue *valueToStore, SSAValue *memref,
764 ArrayRef<SSAValue *> indices) {
765 result->addOperands(valueToStore);
766 result->addOperands(memref);
767 result->addOperands(indices);
768}
769
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700770void StoreOp::print(OpAsmPrinter *p) const {
771 *p << "store " << *getValueToStore();
772 *p << ", " << *getMemRef() << '[';
773 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700774 *p << ']';
775 p->printOptionalAttrDict(getAttrs());
776 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700777}
778
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700779bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700780 OpAsmParser::OperandType storeValueInfo;
781 OpAsmParser::OperandType memrefInfo;
782 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700783 MemRefType *memrefType;
784
785 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700786 return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
787 parser->parseOperand(memrefInfo) ||
788 parser->parseOperandList(indexInfo, -1,
789 OpAsmParser::Delimiter::Square) ||
790 parser->parseOptionalAttributeDict(result->attributes) ||
791 parser->parseColonType(memrefType) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700792 parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
793 result->operands) ||
794 parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700795 parser->resolveOperands(indexInfo, affineIntTy, result->operands);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700796}
797
798const char *StoreOp::verify() const {
799 if (getNumOperands() < 2)
800 return "expected a value to store and a memref";
801
802 // Second operand is a memref type.
803 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
804 if (!memRefType)
805 return "second operand must be a memref";
806
807 // First operand must have same type as memref element type.
808 if (getValueToStore()->getType() != memRefType->getElementType())
809 return "first operand must have same type memref element type ";
810
811 if (getNumOperands() != 2 + memRefType->getRank())
812 return "store index operand count not equal to memref rank";
813
814 for (auto *idx : getIndices())
815 if (!idx->getType()->isAffineInt())
816 return "index to load must have 'affineint' type";
817
818 // TODO: Verify we have the right number of indices.
819
820 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
821 // result of an affine_apply.
822 return nullptr;
823}
824
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700825//===----------------------------------------------------------------------===//
826// Register operations.
827//===----------------------------------------------------------------------===//
828
Chris Lattnerff0d5902018-07-05 09:12:11 -0700829/// Install the standard operations in the specified operation set.
830void mlir::registerStandardOperations(OperationSet &opSet) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700831 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
Chris Lattner8c7feba2018-08-23 09:58:23 -0700832 ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
Chris Lattner0473e4b2018-09-06 17:31:21 -0700833 ReturnOp, ShapeCastOp, StoreOp>(
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700834 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700835}