blob: fb7dfe76d16f1098ca8fcfd5fe3bd9eb1fbf513c [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070025#include "mlir/Support/STLExtras.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070026#include "llvm/Support/raw_ostream.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070027
Chris Lattnerff0d5902018-07-05 09:12:11 -070028using namespace mlir;
29
MLIR Team554a8ad2018-07-30 13:08:05 -070030static void printDimAndSymbolList(Operation::const_operand_iterator begin,
31 Operation::const_operand_iterator end,
32 unsigned numDims, OpAsmPrinter *p) {
33 *p << '(';
34 p->printOperands(begin, begin + numDims);
35 *p << ')';
36
37 if (begin + numDims != end) {
38 *p << '[';
39 p->printOperands(begin + numDims, end);
40 *p << ']';
41 }
42}
43
44// Parses dimension and symbol list, and sets 'numDims' to the number of
45// dimension operands parsed.
46// Returns 'false' on success and 'true' on error.
47static bool
48parseDimAndSymbolList(OpAsmParser *parser,
MLIR Team554a8ad2018-07-30 13:08:05 -070049 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -070050 SmallVector<OpAsmParser::OperandType, 8> opInfos;
Chris Lattner85cf26d2018-08-02 16:54:36 -070051 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070052 return true;
53 // Store number of dimensions for validation by caller.
54 numDims = opInfos.size();
55
56 // Parse the optional symbol operands.
57 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070058 if (parser->parseOperandList(opInfos, -1,
59 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070060 parser->resolveOperands(opInfos, affineIntTy, operands))
61 return true;
62 return false;
63}
64
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070065//===----------------------------------------------------------------------===//
66// AddFOp
67//===----------------------------------------------------------------------===//
68
Chris Lattnereed6c4d2018-08-07 09:12:35 -070069bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -070070 SmallVector<OpAsmParser::OperandType, 2> ops;
71 Type *type;
Chris Lattner8bdbebf2018-08-08 11:02:58 -070072 return parser->parseOperandList(ops, 2) ||
73 parser->parseOptionalAttributeDict(result->attributes) ||
74 parser->parseColonType(type) ||
75 parser->resolveOperands(ops, type, result->operands) ||
76 parser->addTypeToList(type, result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -070077}
78
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070079void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070080 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
81 p->printOptionalAttrDict(getAttrs());
82 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070083}
84
Chris Lattnereed6c4d2018-08-07 09:12:35 -070085// TODO: Have verify functions return std::string to enable more descriptive
86// error messages.
Chris Lattner21e67f62018-07-06 10:46:19 -070087// Return an error message on failure.
88const char *AddFOp::verify() const {
89 // TODO: Check that the types of the LHS and RHS match.
90 // TODO: This should be a refinement of TwoOperands.
91 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
92 return nullptr;
93}
94
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070095//===----------------------------------------------------------------------===//
96// AffineApplyOp
97//===----------------------------------------------------------------------===//
98
Chris Lattnereed6c4d2018-08-07 09:12:35 -070099bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner3164ae62018-07-28 09:36:25 -0700100 auto &builder = parser->getBuilder();
101 auto *affineIntTy = builder.getAffineIntType();
102
103 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -0700104 unsigned numDims;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700105 if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
106 parseDimAndSymbolList(parser, result->operands, numDims) ||
107 parser->parseOptionalAttributeDict(result->attributes))
108 return true;
Chris Lattner3164ae62018-07-28 09:36:25 -0700109 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700110
Chris Lattner3164ae62018-07-28 09:36:25 -0700111 if (map->getNumDims() != numDims ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700112 numDims + map->getNumSymbols() != result->operands.size()) {
113 return parser->emitError(parser->getNameLoc(),
114 "dimension or symbol index mismatch");
Chris Lattner3164ae62018-07-28 09:36:25 -0700115 }
116
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700117 result->types.append(map->getNumResults(), affineIntTy);
118 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700119}
120
121void AffineApplyOp::print(OpAsmPrinter *p) const {
122 auto *map = getAffineMap();
123 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700124 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700125 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700126}
127
128const char *AffineApplyOp::verify() const {
129 // Check that affine map attribute was specified.
130 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
131 if (!affineMapAttr)
Chris Lattner1aa46322018-08-21 17:55:22 -0700132 return "requires an affine map";
Chris Lattner3164ae62018-07-28 09:36:25 -0700133
134 // Check input and output dimensions match.
135 auto *map = affineMapAttr->getValue();
136
137 // Verify that operand count matches affine map dimension and symbol count.
138 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
139 return "operand count and affine map dimension and symbol count must match";
140
141 // Verify that result count matches affine map result count.
142 if (getNumResults() != map->getNumResults())
143 return "result count and affine map result count must match";
144
145 return nullptr;
146}
147
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700148//===----------------------------------------------------------------------===//
149// AllocOp
150//===----------------------------------------------------------------------===//
151
MLIR Team554a8ad2018-07-30 13:08:05 -0700152void AllocOp::print(OpAsmPrinter *p) const {
153 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
154 *p << "alloc";
155 // Print dynamic dimension operands.
156 printDimAndSymbolList(operand_begin(), operand_end(),
157 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700158 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700159 *p << " : " << *type;
160}
161
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700162bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Team554a8ad2018-07-30 13:08:05 -0700163 MemRefType *type;
MLIR Team554a8ad2018-07-30 13:08:05 -0700164
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700165 // Parse the dimension operands and optional symbol operands, followed by a
166 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700167 unsigned numDimOperands;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700168 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
169 parser->parseOptionalAttributeDict(result->attributes) ||
170 parser->parseColonType(type))
171 return true;
MLIR Team554a8ad2018-07-30 13:08:05 -0700172
173 // Check numDynamicDims against number of question marks in memref type.
174 if (numDimOperands != type->getNumDynamicDims()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700175 return parser->emitError(parser->getNameLoc(),
176 "dimension operand count does not equal memref "
177 "dynamic dimension count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700178 }
179
180 // Check that the number of symbol operands matches the number of symbols in
181 // the first affinemap of the memref's affine map composition.
182 // Note that a memref must specify at least one affine map in the composition.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700183 if (result->operands.size() - numDimOperands !=
MLIR Team554a8ad2018-07-30 13:08:05 -0700184 type->getAffineMaps()[0]->getNumSymbols()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700185 return parser->emitError(
186 parser->getNameLoc(),
187 "affine map symbol operand count does not equal memref affine map "
188 "symbol count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700189 }
190
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700191 result->types.push_back(type);
192 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700193}
194
195const char *AllocOp::verify() const {
196 // TODO(andydavis): Verify alloc.
197 return nullptr;
198}
199
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700200//===----------------------------------------------------------------------===//
Chris Lattner1aa46322018-08-21 17:55:22 -0700201// CallOp
202//===----------------------------------------------------------------------===//
203
204OperationState CallOp::build(Builder *builder, Function *callee,
205 ArrayRef<SSAValue *> operands) {
206 OperationState result(builder->getIdentifier("call"));
207 result.operands.append(operands.begin(), operands.end());
208 result.attributes.push_back(
209 {builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
210 result.types.append(callee->getType()->getResults().begin(),
211 callee->getType()->getResults().end());
212 return result;
213}
214
215bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
216 StringRef calleeName;
217 llvm::SMLoc calleeLoc;
218 FunctionType *calleeType = nullptr;
219 SmallVector<OpAsmParser::OperandType, 4> operands;
220 Function *callee = nullptr;
221 if (parser->parseFunctionName(calleeName, calleeLoc) ||
222 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
223 OpAsmParser::Delimiter::Paren) ||
224 parser->parseOptionalAttributeDict(result->attributes) ||
225 parser->parseColonType(calleeType) ||
226 parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
227 parser->addTypesToList(calleeType->getResults(), result->types) ||
228 parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
229 result->operands))
230 return true;
231
232 auto &builder = parser->getBuilder();
233 result->attributes.push_back(
234 {builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
235
236 return false;
237}
238
239void CallOp::print(OpAsmPrinter *p) const {
240 *p << "call ";
241 p->printFunctionReference(getCallee());
242 *p << '(';
243 p->printOperands(getOperands());
244 *p << ')';
245 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
246 *p << " : " << *getCallee()->getType();
247}
248
249const char *CallOp::verify() const {
250 // Check that the callee attribute was specified.
251 auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
252 if (!fnAttr)
253 return "requires a 'callee' function attribute";
254
255 // Verify that the operand and result types match the callee.
256 auto *fnType = fnAttr->getValue()->getType();
257 if (fnType->getNumInputs() != getNumOperands())
258 return "incorrect number of operands for callee";
259
260 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
261 if (getOperand(i)->getType() != fnType->getInput(i))
262 return "operand type mismatch";
263 }
264
265 if (fnType->getNumResults() != getNumResults())
266 return "incorrect number of results for callee";
267
268 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
269 if (getResult(i)->getType() != fnType->getResult(i))
270 return "result type mismatch";
271 }
272
273 return nullptr;
274}
275
276//===----------------------------------------------------------------------===//
277// CallIndirectOp
278//===----------------------------------------------------------------------===//
279
280OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
281 ArrayRef<SSAValue *> operands) {
282 auto *fnType = cast<FunctionType>(callee->getType());
283
284 OperationState result(builder->getIdentifier("call_indirect"));
285 result.operands.push_back(callee);
286 result.operands.append(operands.begin(), operands.end());
287 result.types.append(fnType->getResults().begin(), fnType->getResults().end());
288 return result;
289}
290
291bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
292 FunctionType *calleeType = nullptr;
293 OpAsmParser::OperandType callee;
294 llvm::SMLoc operandsLoc;
295 SmallVector<OpAsmParser::OperandType, 4> operands;
296 return parser->parseOperand(callee) ||
297 parser->getCurrentLocation(&operandsLoc) ||
298 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
299 OpAsmParser::Delimiter::Paren) ||
300 parser->parseOptionalAttributeDict(result->attributes) ||
301 parser->parseColonType(calleeType) ||
302 parser->resolveOperand(callee, calleeType, result->operands) ||
303 parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
304 result->operands) ||
305 parser->addTypesToList(calleeType->getResults(), result->types);
306}
307
308void CallIndirectOp::print(OpAsmPrinter *p) const {
309 *p << "call_indirect ";
310 p->printOperand(getCallee());
311 *p << '(';
312 auto operandRange = getOperands();
313 p->printOperands(++operandRange.begin(), operandRange.end());
314 *p << ')';
315 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
316 *p << " : " << *getCallee()->getType();
317}
318
319const char *CallIndirectOp::verify() const {
320 // The callee must be a function.
321 auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
322 if (!fnType)
323 return "callee must have function type";
324
325 // Verify that the operand and result types match the callee.
326 if (fnType->getNumInputs() != getNumOperands() - 1)
327 return "incorrect number of operands for callee";
328
329 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
330 if (getOperand(i + 1)->getType() != fnType->getInput(i))
331 return "operand type mismatch";
332 }
333
334 if (fnType->getNumResults() != getNumResults())
335 return "incorrect number of results for callee";
336
337 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
338 if (getResult(i)->getType() != fnType->getResult(i))
339 return "result type mismatch";
340 }
341
342 return nullptr;
343}
344
345//===----------------------------------------------------------------------===//
346// Constant*Op
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700347//===----------------------------------------------------------------------===//
348
Chris Lattnerd4964212018-08-01 10:43:18 -0700349void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700350 *p << "constant " << *getValue();
351 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
Chris Lattner4613d9e2018-08-19 21:17:22 -0700352
353 if (!isa<FunctionAttr>(getValue()))
354 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700355}
356
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700357bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattnerd4964212018-08-01 10:43:18 -0700358 Attribute *valueAttr;
359 Type *type;
Chris Lattnerd4964212018-08-01 10:43:18 -0700360
Chris Lattner4613d9e2018-08-19 21:17:22 -0700361 if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
362 parser->parseOptionalAttributeDict(result->attributes))
363 return true;
364
365 // 'constant' taking a function reference doesn't get a redundant type
366 // specifier. The attribute itself carries it.
367 if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
368 return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
369
370 return parser->parseColonType(type) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700371 parser->addTypeToList(type, result->types);
Chris Lattnerd4964212018-08-01 10:43:18 -0700372}
373
Chris Lattner9361fb32018-07-24 08:34:58 -0700374/// The constant op requires an attribute, and furthermore requires that it
375/// matches the return type.
376const char *ConstantOp::verify() const {
377 auto *value = getValue();
378 if (!value)
379 return "requires a 'value' attribute";
380
381 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700382 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700383 if (!isa<IntegerAttr>(value))
384 return "requires 'value' to be an integer for an integer result type";
385 return nullptr;
386 }
387
Chris Lattner7ba98c62018-08-16 16:56:40 -0700388 if (isa<FloatType>(type)) {
389 if (!isa<FloatAttr>(value))
390 return "requires 'value' to be a floating point constant";
391 return nullptr;
392 }
393
394 if (type->isTFString()) {
395 if (!isa<StringAttr>(value))
396 return "requires 'value' to be a string constant";
397 return nullptr;
398 }
399
Chris Lattner9361fb32018-07-24 08:34:58 -0700400 if (isa<FunctionType>(type)) {
Chris Lattner4613d9e2018-08-19 21:17:22 -0700401 if (!isa<FunctionAttr>(value))
402 return "requires 'value' to be a function reference";
403 return nullptr;
Chris Lattner9361fb32018-07-24 08:34:58 -0700404 }
405
406 return "requires a result type that aligns with the 'value' attribute";
407}
408
Chris Lattner7ba98c62018-08-16 16:56:40 -0700409OperationState ConstantFloatOp::build(Builder *builder, double value,
410 FloatType *type) {
411 OperationState result(builder->getIdentifier("constant"));
412 result.attributes.push_back(
413 {builder->getIdentifier("value"), builder->getFloatAttr(value)});
414 result.types.push_back(type);
415 return result;
416}
417
418bool ConstantFloatOp::isClassFor(const Operation *op) {
419 return ConstantOp::isClassFor(op) &&
420 isa<FloatType>(op->getResult(0)->getType());
421}
422
Chris Lattner992a1272018-08-07 12:02:37 -0700423/// ConstantIntOp only matches values whose result type is an IntegerType.
Chris Lattner9361fb32018-07-24 08:34:58 -0700424bool ConstantIntOp::isClassFor(const Operation *op) {
425 return ConstantOp::isClassFor(op) &&
Chris Lattner992a1272018-08-07 12:02:37 -0700426 isa<IntegerType>(op->getResult(0)->getType());
427}
428
429OperationState ConstantIntOp::build(Builder *builder, int64_t value,
430 unsigned width) {
431 OperationState result(builder->getIdentifier("constant"));
432 result.attributes.push_back(
433 {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
434 result.types.push_back(builder->getIntegerType(width));
435 return result;
436}
437
438/// ConstantAffineIntOp only matches values whose result type is AffineInt.
439bool ConstantAffineIntOp::isClassFor(const Operation *op) {
440 return ConstantOp::isClassFor(op) &&
441 op->getResult(0)->getType()->isAffineInt();
442}
443
444OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
445 OperationState result(builder->getIdentifier("constant"));
446 result.attributes.push_back(
447 {builder->getIdentifier("value"), builder->getIntegerAttr(value)});
448 result.types.push_back(builder->getAffineIntType());
449 return result;
Chris Lattner9361fb32018-07-24 08:34:58 -0700450}
451
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700452//===----------------------------------------------------------------------===//
Uday Bondhugula67701712018-08-21 16:01:23 -0700453// AffineApplyOp
454//===----------------------------------------------------------------------===//
455
456OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
457 ArrayRef<SSAValue *> operands) {
458 SmallVector<Type *, 4> resultTypes(map->getNumResults(),
459 builder->getAffineIntType());
460
461 OperationState result(
462 builder->getIdentifier("affine_apply"), operands, resultTypes,
463 {{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
464
465 return result;
466}
467
468//===----------------------------------------------------------------------===//
MLIR Team1989cc12018-08-15 15:39:26 -0700469// DeallocOp
470//===----------------------------------------------------------------------===//
471
472void DeallocOp::print(OpAsmPrinter *p) const {
473 *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
474}
475
476bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
477 OpAsmParser::OperandType memrefInfo;
478 MemRefType *type;
479
480 return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
481 parser->resolveOperand(memrefInfo, type, result->operands);
482}
483
484const char *DeallocOp::verify() const {
485 if (!isa<MemRefType>(getMemRef()->getType()))
486 return "operand must be a memref";
487 return nullptr;
488}
489
490//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700491// DimOp
492//===----------------------------------------------------------------------===//
493
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700494void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700495 *p << "dim " << *getOperand() << ", " << getIndex();
496 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
497 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700498}
499
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700500bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700501 OpAsmParser::OperandType operandInfo;
502 IntegerAttr *indexAttr;
503 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700504
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700505 return parser->parseOperand(operandInfo) || parser->parseComma() ||
506 parser->parseAttribute(indexAttr, "index", result->attributes) ||
507 parser->parseOptionalAttributeDict(result->attributes) ||
508 parser->parseColonType(type) ||
509 parser->resolveOperand(operandInfo, type, result->operands) ||
510 parser->addTypeToList(parser->getBuilder().getAffineIntType(),
511 result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700512}
513
Chris Lattner21e67f62018-07-06 10:46:19 -0700514const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700515 // Check that we have an integer index operand.
516 auto indexAttr = getAttrOfType<IntegerAttr>("index");
517 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700518 return "requires an integer attribute named 'index'";
519 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700520
Chris Lattner9361fb32018-07-24 08:34:58 -0700521 auto *type = getOperand()->getType();
522 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
523 if (index >= tensorType->getRank())
524 return "index is out of range";
525 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
526 if (index >= memrefType->getRank())
527 return "index is out of range";
528
529 } else if (isa<UnrankedTensorType>(type)) {
530 // ok, assumed to be in-range.
531 } else {
532 return "requires an operand with tensor or memref type";
533 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700534
535 return nullptr;
536}
537
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700538//===----------------------------------------------------------------------===//
539// LoadOp
540//===----------------------------------------------------------------------===//
541
Chris Lattner85ee1512018-07-25 11:15:20 -0700542void LoadOp::print(OpAsmPrinter *p) const {
543 *p << "load " << *getMemRef() << '[';
544 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700545 *p << ']';
546 p->printOptionalAttrDict(getAttrs());
547 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700548}
549
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700550bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700551 OpAsmParser::OperandType memrefInfo;
552 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
553 MemRefType *type;
Chris Lattner85ee1512018-07-25 11:15:20 -0700554
555 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700556 return parser->parseOperand(memrefInfo) ||
557 parser->parseOperandList(indexInfo, -1,
558 OpAsmParser::Delimiter::Square) ||
559 parser->parseOptionalAttributeDict(result->attributes) ||
560 parser->parseColonType(type) ||
561 parser->resolveOperand(memrefInfo, type, result->operands) ||
562 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
563 parser->addTypeToList(type->getElementType(), result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700564}
565
566const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700567 if (getNumOperands() == 0)
568 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700569
Chris Lattner3164ae62018-07-28 09:36:25 -0700570 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
571 if (!memRefType)
572 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700573
Chris Lattner3164ae62018-07-28 09:36:25 -0700574 for (auto *idx : getIndices())
575 if (!idx->getType()->isAffineInt())
576 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700577
Chris Lattner3164ae62018-07-28 09:36:25 -0700578 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700579
Chris Lattner3164ae62018-07-28 09:36:25 -0700580 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
581 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700582 return nullptr;
583}
584
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700585//===----------------------------------------------------------------------===//
586// ReturnOp
587//===----------------------------------------------------------------------===//
588
589bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
590 SmallVector<OpAsmParser::OperandType, 2> opInfo;
591 SmallVector<Type *, 2> types;
Chris Lattner1aa46322018-08-21 17:55:22 -0700592 llvm::SMLoc loc;
593 return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700594 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
Chris Lattner1aa46322018-08-21 17:55:22 -0700595 parser->resolveOperands(opInfo, types, loc, result->operands);
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700596}
597
598void ReturnOp::print(OpAsmPrinter *p) const {
599 *p << "return";
600 if (getNumOperands() > 0) {
601 *p << " ";
602 p->printOperands(operand_begin(), operand_end());
603 *p << " : ";
604 interleave(operand_begin(), operand_end(),
605 [&](auto *e) { p->printType(e->getType()); },
606 [&]() { *p << ", "; });
607 }
608}
609
610const char *ReturnOp::verify() const {
611 // ReturnOp must be part of an ML function.
612 if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
Tatiana Shpeisman3abd6bd2018-08-16 20:19:44 -0700613 StmtBlock *block = stmt->getBlock();
614 if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700615 return "must be the last statement in the ML function";
616
617 // Return success. Checking that operand types match those in the function
618 // signature is performed in the ML function verifier.
619 return nullptr;
620 }
Tatiana Shpeismanb697ac12018-08-09 23:21:19 -0700621 return "cannot occur in a CFG function";
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700622}
623
624//===----------------------------------------------------------------------===//
625// StoreOp
626//===----------------------------------------------------------------------===//
627
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700628void StoreOp::print(OpAsmPrinter *p) const {
629 *p << "store " << *getValueToStore();
630 *p << ", " << *getMemRef() << '[';
631 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700632 *p << ']';
633 p->printOptionalAttrDict(getAttrs());
634 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700635}
636
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700637bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700638 OpAsmParser::OperandType storeValueInfo;
639 OpAsmParser::OperandType memrefInfo;
640 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700641 MemRefType *memrefType;
642
643 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700644 return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
645 parser->parseOperand(memrefInfo) ||
646 parser->parseOperandList(indexInfo, -1,
647 OpAsmParser::Delimiter::Square) ||
648 parser->parseOptionalAttributeDict(result->attributes) ||
649 parser->parseColonType(memrefType) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700650 parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
651 result->operands) ||
652 parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700653 parser->resolveOperands(indexInfo, affineIntTy, result->operands);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700654}
655
656const char *StoreOp::verify() const {
657 if (getNumOperands() < 2)
658 return "expected a value to store and a memref";
659
660 // Second operand is a memref type.
661 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
662 if (!memRefType)
663 return "second operand must be a memref";
664
665 // First operand must have same type as memref element type.
666 if (getValueToStore()->getType() != memRefType->getElementType())
667 return "first operand must have same type memref element type ";
668
669 if (getNumOperands() != 2 + memRefType->getRank())
670 return "store index operand count not equal to memref rank";
671
672 for (auto *idx : getIndices())
673 if (!idx->getType()->isAffineInt())
674 return "index to load must have 'affineint' type";
675
676 // TODO: Verify we have the right number of indices.
677
678 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
679 // result of an affine_apply.
680 return nullptr;
681}
682
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700683//===----------------------------------------------------------------------===//
684// Register operations.
685//===----------------------------------------------------------------------===//
686
Chris Lattnerff0d5902018-07-05 09:12:11 -0700687/// Install the standard operations in the specified operation set.
688void mlir::registerStandardOperations(OperationSet &opSet) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700689 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
690 ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700691 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700692}