blob: 66d3d05c885559b414246b9dda1e4bfe46f3469e [file] [log] [blame]
Chris Lattnerff0d5902018-07-05 09:12:11 -07001//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/StandardOps.h"
MLIR Team3fa00ab2018-07-24 10:13:31 -070019#include "mlir/IR/AffineMap.h"
Chris Lattner85ee1512018-07-25 11:15:20 -070020#include "mlir/IR/Builders.h"
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070021#include "mlir/IR/OpImplementation.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070022#include "mlir/IR/OperationSet.h"
Chris Lattner9361fb32018-07-24 08:34:58 -070023#include "mlir/IR/SSAValue.h"
24#include "mlir/IR/Types.h"
Chris Lattnerff0d5902018-07-05 09:12:11 -070025#include "llvm/Support/raw_ostream.h"
26using namespace mlir;
27
MLIR Team554a8ad2018-07-30 13:08:05 -070028static void printDimAndSymbolList(Operation::const_operand_iterator begin,
29 Operation::const_operand_iterator end,
30 unsigned numDims, OpAsmPrinter *p) {
31 *p << '(';
32 p->printOperands(begin, begin + numDims);
33 *p << ')';
34
35 if (begin + numDims != end) {
36 *p << '[';
37 p->printOperands(begin + numDims, end);
38 *p << ']';
39 }
40}
41
42// Parses dimension and symbol list, and sets 'numDims' to the number of
43// dimension operands parsed.
44// Returns 'false' on success and 'true' on error.
45static bool
46parseDimAndSymbolList(OpAsmParser *parser,
47 SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
48 SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
49 if (parser->parseOperandList(opInfos, -1,
50 OpAsmParser::Delimeter::ParenDelimeter))
51 return true;
52 // Store number of dimensions for validation by caller.
53 numDims = opInfos.size();
54
55 // Parse the optional symbol operands.
56 auto *affineIntTy = parser->getBuilder().getAffineIntType();
57 if (parser->parseOperandList(
58 opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
59 parser->resolveOperands(opInfos, affineIntTy, operands))
60 return true;
61 return false;
62}
63
MLIR Team39a3a602018-07-24 17:43:56 -070064// TODO: Have verify functions return std::string to enable more descriptive
65// error messages.
Chris Lattner85ee1512018-07-25 11:15:20 -070066OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
67 SmallVector<OpAsmParser::OperandType, 2> ops;
68 Type *type;
69 SSAValue *lhs, *rhs;
70 if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
71 parser->resolveOperand(ops[0], type, lhs) ||
72 parser->resolveOperand(ops[1], type, rhs))
73 return {};
74
75 return OpAsmParserResult({lhs, rhs}, type);
76}
77
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -070078void AddFOp::print(OpAsmPrinter *p) const {
79 *p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
80 << *getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -070081}
82
Chris Lattner21e67f62018-07-06 10:46:19 -070083// Return an error message on failure.
84const char *AddFOp::verify() const {
85 // TODO: Check that the types of the LHS and RHS match.
86 // TODO: This should be a refinement of TwoOperands.
87 // TODO: There should also be a OneResultWhoseTypeMatchesFirstOperand.
88 return nullptr;
89}
90
Chris Lattner3164ae62018-07-28 09:36:25 -070091OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
92 SmallVector<OpAsmParser::OperandType, 2> opInfos;
93 SmallVector<SSAValue *, 4> operands;
94
95 auto &builder = parser->getBuilder();
96 auto *affineIntTy = builder.getAffineIntType();
97
98 AffineMapAttr *mapAttr;
MLIR Team554a8ad2018-07-30 13:08:05 -070099 if (parser->parseAttribute(mapAttr))
Chris Lattner3164ae62018-07-28 09:36:25 -0700100 return {};
101
MLIR Team554a8ad2018-07-30 13:08:05 -0700102 unsigned numDims;
103 if (parseDimAndSymbolList(parser, opInfos, operands, numDims))
104 return {};
Chris Lattner3164ae62018-07-28 09:36:25 -0700105 auto *map = mapAttr->getValue();
MLIR Team554a8ad2018-07-30 13:08:05 -0700106
Chris Lattner3164ae62018-07-28 09:36:25 -0700107 if (map->getNumDims() != numDims ||
108 numDims + map->getNumSymbols() != opInfos.size()) {
109 parser->emitError(parser->getNameLoc(),
110 "dimension or symbol index mismatch");
111 return {};
112 }
113
114 SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
115 return OpAsmParserResult(
116 operands, resultTypes,
117 NamedAttribute(builder.getIdentifier("map"), mapAttr));
118}
119
120void AffineApplyOp::print(OpAsmPrinter *p) const {
121 auto *map = getAffineMap();
122 *p << "affine_apply " << *map;
MLIR Team554a8ad2018-07-30 13:08:05 -0700123 printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
Chris Lattner3164ae62018-07-28 09:36:25 -0700124}
125
126const char *AffineApplyOp::verify() const {
127 // Check that affine map attribute was specified.
128 auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
129 if (!affineMapAttr)
130 return "requires an affine map.";
131
132 // Check input and output dimensions match.
133 auto *map = affineMapAttr->getValue();
134
135 // Verify that operand count matches affine map dimension and symbol count.
136 if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
137 return "operand count and affine map dimension and symbol count must match";
138
139 // Verify that result count matches affine map result count.
140 if (getNumResults() != map->getNumResults())
141 return "result count and affine map result count must match";
142
143 return nullptr;
144}
145
MLIR Team554a8ad2018-07-30 13:08:05 -0700146void AllocOp::print(OpAsmPrinter *p) const {
147 MemRefType *type = cast<MemRefType>(getMemRef()->getType());
148 *p << "alloc";
149 // Print dynamic dimension operands.
150 printDimAndSymbolList(operand_begin(), operand_end(),
151 type->getNumDynamicDims(), p);
152 // Print memref type.
153 *p << " : " << *type;
154}
155
156OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
157 MemRefType *type;
158 SmallVector<SSAValue *, 4> operands;
159 SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
160
161 // Parse the dimension operands and optional symbol operands.
162 unsigned numDimOperands;
163 if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands))
164 return {};
165
166 // Parse memref type.
167 if (parser->parseColonType(type))
168 return {};
169
170 // Check numDynamicDims against number of question marks in memref type.
171 if (numDimOperands != type->getNumDynamicDims()) {
172 parser->emitError(parser->getNameLoc(),
173 "Dynamic dimensions count mismatch: dimension operand "
174 "count does not equal memref dynamic dimension count.");
175 return {};
176 }
177
178 // Check that the number of symbol operands matches the number of symbols in
179 // the first affinemap of the memref's affine map composition.
180 // Note that a memref must specify at least one affine map in the composition.
181 if ((operandsInfo.size() - numDimOperands) !=
182 type->getAffineMaps()[0]->getNumSymbols()) {
183 parser->emitError(parser->getNameLoc(),
184 "AffineMap symbol count mismatch: symbol operand "
185 "count does not equal memref affine map symbol count.");
186 return {};
187 }
188
189 return OpAsmParserResult(operands, type);
190}
191
192const char *AllocOp::verify() const {
193 // TODO(andydavis): Verify alloc.
194 return nullptr;
195}
196
Chris Lattner9361fb32018-07-24 08:34:58 -0700197/// The constant op requires an attribute, and furthermore requires that it
198/// matches the return type.
199const char *ConstantOp::verify() const {
200 auto *value = getValue();
201 if (!value)
202 return "requires a 'value' attribute";
203
204 auto *type = this->getType();
Chris Lattner1ec70572018-07-24 10:41:30 -0700205 if (isa<IntegerType>(type) || type->isAffineInt()) {
Chris Lattner9361fb32018-07-24 08:34:58 -0700206 if (!isa<IntegerAttr>(value))
207 return "requires 'value' to be an integer for an integer result type";
208 return nullptr;
209 }
210
211 if (isa<FunctionType>(type)) {
212 // TODO: Verify a function attr.
213 }
214
215 return "requires a result type that aligns with the 'value' attribute";
216}
217
Chris Lattner3da86ad2018-07-26 09:58:23 -0700218/// ConstantIntOp only matches values whose result type is an IntegerType or
219/// AffineInt.
Chris Lattner9361fb32018-07-24 08:34:58 -0700220bool ConstantIntOp::isClassFor(const Operation *op) {
221 return ConstantOp::isClassFor(op) &&
Chris Lattner3da86ad2018-07-26 09:58:23 -0700222 (isa<IntegerType>(op->getResult(0)->getType()) ||
223 op->getResult(0)->getType()->isAffineInt());
Chris Lattner9361fb32018-07-24 08:34:58 -0700224}
225
Chris Lattnerdd0c2ca2018-07-24 16:07:22 -0700226void DimOp::print(OpAsmPrinter *p) const {
227 *p << "dim " << *getOperand() << ", " << getIndex() << " : "
228 << *getOperand()->getType();
Chris Lattnerff0d5902018-07-05 09:12:11 -0700229}
230
Chris Lattner85ee1512018-07-25 11:15:20 -0700231OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
232 OpAsmParser::OperandType operandInfo;
233 IntegerAttr *indexAttr;
234 Type *type;
235 SSAValue *operand;
236 if (parser->parseOperand(operandInfo) || parser->parseComma() ||
237 parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
238 parser->resolveOperand(operandInfo, type, operand))
239 return {};
240
241 auto &builder = parser->getBuilder();
242 return OpAsmParserResult(
243 operand, builder.getAffineIntType(),
244 NamedAttribute(builder.getIdentifier("index"), indexAttr));
245}
246
Chris Lattner21e67f62018-07-06 10:46:19 -0700247const char *DimOp::verify() const {
Chris Lattner21e67f62018-07-06 10:46:19 -0700248 // Check that we have an integer index operand.
249 auto indexAttr = getAttrOfType<IntegerAttr>("index");
250 if (!indexAttr)
Chris Lattner9361fb32018-07-24 08:34:58 -0700251 return "requires an integer attribute named 'index'";
252 uint64_t index = (uint64_t)indexAttr->getValue();
Chris Lattner21e67f62018-07-06 10:46:19 -0700253
Chris Lattner9361fb32018-07-24 08:34:58 -0700254 auto *type = getOperand()->getType();
255 if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
256 if (index >= tensorType->getRank())
257 return "index is out of range";
258 } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
259 if (index >= memrefType->getRank())
260 return "index is out of range";
261
262 } else if (isa<UnrankedTensorType>(type)) {
263 // ok, assumed to be in-range.
264 } else {
265 return "requires an operand with tensor or memref type";
266 }
Chris Lattner21e67f62018-07-06 10:46:19 -0700267
268 return nullptr;
269}
270
Chris Lattner85ee1512018-07-25 11:15:20 -0700271void LoadOp::print(OpAsmPrinter *p) const {
272 *p << "load " << *getMemRef() << '[';
273 p->printOperands(getIndices());
274 *p << "] : " << *getMemRef()->getType();
275}
276
277OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
278 OpAsmParser::OperandType memrefInfo;
279 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
280 MemRefType *type;
281 SmallVector<SSAValue *, 4> operands;
282
283 auto affineIntTy = parser->getBuilder().getAffineIntType();
284 if (parser->parseOperand(memrefInfo) ||
285 parser->parseOperandList(indexInfo, -1,
286 OpAsmParser::Delimeter::SquareDelimeter) ||
287 parser->parseColonType(type) ||
288 parser->resolveOperands(memrefInfo, type, operands) ||
289 parser->resolveOperands(indexInfo, affineIntTy, operands))
290 return {};
291
292 return OpAsmParserResult(operands, type->getElementType());
293}
294
295const char *LoadOp::verify() const {
Chris Lattner3164ae62018-07-28 09:36:25 -0700296 if (getNumOperands() == 0)
297 return "expected a memref to load from";
Chris Lattner85ee1512018-07-25 11:15:20 -0700298
Chris Lattner3164ae62018-07-28 09:36:25 -0700299 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
300 if (!memRefType)
301 return "first operand must be a memref";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700302
Chris Lattner3164ae62018-07-28 09:36:25 -0700303 for (auto *idx : getIndices())
304 if (!idx->getType()->isAffineInt())
305 return "index to load must have 'affineint' type";
MLIR Team3fa00ab2018-07-24 10:13:31 -0700306
Chris Lattner3164ae62018-07-28 09:36:25 -0700307 // TODO: Verify we have the right number of indices.
MLIR Team39a3a602018-07-24 17:43:56 -0700308
Chris Lattner3164ae62018-07-28 09:36:25 -0700309 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
310 // result of an affine_apply.
MLIR Team3fa00ab2018-07-24 10:13:31 -0700311 return nullptr;
312}
313
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700314void StoreOp::print(OpAsmPrinter *p) const {
315 *p << "store " << *getValueToStore();
316 *p << ", " << *getMemRef() << '[';
317 p->printOperands(getIndices());
318 *p << "] : " << *getMemRef()->getType();
319}
320
321OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
322 OpAsmParser::OperandType storeValueInfo;
323 OpAsmParser::OperandType memrefInfo;
324 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
325 SmallVector<SSAValue *, 4> operands;
326 MemRefType *memrefType;
327
328 auto affineIntTy = parser->getBuilder().getAffineIntType();
329 if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
330 parser->parseOperand(memrefInfo) ||
331 parser->parseOperandList(indexInfo, -1,
332 OpAsmParser::Delimeter::SquareDelimeter) ||
333 parser->parseColonType(memrefType) ||
334 parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
335 operands) ||
336 parser->resolveOperands(memrefInfo, memrefType, operands) ||
337 parser->resolveOperands(indexInfo, affineIntTy, operands))
338 return {};
339
340 return OpAsmParserResult(operands, {});
341}
342
343const char *StoreOp::verify() const {
344 if (getNumOperands() < 2)
345 return "expected a value to store and a memref";
346
347 // Second operand is a memref type.
348 auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
349 if (!memRefType)
350 return "second operand must be a memref";
351
352 // First operand must have same type as memref element type.
353 if (getValueToStore()->getType() != memRefType->getElementType())
354 return "first operand must have same type memref element type ";
355
356 if (getNumOperands() != 2 + memRefType->getRank())
357 return "store index operand count not equal to memref rank";
358
359 for (auto *idx : getIndices())
360 if (!idx->getType()->isAffineInt())
361 return "index to load must have 'affineint' type";
362
363 // TODO: Verify we have the right number of indices.
364
365 // TODO: in MLFunction verify that the indices are parameters, IV's, or the
366 // result of an affine_apply.
367 return nullptr;
368}
369
Chris Lattnerff0d5902018-07-05 09:12:11 -0700370/// Install the standard operations in the specified operation set.
371void mlir::registerStandardOperations(OperationSet &opSet) {
MLIR Teamc5efe7e2018-07-31 14:11:38 -0700372 opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
373 StoreOp>(
374 /*prefix=*/"");
Chris Lattnerff0d5902018-07-05 09:12:11 -0700375}