blob: 175d41144341d86b415cee4c233820dc68d777b1 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070025#include "mlir/Support/STLExtras.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070026#include "llvm/Support/raw_ostream.h"
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070027
Chris Lattnerff0d5902018-07-05 09:12:11 -070028using namespace mlir;
29
MLIR Team554a8ad2018-07-30 13:08:05 -070030static void printDimAndSymbolList(Operation::const_operand_iterator begin,
31 Operation::const_operand_iterator end,
32 unsigned numDims, OpAsmPrinter *p) {
33 *p << '(';
34 p->printOperands(begin, begin + numDims);
35 *p << ')';
36
37 if (begin + numDims != end) {
38 *p << '[';
39 p->printOperands(begin + numDims, end);
40 *p << ']';
41 }
42}
43
44// Parses dimension and symbol list, and sets 'numDims' to the number of
45// dimension operands parsed.
46// Returns 'false' on success and 'true' on error.
47static bool
48parseDimAndSymbolList(OpAsmParser *parser,
MLIR Team554a8ad2018-07-30 13:08:05 -070049 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -070050 SmallVector<OpAsmParser::OperandType, 8> opInfos;
Chris Lattner85cf26d2018-08-02 16:54:36 -070051 if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
MLIR Team554a8ad2018-07-30 13:08:05 -070052 return true;
53 // Store number of dimensions for validation by caller.
54 numDims = opInfos.size();
55
56 // Parse the optional symbol operands.
57 auto *affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner85cf26d2018-08-02 16:54:36 -070058 if (parser->parseOperandList(opInfos, -1,
59 OpAsmParser::Delimiter::OptionalSquare) ||
MLIR Team554a8ad2018-07-30 13:08:05 -070060 parser->resolveOperands(opInfos, affineIntTy, operands))
61 return true;
62 return false;
63}
64
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070065//===----------------------------------------------------------------------===//
66// AddFOp
67//===----------------------------------------------------------------------===//
68
Chris Lattner1eb77482018-08-22 19:25:49 -070069void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
70 SSAValue *rhs) {
71 assert(lhs->getType() == rhs->getType());
72 result->addOperands({lhs, rhs});
73 result->types.push_back(lhs->getType());
74}
75
Chris Lattnereed6c4d2018-08-07 09:12:35 -070076bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -070077 SmallVector<OpAsmParser::OperandType, 2> ops;
78 Type *type;
Chris Lattner8bdbebf2018-08-08 11:02:58 -070079 return parser->parseOperandList(ops, 2) ||
80 parser->parseOptionalAttributeDict(result->attributes) ||
81 parser->parseColonType(type) ||
82 parser->resolveOperands(ops, type, result->operands) ||
83 parser->addTypeToList(type, result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -070084}
85
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070086void AddFOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -070087 *p << "addf " << *getOperand(0) << ", " << *getOperand(1);
88 p->printOptionalAttrDict(getAttrs());
89 *p << " : " << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070090}
91
Chris Lattner015100b2018-09-09 20:40:23 -070092bool AddFOp::verify() const {
93 // TODO: check that the single type is a float type.
94 return false;
Chris Lattner21e67f62018-07-06 10:46:19 -070095}
96
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -070097//===----------------------------------------------------------------------===//
98// AffineApplyOp
99//===----------------------------------------------------------------------===//
100
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700101bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner3164ae62018-07-28 09:36:25 -0700102 auto &builder = parser->getBuilder();
103 auto *affineIntTy = builder.getAffineIntType();
104
105 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -0700106 unsigned numDims;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700107 if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
108 parseDimAndSymbolList(parser, result->operands, numDims) ||
109 parser->parseOptionalAttributeDict(result->attributes))
110 return true;
Chris Lattner3164ae62018-07-28 09:36:25 -0700111 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700112
Chris Lattner3164ae62018-07-28 09:36:25 -0700113 if (map->getNumDims() != numDims ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700114 numDims + map->getNumSymbols() != result->operands.size()) {
115 return parser->emitError(parser->getNameLoc(),
116 "dimension or symbol index mismatch");
Chris Lattner3164ae62018-07-28 09:36:25 -0700117 }
118
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700119 result->types.append(map->getNumResults(), affineIntTy);
120 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700121}
122
123void AffineApplyOp::print(OpAsmPrinter *p) const {
124 auto *map = getAffineMap();
125 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700126 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700127 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700128}
129
Chris Lattner015100b2018-09-09 20:40:23 -0700130bool AffineApplyOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700131 // Check that affine map attribute was specified.
132 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
133 if (!affineMapAttr)
Chris Lattner015100b2018-09-09 20:40:23 -0700134 return emitOpError("requires an affine map");
Chris Lattner3164ae62018-07-28 09:36:25 -0700135
136 // Check input and output dimensions match.
137 auto *map = affineMapAttr->getValue();
138
139 // Verify that operand count matches affine map dimension and symbol count.
140 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
Chris Lattner015100b2018-09-09 20:40:23 -0700141 return emitOpError(
142 "operand count and affine map dimension and symbol count must match");
Chris Lattner3164ae62018-07-28 09:36:25 -0700143
144 // Verify that result count matches affine map result count.
145 if (getNumResults() != map->getNumResults())
Chris Lattner015100b2018-09-09 20:40:23 -0700146 return emitOpError("result count and affine map result count must match");
Chris Lattner3164ae62018-07-28 09:36:25 -0700147
Chris Lattner015100b2018-09-09 20:40:23 -0700148 return false;
Chris Lattner3164ae62018-07-28 09:36:25 -0700149}
150
Tatiana Shpeismande8829f2018-08-24 23:38:14 -0700151// The result of the affine apply operation can be used as a dimension id if it
152// is a CFG value or if it is an MLValue, and all the operands are valid
153// dimension ids.
154bool AffineApplyOp::isValidDim() const {
155 for (auto *op : getOperands()) {
156 if (auto *v = dyn_cast<MLValue>(op))
157 if (!v->isValidDim())
158 return false;
159 }
160 return true;
161}
162
163// The result of the affine apply operation can be used as a symbol if it is
164// a CFG value or if it is an MLValue, and all the operands are symbols.
165bool AffineApplyOp::isValidSymbol() const {
166 for (auto *op : getOperands()) {
167 if (auto *v = dyn_cast<MLValue>(op))
168 if (!v->isValidSymbol())
169 return false;
170 }
171 return true;
172}
173
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700174//===----------------------------------------------------------------------===//
175// AllocOp
176//===----------------------------------------------------------------------===//
177
MLIR Team3802ebd2018-08-31 14:49:38 -0700178void AllocOp::build(Builder *builder, OperationState *result,
179 MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
180 result->addOperands(operands);
181 result->types.push_back(memrefType);
182}
183
MLIR Team554a8ad2018-07-30 13:08:05 -0700184void AllocOp::print(OpAsmPrinter *p) const {
185 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
186 *p << "alloc";
187 // Print dynamic dimension operands.
188 printDimAndSymbolList(operand_begin(), operand_end(),
189 type->getNumDynamicDims(), p);
Chris Lattner85cf26d2018-08-02 16:54:36 -0700190 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
MLIR Team554a8ad2018-07-30 13:08:05 -0700191 *p << " : " << *type;
192}
193
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700194bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Team554a8ad2018-07-30 13:08:05 -0700195 MemRefType *type;
MLIR Team554a8ad2018-07-30 13:08:05 -0700196
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700197 // Parse the dimension operands and optional symbol operands, followed by a
198 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700199 unsigned numDimOperands;
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700200 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
201 parser->parseOptionalAttributeDict(result->attributes) ||
202 parser->parseColonType(type))
203 return true;
MLIR Team554a8ad2018-07-30 13:08:05 -0700204
205 // Check numDynamicDims against number of question marks in memref type.
206 if (numDimOperands != type->getNumDynamicDims()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700207 return parser->emitError(parser->getNameLoc(),
208 "dimension operand count does not equal memref "
209 "dynamic dimension count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700210 }
211
212 // Check that the number of symbol operands matches the number of symbols in
213 // the first affinemap of the memref's affine map composition.
214 // Note that a memref must specify at least one affine map in the composition.
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700215 if (result->operands.size() - numDimOperands !=
MLIR Team554a8ad2018-07-30 13:08:05 -0700216 type->getAffineMaps()[0]->getNumSymbols()) {
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700217 return parser->emitError(
218 parser->getNameLoc(),
219 "affine map symbol operand count does not equal memref affine map "
220 "symbol count");
MLIR Team554a8ad2018-07-30 13:08:05 -0700221 }
222
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700223 result->types.push_back(type);
224 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700225}
226
Chris Lattner015100b2018-09-09 20:40:23 -0700227bool AllocOp::verify() const {
MLIR Team554a8ad2018-07-30 13:08:05 -0700228 // TODO(andydavis): Verify alloc.
Chris Lattner015100b2018-09-09 20:40:23 -0700229 return false;
MLIR Team554a8ad2018-07-30 13:08:05 -0700230}
231
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700232//===----------------------------------------------------------------------===//
Chris Lattner1aa46322018-08-21 17:55:22 -0700233// CallOp
234//===----------------------------------------------------------------------===//
235
Chris Lattner1eb77482018-08-22 19:25:49 -0700236void CallOp::build(Builder *builder, OperationState *result, Function *callee,
237 ArrayRef<SSAValue *> operands) {
238 result->addOperands(operands);
239 result->addAttribute("callee", builder->getFunctionAttr(callee));
240 result->addTypes(callee->getType()->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700241}
242
243bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
244 StringRef calleeName;
245 llvm::SMLoc calleeLoc;
246 FunctionType *calleeType = nullptr;
247 SmallVector<OpAsmParser::OperandType, 4> operands;
248 Function *callee = nullptr;
249 if (parser->parseFunctionName(calleeName, calleeLoc) ||
250 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
251 OpAsmParser::Delimiter::Paren) ||
252 parser->parseOptionalAttributeDict(result->attributes) ||
253 parser->parseColonType(calleeType) ||
254 parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
255 parser->addTypesToList(calleeType->getResults(), result->types) ||
256 parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
257 result->operands))
258 return true;
259
Chris Lattner1eb77482018-08-22 19:25:49 -0700260 result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
Chris Lattner1aa46322018-08-21 17:55:22 -0700261 return false;
262}
263
264void CallOp::print(OpAsmPrinter *p) const {
265 *p << "call ";
266 p->printFunctionReference(getCallee());
267 *p << '(';
268 p->printOperands(getOperands());
269 *p << ')';
270 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
271 *p << " : " << *getCallee()->getType();
272}
273
Chris Lattner015100b2018-09-09 20:40:23 -0700274bool CallOp::verify() const {
Chris Lattner1aa46322018-08-21 17:55:22 -0700275 // Check that the callee attribute was specified.
276 auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
277 if (!fnAttr)
Chris Lattner015100b2018-09-09 20:40:23 -0700278 return emitOpError("requires a 'callee' function attribute");
Chris Lattner1aa46322018-08-21 17:55:22 -0700279
280 // Verify that the operand and result types match the callee.
281 auto *fnType = fnAttr->getValue()->getType();
282 if (fnType->getNumInputs() != getNumOperands())
Chris Lattner015100b2018-09-09 20:40:23 -0700283 return emitOpError("incorrect number of operands for callee");
Chris Lattner1aa46322018-08-21 17:55:22 -0700284
285 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
286 if (getOperand(i)->getType() != fnType->getInput(i))
Chris Lattner015100b2018-09-09 20:40:23 -0700287 return emitOpError("operand type mismatch");
Chris Lattner1aa46322018-08-21 17:55:22 -0700288 }
289
290 if (fnType->getNumResults() != getNumResults())
Chris Lattner015100b2018-09-09 20:40:23 -0700291 return emitOpError("incorrect number of results for callee");
Chris Lattner1aa46322018-08-21 17:55:22 -0700292
293 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
294 if (getResult(i)->getType() != fnType->getResult(i))
Chris Lattner015100b2018-09-09 20:40:23 -0700295 return emitOpError("result type mismatch");
Chris Lattner1aa46322018-08-21 17:55:22 -0700296 }
297
Chris Lattner015100b2018-09-09 20:40:23 -0700298 return false;
Chris Lattner1aa46322018-08-21 17:55:22 -0700299}
300
301//===----------------------------------------------------------------------===//
302// CallIndirectOp
303//===----------------------------------------------------------------------===//
304
Chris Lattner1eb77482018-08-22 19:25:49 -0700305void CallIndirectOp::build(Builder *builder, OperationState *result,
306 SSAValue *callee, ArrayRef<SSAValue *> operands) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700307 auto *fnType = cast<FunctionType>(callee->getType());
Chris Lattner1eb77482018-08-22 19:25:49 -0700308 result->operands.push_back(callee);
309 result->addOperands(operands);
310 result->addTypes(fnType->getResults());
Chris Lattner1aa46322018-08-21 17:55:22 -0700311}
312
313bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
314 FunctionType *calleeType = nullptr;
315 OpAsmParser::OperandType callee;
316 llvm::SMLoc operandsLoc;
317 SmallVector<OpAsmParser::OperandType, 4> operands;
318 return parser->parseOperand(callee) ||
319 parser->getCurrentLocation(&operandsLoc) ||
320 parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
321 OpAsmParser::Delimiter::Paren) ||
322 parser->parseOptionalAttributeDict(result->attributes) ||
323 parser->parseColonType(calleeType) ||
324 parser->resolveOperand(callee, calleeType, result->operands) ||
325 parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
326 result->operands) ||
327 parser->addTypesToList(calleeType->getResults(), result->types);
328}
329
330void CallIndirectOp::print(OpAsmPrinter *p) const {
331 *p << "call_indirect ";
332 p->printOperand(getCallee());
333 *p << '(';
334 auto operandRange = getOperands();
335 p->printOperands(++operandRange.begin(), operandRange.end());
336 *p << ')';
337 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
338 *p << " : " << *getCallee()->getType();
339}
340
Chris Lattner015100b2018-09-09 20:40:23 -0700341bool CallIndirectOp::verify() const {
Chris Lattner1aa46322018-08-21 17:55:22 -0700342 // The callee must be a function.
343 auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
344 if (!fnType)
Chris Lattner015100b2018-09-09 20:40:23 -0700345 return emitOpError("callee must have function type");
Chris Lattner1aa46322018-08-21 17:55:22 -0700346
347 // Verify that the operand and result types match the callee.
348 if (fnType->getNumInputs() != getNumOperands() - 1)
Chris Lattner015100b2018-09-09 20:40:23 -0700349 return emitOpError("incorrect number of operands for callee");
Chris Lattner1aa46322018-08-21 17:55:22 -0700350
351 for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
352 if (getOperand(i + 1)->getType() != fnType->getInput(i))
Chris Lattner015100b2018-09-09 20:40:23 -0700353 return emitOpError("operand type mismatch");
Chris Lattner1aa46322018-08-21 17:55:22 -0700354 }
355
356 if (fnType->getNumResults() != getNumResults())
Chris Lattner015100b2018-09-09 20:40:23 -0700357 return emitOpError("incorrect number of results for callee");
Chris Lattner1aa46322018-08-21 17:55:22 -0700358
359 for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
360 if (getResult(i)->getType() != fnType->getResult(i))
Chris Lattner015100b2018-09-09 20:40:23 -0700361 return emitOpError("result type mismatch");
Chris Lattner1aa46322018-08-21 17:55:22 -0700362 }
363
Chris Lattner015100b2018-09-09 20:40:23 -0700364 return false;
Chris Lattner1aa46322018-08-21 17:55:22 -0700365}
366
367//===----------------------------------------------------------------------===//
368// Constant*Op
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700369//===----------------------------------------------------------------------===//
370
Chris Lattnerd4964212018-08-01 10:43:18 -0700371void ConstantOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700372 *p << "constant " << *getValue();
373 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
Chris Lattner4613d9e2018-08-19 21:17:22 -0700374
375 if (!isa<FunctionAttr>(getValue()))
376 *p << " : " << *getType();
Chris Lattnerd4964212018-08-01 10:43:18 -0700377}
378
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700379bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattnerd4964212018-08-01 10:43:18 -0700380 Attribute *valueAttr;
381 Type *type;
Chris Lattnerd4964212018-08-01 10:43:18 -0700382
Chris Lattner4613d9e2018-08-19 21:17:22 -0700383 if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
384 parser->parseOptionalAttributeDict(result->attributes))
385 return true;
386
387 // 'constant' taking a function reference doesn't get a redundant type
388 // specifier. The attribute itself carries it.
389 if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
390 return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
391
392 return parser->parseColonType(type) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700393 parser->addTypeToList(type, result->types);
Chris Lattnerd4964212018-08-01 10:43:18 -0700394}
395
Chris Lattner9361fb32018-07-24 08:34:58 -0700396/// The constant op requires an attribute, and furthermore requires that it
397/// matches the return type.
Chris Lattner015100b2018-09-09 20:40:23 -0700398bool ConstantOp::verify() const {
Chris Lattner9361fb32018-07-24 08:34:58 -0700399 auto *value = getValue();
400 if (!value)
Chris Lattner015100b2018-09-09 20:40:23 -0700401 return emitOpError("requires a 'value' attribute");
Chris Lattner9361fb32018-07-24 08:34:58 -0700402
403 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700404 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700405 if (!isa<IntegerAttr>(value))
Chris Lattner015100b2018-09-09 20:40:23 -0700406 return emitOpError(
407 "requires 'value' to be an integer for an integer result type");
408 return false;
Chris Lattner9361fb32018-07-24 08:34:58 -0700409 }
410
Chris Lattner7ba98c62018-08-16 16:56:40 -0700411 if (isa<FloatType>(type)) {
412 if (!isa<FloatAttr>(value))
Chris Lattner015100b2018-09-09 20:40:23 -0700413 return emitOpError("requires 'value' to be a floating point constant");
414 return false;
Chris Lattner7ba98c62018-08-16 16:56:40 -0700415 }
416
417 if (type->isTFString()) {
418 if (!isa<StringAttr>(value))
Chris Lattner015100b2018-09-09 20:40:23 -0700419 return emitOpError("requires 'value' to be a string constant");
420 return false;
Chris Lattner7ba98c62018-08-16 16:56:40 -0700421 }
422
Chris Lattner9361fb32018-07-24 08:34:58 -0700423 if (isa<FunctionType>(type)) {
Chris Lattner4613d9e2018-08-19 21:17:22 -0700424 if (!isa<FunctionAttr>(value))
Chris Lattner015100b2018-09-09 20:40:23 -0700425 return emitOpError("requires 'value' to be a function reference");
426 return false;
Chris Lattner9361fb32018-07-24 08:34:58 -0700427 }
428
Chris Lattner015100b2018-09-09 20:40:23 -0700429 return emitOpError(
430 "requires a result type that aligns with the 'value' attribute");
Chris Lattner9361fb32018-07-24 08:34:58 -0700431}
432
Chris Lattner1eb77482018-08-22 19:25:49 -0700433void ConstantFloatOp::build(Builder *builder, OperationState *result,
434 double value, FloatType *type) {
435 result->addAttribute("value", builder->getFloatAttr(value));
436 result->types.push_back(type);
Chris Lattner7ba98c62018-08-16 16:56:40 -0700437}
438
439bool ConstantFloatOp::isClassFor(const Operation *op) {
440 return ConstantOp::isClassFor(op) &&
441 isa<FloatType>(op->getResult(0)->getType());
442}
443
Chris Lattner992a1272018-08-07 12:02:37 -0700444/// ConstantIntOp only matches values whose result type is an IntegerType.
Chris Lattner9361fb32018-07-24 08:34:58 -0700445bool ConstantIntOp::isClassFor(const Operation *op) {
446 return ConstantOp::isClassFor(op) &&
Chris Lattner992a1272018-08-07 12:02:37 -0700447 isa<IntegerType>(op->getResult(0)->getType());
448}
449
Chris Lattner1eb77482018-08-22 19:25:49 -0700450void ConstantIntOp::build(Builder *builder, OperationState *result,
451 int64_t value, unsigned width) {
452 result->addAttribute("value", builder->getIntegerAttr(value));
453 result->types.push_back(builder->getIntegerType(width));
Chris Lattner992a1272018-08-07 12:02:37 -0700454}
455
456/// ConstantAffineIntOp only matches values whose result type is AffineInt.
457bool ConstantAffineIntOp::isClassFor(const Operation *op) {
458 return ConstantOp::isClassFor(op) &&
459 op->getResult(0)->getType()->isAffineInt();
460}
461
Chris Lattner1eb77482018-08-22 19:25:49 -0700462void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
463 int64_t value) {
464 result->addAttribute("value", builder->getIntegerAttr(value));
465 result->types.push_back(builder->getAffineIntType());
Chris Lattner9361fb32018-07-24 08:34:58 -0700466}
467
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700468//===----------------------------------------------------------------------===//
Uday Bondhugula67701712018-08-21 16:01:23 -0700469// AffineApplyOp
470//===----------------------------------------------------------------------===//
471
Chris Lattner1eb77482018-08-22 19:25:49 -0700472void AffineApplyOp::build(Builder *builder, OperationState *result,
473 AffineMap *map, ArrayRef<SSAValue *> operands) {
474 result->addOperands(operands);
475 result->types.append(map->getNumResults(), builder->getAffineIntType());
476 result->addAttribute("map", builder->getAffineMapAttr(map));
Uday Bondhugula67701712018-08-21 16:01:23 -0700477}
478
479//===----------------------------------------------------------------------===//
MLIR Team1989cc12018-08-15 15:39:26 -0700480// DeallocOp
481//===----------------------------------------------------------------------===//
482
MLIR Team3802ebd2018-08-31 14:49:38 -0700483void DeallocOp::build(Builder *builder, OperationState *result,
484 SSAValue *memref) {
485 result->addOperands(memref);
486}
487
MLIR Team1989cc12018-08-15 15:39:26 -0700488void DeallocOp::print(OpAsmPrinter *p) const {
489 *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
490}
491
492bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
493 OpAsmParser::OperandType memrefInfo;
494 MemRefType *type;
495
496 return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
497 parser->resolveOperand(memrefInfo, type, result->operands);
498}
499
Chris Lattner015100b2018-09-09 20:40:23 -0700500bool DeallocOp::verify() const {
MLIR Team1989cc12018-08-15 15:39:26 -0700501 if (!isa<MemRefType>(getMemRef()->getType()))
Chris Lattner015100b2018-09-09 20:40:23 -0700502 return emitOpError("operand must be a memref");
503 return false;
MLIR Team1989cc12018-08-15 15:39:26 -0700504}
505
506//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700507// DimOp
508//===----------------------------------------------------------------------===//
509
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700510void DimOp::print(OpAsmPrinter *p) const {
Chris Lattner85cf26d2018-08-02 16:54:36 -0700511 *p << "dim " << *getOperand() << ", " << getIndex();
512 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
513 *p << " : " << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700514}
515
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700516bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700517 OpAsmParser::OperandType operandInfo;
518 IntegerAttr *indexAttr;
519 Type *type;
Chris Lattner85cf26d2018-08-02 16:54:36 -0700520
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700521 return parser->parseOperand(operandInfo) || parser->parseComma() ||
522 parser->parseAttribute(indexAttr, "index", result->attributes) ||
523 parser->parseOptionalAttributeDict(result->attributes) ||
524 parser->parseColonType(type) ||
525 parser->resolveOperand(operandInfo, type, result->operands) ||
526 parser->addTypeToList(parser->getBuilder().getAffineIntType(),
527 result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700528}
529
Chris Lattner015100b2018-09-09 20:40:23 -0700530bool DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700531 // Check that we have an integer index operand.
532 auto indexAttr = getAttrOfType<IntegerAttr>("index");
533 if (!indexAttr)
Chris Lattner015100b2018-09-09 20:40:23 -0700534 return emitOpError("requires an integer attribute named 'index'");
Chris Lattner9361fb32018-07-24 08:34:58 -0700535 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700536
Chris Lattner9361fb32018-07-24 08:34:58 -0700537 auto *type = getOperand()->getType();
538 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
539 if (index >= tensorType->getRank())
Chris Lattner015100b2018-09-09 20:40:23 -0700540 return emitOpError("index is out of range");
Chris Lattner9361fb32018-07-24 08:34:58 -0700541 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
542 if (index >= memrefType->getRank())
Chris Lattner015100b2018-09-09 20:40:23 -0700543 return emitOpError("index is out of range");
Chris Lattner9361fb32018-07-24 08:34:58 -0700544
545 } else if (isa<UnrankedTensorType>(type)) {
546 // ok, assumed to be in-range.
547 } else {
Chris Lattner015100b2018-09-09 20:40:23 -0700548 return emitOpError("requires an operand with tensor or memref type");
Chris Lattner9361fb32018-07-24 08:34:58 -0700549 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700550
Chris Lattner015100b2018-09-09 20:40:23 -0700551 return false;
Chris Lattner21e67f62018-07-06 10:46:19 -0700552}
553
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700554//===----------------------------------------------------------------------===//
Chris Lattner8c7feba2018-08-23 09:58:23 -0700555// ExtractElementOp
556//===----------------------------------------------------------------------===//
557
558void ExtractElementOp::build(Builder *builder, OperationState *result,
559 SSAValue *aggregate,
560 ArrayRef<SSAValue *> indices) {
561 auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
562 result->addOperands(aggregate);
563 result->addOperands(indices);
564 result->types.push_back(aggregateType->getElementType());
565}
566
567void ExtractElementOp::print(OpAsmPrinter *p) const {
568 *p << "extract_element " << *getAggregate() << '[';
569 p->printOperands(getIndices());
570 *p << ']';
571 p->printOptionalAttrDict(getAttrs());
572 *p << " : " << *getAggregate()->getType();
573}
574
575bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
576 OpAsmParser::OperandType aggregateInfo;
577 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
578 VectorOrTensorType *type;
579
580 auto affineIntTy = parser->getBuilder().getAffineIntType();
581 return parser->parseOperand(aggregateInfo) ||
582 parser->parseOperandList(indexInfo, -1,
583 OpAsmParser::Delimiter::Square) ||
584 parser->parseOptionalAttributeDict(result->attributes) ||
585 parser->parseColonType(type) ||
586 parser->resolveOperand(aggregateInfo, type, result->operands) ||
587 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
588 parser->addTypeToList(type->getElementType(), result->types);
589}
590
Chris Lattner015100b2018-09-09 20:40:23 -0700591bool ExtractElementOp::verify() const {
Chris Lattner8c7feba2018-08-23 09:58:23 -0700592 if (getNumOperands() == 0)
Chris Lattner015100b2018-09-09 20:40:23 -0700593 return emitOpError("expected an aggregate to index into");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700594
595 auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
596 if (!aggregateType)
Chris Lattner015100b2018-09-09 20:40:23 -0700597 return emitOpError("first operand must be a vector or tensor");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700598
599 if (getResult()->getType() != aggregateType->getElementType())
Chris Lattner015100b2018-09-09 20:40:23 -0700600 return emitOpError("result type must match element type of aggregate");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700601
602 for (auto *idx : getIndices())
603 if (!idx->getType()->isAffineInt())
Chris Lattner015100b2018-09-09 20:40:23 -0700604 return emitOpError("index to extract_element must have 'affineint' type");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700605
606 // Verify the # indices match if we have a ranked type.
607 auto aggregateRank = aggregateType->getRankIfPresent();
608 if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
Chris Lattner015100b2018-09-09 20:40:23 -0700609 return emitOpError("incorrect number of indices for extract_element");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700610
Chris Lattner015100b2018-09-09 20:40:23 -0700611 return false;
Chris Lattner8c7feba2018-08-23 09:58:23 -0700612}
613
614//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700615// LoadOp
616//===----------------------------------------------------------------------===//
617
Chris Lattner8c7feba2018-08-23 09:58:23 -0700618void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
619 ArrayRef<SSAValue *> indices) {
620 auto *memrefType = cast<MemRefType>(memref->getType());
621 result->addOperands(memref);
622 result->addOperands(indices);
623 result->types.push_back(memrefType->getElementType());
624}
625
Chris Lattner85ee1512018-07-25 11:15:20 -0700626void LoadOp::print(OpAsmPrinter *p) const {
627 *p << "load " << *getMemRef() << '[';
628 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700629 *p << ']';
630 p->printOptionalAttrDict(getAttrs());
631 *p << " : " << *getMemRef()->getType();
Chris Lattner85ee1512018-07-25 11:15:20 -0700632}
633
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700634bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
Chris Lattner85ee1512018-07-25 11:15:20 -0700635 OpAsmParser::OperandType memrefInfo;
636 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
637 MemRefType *type;
Chris Lattner85ee1512018-07-25 11:15:20 -0700638
639 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700640 return parser->parseOperand(memrefInfo) ||
641 parser->parseOperandList(indexInfo, -1,
642 OpAsmParser::Delimiter::Square) ||
643 parser->parseOptionalAttributeDict(result->attributes) ||
644 parser->parseColonType(type) ||
645 parser->resolveOperand(memrefInfo, type, result->operands) ||
646 parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
647 parser->addTypeToList(type->getElementType(), result->types);
Chris Lattner85ee1512018-07-25 11:15:20 -0700648}
649
Chris Lattner015100b2018-09-09 20:40:23 -0700650bool LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700651 if (getNumOperands() == 0)
Chris Lattner015100b2018-09-09 20:40:23 -0700652 return emitOpError("expected a memref to load from");
Chris Lattner85ee1512018-07-25 11:15:20 -0700653
Chris Lattner3164ae62018-07-28 09:36:25 -0700654 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
655 if (!memRefType)
Chris Lattner015100b2018-09-09 20:40:23 -0700656 return emitOpError("first operand must be a memref");
MLIR Team3fa00ab2018-07-24 10:13:31 -0700657
Chris Lattner8c7feba2018-08-23 09:58:23 -0700658 if (getResult()->getType() != memRefType->getElementType())
Chris Lattner015100b2018-09-09 20:40:23 -0700659 return emitOpError("result type must match element type of memref");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700660
661 if (memRefType->getRank() != getNumOperands() - 1)
Chris Lattner015100b2018-09-09 20:40:23 -0700662 return emitOpError("incorrect number of indices for load");
Chris Lattner8c7feba2018-08-23 09:58:23 -0700663
Chris Lattner3164ae62018-07-28 09:36:25 -0700664 for (auto *idx : getIndices())
665 if (!idx->getType()->isAffineInt())
Chris Lattner015100b2018-09-09 20:40:23 -0700666 return emitOpError("index to load must have 'affineint' type");
MLIR Team3fa00ab2018-07-24 10:13:31 -0700667
Chris Lattner3164ae62018-07-28 09:36:25 -0700668 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700669
Chris Lattner3164ae62018-07-28 09:36:25 -0700670 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
671 // result of an affine_apply.
Chris Lattner015100b2018-09-09 20:40:23 -0700672 return false;
MLIR Team3fa00ab2018-07-24 10:13:31 -0700673}
674
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700675//===----------------------------------------------------------------------===//
676// ReturnOp
677//===----------------------------------------------------------------------===//
678
679bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
680 SmallVector<OpAsmParser::OperandType, 2> opInfo;
681 SmallVector<Type *, 2> types;
Chris Lattner1aa46322018-08-21 17:55:22 -0700682 llvm::SMLoc loc;
683 return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700684 (!opInfo.empty() && parser->parseColonTypeList(types)) ||
Chris Lattner1aa46322018-08-21 17:55:22 -0700685 parser->resolveOperands(opInfo, types, loc, result->operands);
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700686}
687
688void ReturnOp::print(OpAsmPrinter *p) const {
689 *p << "return";
690 if (getNumOperands() > 0) {
691 *p << " ";
692 p->printOperands(operand_begin(), operand_end());
693 *p << " : ";
694 interleave(operand_begin(), operand_end(),
MLIR Team6a220a62018-09-07 12:34:19 -0700695 [&](const SSAValue *e) { p->printType(e->getType()); },
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700696 [&]() { *p << ", "; });
697 }
698}
699
Chris Lattner015100b2018-09-09 20:40:23 -0700700bool ReturnOp::verify() const {
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700701 // ReturnOp must be part of an ML function.
702 if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
Tatiana Shpeisman3abd6bd2018-08-16 20:19:44 -0700703 StmtBlock *block = stmt->getBlock();
704 if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
Chris Lattner015100b2018-09-09 20:40:23 -0700705 return emitOpError("must be the last statement in the ML function");
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700706
707 // Return success. Checking that operand types match those in the function
708 // signature is performed in the ML function verifier.
Chris Lattner015100b2018-09-09 20:40:23 -0700709 return false;
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700710 }
Chris Lattner015100b2018-09-09 20:40:23 -0700711 return emitOpError("cannot occur in a CFG function");
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700712}
713
714//===----------------------------------------------------------------------===//
Chris Lattner0473e4b2018-09-06 17:31:21 -0700715// ShapeCastOp
716//===----------------------------------------------------------------------===//
717
718void ShapeCastOp::build(Builder *builder, OperationState *result,
719 SSAValue *input, Type *resultType) {
720 result->addOperands(input);
721 result->addTypes(resultType);
722}
723
Chris Lattner015100b2018-09-09 20:40:23 -0700724bool ShapeCastOp::verify() const {
Chris Lattner0473e4b2018-09-06 17:31:21 -0700725 auto *opType = dyn_cast<TensorType>(getOperand()->getType());
726 auto *resType = dyn_cast<TensorType>(getResult()->getType());
727 if (!opType || !resType)
Chris Lattner015100b2018-09-09 20:40:23 -0700728 return emitOpError("requires input and result types to be tensors");
Chris Lattner0473e4b2018-09-06 17:31:21 -0700729
730 if (opType == resType)
Chris Lattner015100b2018-09-09 20:40:23 -0700731 return emitOpError("requires the input and result type to be different");
Chris Lattner0473e4b2018-09-06 17:31:21 -0700732
733 if (opType->getElementType() != resType->getElementType())
Chris Lattner015100b2018-09-09 20:40:23 -0700734 return emitOpError(
735 "requires input and result element types to be the same");
Chris Lattner0473e4b2018-09-06 17:31:21 -0700736
737 // If the source or destination are unranked, then the cast is valid.
738 auto *opRType = dyn_cast<RankedTensorType>(opType);
739 auto *resRType = dyn_cast<RankedTensorType>(resType);
740 if (!opRType || !resRType)
Chris Lattner015100b2018-09-09 20:40:23 -0700741 return false;
Chris Lattner0473e4b2018-09-06 17:31:21 -0700742
743 // If they are both ranked, they have to have the same rank, and any specified
744 // dimensions must match.
745 if (opRType->getRank() != resRType->getRank())
Chris Lattner015100b2018-09-09 20:40:23 -0700746 return emitOpError("requires input and result ranks to match");
Chris Lattner0473e4b2018-09-06 17:31:21 -0700747
748 for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
749 int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
750 if (opDim != -1 && resultDim != -1 && opDim != resultDim)
Chris Lattner015100b2018-09-09 20:40:23 -0700751 return emitOpError("requires static dimensions to match");
Chris Lattner0473e4b2018-09-06 17:31:21 -0700752 }
753
Chris Lattner015100b2018-09-09 20:40:23 -0700754 return false;
Chris Lattner0473e4b2018-09-06 17:31:21 -0700755}
756
757//===----------------------------------------------------------------------===//
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700758// StoreOp
759//===----------------------------------------------------------------------===//
760
MLIR Team3802ebd2018-08-31 14:49:38 -0700761void StoreOp::build(Builder *builder, OperationState *result,
762 SSAValue *valueToStore, SSAValue *memref,
763 ArrayRef<SSAValue *> indices) {
764 result->addOperands(valueToStore);
765 result->addOperands(memref);
766 result->addOperands(indices);
767}
768
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700769void StoreOp::print(OpAsmPrinter *p) const {
770 *p << "store " << *getValueToStore();
771 *p << ", " << *getMemRef() << '[';
772 p->printOperands(getIndices());
Chris Lattner85cf26d2018-08-02 16:54:36 -0700773 *p << ']';
774 p->printOptionalAttrDict(getAttrs());
775 *p << " : " << *getMemRef()->getType();
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700776}
777
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700778bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700779 OpAsmParser::OperandType storeValueInfo;
780 OpAsmParser::OperandType memrefInfo;
781 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700782 MemRefType *memrefType;
783
784 auto affineIntTy = parser->getBuilder().getAffineIntType();
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700785 return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
786 parser->parseOperand(memrefInfo) ||
787 parser->parseOperandList(indexInfo, -1,
788 OpAsmParser::Delimiter::Square) ||
789 parser->parseOptionalAttributeDict(result->attributes) ||
790 parser->parseColonType(memrefType) ||
Chris Lattner8bdbebf2018-08-08 11:02:58 -0700791 parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
792 result->operands) ||
793 parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
Chris Lattnereed6c4d2018-08-07 09:12:35 -0700794 parser->resolveOperands(indexInfo, affineIntTy, result->operands);
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700795}
796
Chris Lattner015100b2018-09-09 20:40:23 -0700797bool StoreOp::verify() const {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700798 if (getNumOperands() < 2)
Chris Lattner015100b2018-09-09 20:40:23 -0700799 return emitOpError("expected a value to store and a memref");
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700800
801 // Second operand is a memref type.
802 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
803 if (!memRefType)
Chris Lattner015100b2018-09-09 20:40:23 -0700804 return emitOpError("second operand must be a memref");
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700805
806 // First operand must have same type as memref element type.
807 if (getValueToStore()->getType() != memRefType->getElementType())
Chris Lattner015100b2018-09-09 20:40:23 -0700808 return emitOpError("first operand must have same type memref element type");
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700809
810 if (getNumOperands() != 2 + memRefType->getRank())
Chris Lattner015100b2018-09-09 20:40:23 -0700811 return emitOpError("store index operand count not equal to memref rank");
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700812
813 for (auto *idx : getIndices())
814 if (!idx->getType()->isAffineInt())
Chris Lattner015100b2018-09-09 20:40:23 -0700815 return emitOpError("index to load must have 'affineint' type");
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700816
817 // TODO: Verify we have the right number of indices.
818
819 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
820 // result of an affine_apply.
Chris Lattner015100b2018-09-09 20:40:23 -0700821 return false;
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700822}
823
Tatiana Shpeismand9b1d862018-08-09 12:28:58 -0700824//===----------------------------------------------------------------------===//
825// Register operations.
826//===----------------------------------------------------------------------===//
827
Chris Lattnerff0d5902018-07-05 09:12:11 -0700828/// Install the standard operations in the specified operation set.
829void mlir::registerStandardOperations(OperationSet &opSet) {
Chris Lattner1aa46322018-08-21 17:55:22 -0700830 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
Chris Lattner8c7feba2018-08-23 09:58:23 -0700831 ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
Chris Lattner0473e4b2018-09-06 17:31:21 -0700832 ReturnOp, ShapeCastOp, StoreOp>(
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700833 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700834}