blob: 3d74e58b64cfe2520d6e59ad13d6cefe2100aaf9 [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team554a8ad2018-07-30 13:08:05 -070028static void printDimAndSymbolList(Operation::const_operand_iterator begin,
29 Operation::const_operand_iterator end,
30 unsigned numDims, OpAsmPrinter *p) {
31 *p << '(';
32 p->printOperands(begin, begin + numDims);
33 *p << ')';
34
35 if (begin + numDims != end) {
36 *p << '[';
37 p->printOperands(begin + numDims, end);
38 *p << ']';
39 }
40}
41
42// Parses dimension and symbol list, and sets 'numDims' to the number of
43// dimension operands parsed.
44// Returns 'false' on success and 'true' on error.
45static bool
46parseDimAndSymbolList(OpAsmParser *parser,
47 SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
48 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
49 if (parser->parseOperandList(opInfos, -1,
50 OpAsmParser::Delimeter::ParenDelimeter))
51 return true;
52 // Store number of dimensions for validation by caller.
53 numDims = opInfos.size();
54
55 // Parse the optional symbol operands.
56 auto *affineIntTy = parser->getBuilder().getAffineIntType();
57 if (parser->parseOperandList(
58 opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
59 parser->resolveOperands(opInfos, affineIntTy, operands))
60 return true;
61 return false;
62}
63
MLIR Team39a3a602018-07-24 17:43:56 -070064// TODO: Have verify functions return std::string to enable more descriptive
65// error messages.
Chris Lattner85ee1512018-07-25 11:15:20 -070066OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
67 SmallVector<OpAsmParser::OperandType, 2> ops;
68 Type *type;
69 SSAValue *lhs, *rhs;
70 if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
71 parser->resolveOperand(ops[0], type, lhs) ||
72 parser->resolveOperand(ops[1], type, rhs))
73 return {};
74
75 return OpAsmParserResult({lhs, rhs}, type);
76}
77
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070078void AddFOp::print(OpAsmPrinter *p) const {
79 *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
80 << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070081}
82
Chris Lattner21e67f62018-07-06 10:46:19 -070083// Return an error message on failure.
84const char *AddFOp::verify() const {
85 // TODO: Check that the types of the LHS and RHS match.
86 // TODO: This should be a refinement of TwoOperands.
87 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
88 return nullptr;
89}
90
Chris Lattner3164ae62018-07-28 09:36:25 -070091OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
92 SmallVector<OpAsmParser::OperandType, 2> opInfos;
93 SmallVector<SSAValue *, 4> operands;
94
95 auto &builder = parser->getBuilder();
96 auto *affineIntTy = builder.getAffineIntType();
97
98 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -070099 unsigned numDims;
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700100 if (parser->parseAttribute(mapAttr) ||
101 parseDimAndSymbolList(parser, opInfos, operands, numDims))
MLIR Team554a8ad2018-07-30 13:08:05 -0700102 return {};
Chris Lattner3164ae62018-07-28 09:36:25 -0700103 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700104
Chris Lattner3164ae62018-07-28 09:36:25 -0700105 if (map->getNumDims() != numDims ||
106 numDims + map->getNumSymbols() != opInfos.size()) {
107 parser->emitError(parser->getNameLoc(),
108 "dimension or symbol index mismatch");
109 return {};
110 }
111
112 SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
113 return OpAsmParserResult(
114 operands, resultTypes,
115 NamedAttribute(builder.getIdentifier("map"), mapAttr));
116}
117
118void AffineApplyOp::print(OpAsmPrinter *p) const {
119 auto *map = getAffineMap();
120 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700121 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner3164ae62018-07-28 09:36:25 -0700122}
123
124const char *AffineApplyOp::verify() const {
125 // Check that affine map attribute was specified.
126 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
127 if (!affineMapAttr)
128 return "requires an affine map.";
129
130 // Check input and output dimensions match.
131 auto *map = affineMapAttr->getValue();
132
133 // Verify that operand count matches affine map dimension and symbol count.
134 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
135 return "operand count and affine map dimension and symbol count must match";
136
137 // Verify that result count matches affine map result count.
138 if (getNumResults() != map->getNumResults())
139 return "result count and affine map result count must match";
140
141 return nullptr;
142}
143
MLIR Team554a8ad2018-07-30 13:08:05 -0700144void AllocOp::print(OpAsmPrinter *p) const {
145 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
146 *p << "alloc";
147 // Print dynamic dimension operands.
148 printDimAndSymbolList(operand_begin(), operand_end(),
149 type->getNumDynamicDims(), p);
150 // Print memref type.
151 *p << " : " << *type;
152}
153
154OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
155 MemRefType *type;
156 SmallVector<SSAValue *, 4> operands;
157 SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
158
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700159 // Parse the dimension operands and optional symbol operands, followed by a
160 // memref type.
MLIR Team554a8ad2018-07-30 13:08:05 -0700161 unsigned numDimOperands;
Chris Lattner7d3b77c2018-07-31 16:21:36 -0700162 if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
163 parser->parseColonType(type))
MLIR Team554a8ad2018-07-30 13:08:05 -0700164 return {};
165
166 // Check numDynamicDims against number of question marks in memref type.
167 if (numDimOperands != type->getNumDynamicDims()) {
168 parser->emitError(parser->getNameLoc(),
169 "Dynamic dimensions count mismatch: dimension operand "
170 "count does not equal memref dynamic dimension count.");
171 return {};
172 }
173
174 // Check that the number of symbol operands matches the number of symbols in
175 // the first affinemap of the memref's affine map composition.
176 // Note that a memref must specify at least one affine map in the composition.
177 if ((operandsInfo.size() - numDimOperands) !=
178 type->getAffineMaps()[0]->getNumSymbols()) {
179 parser->emitError(parser->getNameLoc(),
180 "AffineMap symbol count mismatch: symbol operand "
181 "count does not equal memref affine map symbol count.");
182 return {};
183 }
184
185 return OpAsmParserResult(operands, type);
186}
187
188const char *AllocOp::verify() const {
189 // TODO(andydavis): Verify alloc.
190 return nullptr;
191}
192
Chris Lattner9361fb32018-07-24 08:34:58 -0700193/// The constant op requires an attribute, and furthermore requires that it
194/// matches the return type.
195const char *ConstantOp::verify() const {
196 auto *value = getValue();
197 if (!value)
198 return "requires a 'value' attribute";
199
200 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700201 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700202 if (!isa<IntegerAttr>(value))
203 return "requires 'value' to be an integer for an integer result type";
204 return nullptr;
205 }
206
207 if (isa<FunctionType>(type)) {
208 // TODO: Verify a function attr.
209 }
210
211 return "requires a result type that aligns with the 'value' attribute";
212}
213
Chris Lattner3da86ad2018-07-26 09:58:23 -0700214/// ConstantIntOp only matches values whose result type is an IntegerType or
215/// AffineInt.
Chris Lattner9361fb32018-07-24 08:34:58 -0700216bool ConstantIntOp::isClassFor(const Operation *op) {
217 return ConstantOp::isClassFor(op) &&
Chris Lattner3da86ad2018-07-26 09:58:23 -0700218 (isa<IntegerType>(op->getResult(0)->getType()) ||
219 op->getResult(0)->getType()->isAffineInt());
Chris Lattner9361fb32018-07-24 08:34:58 -0700220}
221
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700222void DimOp::print(OpAsmPrinter *p) const {
223 *p << "dim " << *getOperand() << ", " << getIndex() << " : "
224 << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700225}
226
Chris Lattner85ee1512018-07-25 11:15:20 -0700227OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
228 OpAsmParser::OperandType operandInfo;
229 IntegerAttr *indexAttr;
230 Type *type;
231 SSAValue *operand;
232 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
233 parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
234 parser->resolveOperand(operandInfo, type, operand))
235 return {};
236
237 auto &builder = parser->getBuilder();
238 return OpAsmParserResult(
239 operand, builder.getAffineIntType(),
240 NamedAttribute(builder.getIdentifier("index"), indexAttr));
241}
242
Chris Lattner21e67f62018-07-06 10:46:19 -0700243const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700244 // Check that we have an integer index operand.
245 auto indexAttr = getAttrOfType<IntegerAttr>("index");
246 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700247 return "requires an integer attribute named 'index'";
248 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700249
Chris Lattner9361fb32018-07-24 08:34:58 -0700250 auto *type = getOperand()->getType();
251 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
252 if (index >= tensorType->getRank())
253 return "index is out of range";
254 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
255 if (index >= memrefType->getRank())
256 return "index is out of range";
257
258 } else if (isa<UnrankedTensorType>(type)) {
259 // ok, assumed to be in-range.
260 } else {
261 return "requires an operand with tensor or memref type";
262 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700263
264 return nullptr;
265}
266
Chris Lattner85ee1512018-07-25 11:15:20 -0700267void LoadOp::print(OpAsmPrinter *p) const {
268 *p << "load " << *getMemRef() << '[';
269 p->printOperands(getIndices());
270 *p << "] : " << *getMemRef()->getType();
271}
272
273OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
274 OpAsmParser::OperandType memrefInfo;
275 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
276 MemRefType *type;
277 SmallVector<SSAValue *, 4> operands;
278
279 auto affineIntTy = parser->getBuilder().getAffineIntType();
280 if (parser->parseOperand(memrefInfo) ||
281 parser->parseOperandList(indexInfo, -1,
282 OpAsmParser::Delimeter::SquareDelimeter) ||
283 parser->parseColonType(type) ||
284 parser->resolveOperands(memrefInfo, type, operands) ||
285 parser->resolveOperands(indexInfo, affineIntTy, operands))
286 return {};
287
288 return OpAsmParserResult(operands, type->getElementType());
289}
290
291const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700292 if (getNumOperands() == 0)
293 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700294
Chris Lattner3164ae62018-07-28 09:36:25 -0700295 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
296 if (!memRefType)
297 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700298
Chris Lattner3164ae62018-07-28 09:36:25 -0700299 for (auto *idx : getIndices())
300 if (!idx->getType()->isAffineInt())
301 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700302
Chris Lattner3164ae62018-07-28 09:36:25 -0700303 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700304
Chris Lattner3164ae62018-07-28 09:36:25 -0700305 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
306 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700307 return nullptr;
308}
309
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700310void StoreOp::print(OpAsmPrinter *p) const {
311 *p << "store " << *getValueToStore();
312 *p << ", " << *getMemRef() << '[';
313 p->printOperands(getIndices());
314 *p << "] : " << *getMemRef()->getType();
315}
316
317OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
318 OpAsmParser::OperandType storeValueInfo;
319 OpAsmParser::OperandType memrefInfo;
320 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
321 SmallVector<SSAValue *, 4> operands;
322 MemRefType *memrefType;
323
324 auto affineIntTy = parser->getBuilder().getAffineIntType();
325 if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
326 parser->parseOperand(memrefInfo) ||
327 parser->parseOperandList(indexInfo, -1,
328 OpAsmParser::Delimeter::SquareDelimeter) ||
329 parser->parseColonType(memrefType) ||
330 parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
331 operands) ||
332 parser->resolveOperands(memrefInfo, memrefType, operands) ||
333 parser->resolveOperands(indexInfo, affineIntTy, operands))
334 return {};
335
336 return OpAsmParserResult(operands, {});
337}
338
339const char *StoreOp::verify() const {
340 if (getNumOperands() < 2)
341 return "expected a value to store and a memref";
342
343 // Second operand is a memref type.
344 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
345 if (!memRefType)
346 return "second operand must be a memref";
347
348 // First operand must have same type as memref element type.
349 if (getValueToStore()->getType() != memRefType->getElementType())
350 return "first operand must have same type memref element type ";
351
352 if (getNumOperands() != 2 + memRefType->getRank())
353 return "store index operand count not equal to memref rank";
354
355 for (auto *idx : getIndices())
356 if (!idx->getType()->isAffineInt())
357 return "index to load must have 'affineint' type";
358
359 // TODO: Verify we have the right number of indices.
360
361 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
362 // result of an affine_apply.
363 return nullptr;
364}
365
Chris Lattnerff0d5902018-07-05 09:12:11 -0700366/// Install the standard operations in the specified operation set.
367void mlir::registerStandardOperations(OperationSet &opSet) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700368 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
369 StoreOp>(
370 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700371}