blob: 7ea9c156ea41bde7611bd16536a8a7f8b71a6623 [file] [log] [blame]
Nicolas Vasilachecca53e82019-07-15 02:50:09 -07001//===- Ops.cpp - Loop MLIR Operations -------------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/Dialect/LoopOps/LoopOps.h"
River Riddleba0fa922019-08-19 11:00:47 -070019#include "mlir/Dialect/StandardOps/Ops.h"
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070020#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/Function.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/Module.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/IR/StandardTypes.h"
29#include "mlir/IR/Value.h"
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070030#include "mlir/Support/MathExtras.h"
31#include "mlir/Support/STLExtras.h"
32
33using namespace mlir;
34using namespace mlir::loop;
35
36//===----------------------------------------------------------------------===//
37// LoopOpsDialect
38//===----------------------------------------------------------------------===//
39
40LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
41 : Dialect(getDialectNamespace(), context) {
42 addOperations<
43#define GET_OP_LIST
44#include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"
45 >();
46}
47
48//===----------------------------------------------------------------------===//
49// ForOp
50//===----------------------------------------------------------------------===//
51
River Riddle729727e2019-09-20 19:47:05 -070052void ForOp::build(Builder *builder, OperationState &result, Value *lb,
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070053 Value *ub, Value *step) {
River Riddle729727e2019-09-20 19:47:05 -070054 result.addOperands({lb, ub, step});
55 Region *bodyRegion = result.addRegion();
56 ForOp::ensureTerminator(*bodyRegion, *builder, result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070057 bodyRegion->front().addArgument(builder->getIndexType());
58}
59
60LogicalResult verify(ForOp op) {
61 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp()))
62 if (cst.getValue() <= 0)
63 return op.emitOpError("constant step operand must be nonnegative");
64
65 // Check that the body defines as single block argument for the induction
66 // variable.
Nicolas Vasilache0002e292019-07-16 12:20:15 -070067 auto *body = op.getBody();
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070068 if (body->getNumArguments() != 1 ||
69 !body->getArgument(0)->getType().isIndex())
70 return op.emitOpError("expected body to have a single index argument for "
71 "the induction variable");
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070072 return success();
73}
74
River Riddle3a643de2019-09-20 20:43:02 -070075static void print(OpAsmPrinter &p, ForOp op) {
76 p << op.getOperationName() << " " << *op.getInductionVar() << " = "
77 << *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step();
78 p.printRegion(op.region(),
79 /*printEntryBlockArgs=*/false,
80 /*printBlockTerminators=*/false);
81 p.printOptionalAttrDict(op.getAttrs());
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070082}
83
River Riddle729727e2019-09-20 19:47:05 -070084static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
River Riddle27975172019-09-20 11:36:49 -070085 auto &builder = parser.getBuilder();
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070086 OpAsmParser::OperandType inductionVariable, lb, ub, step;
87 // Parse the induction variable followed by '='.
River Riddle27975172019-09-20 11:36:49 -070088 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070089 return failure();
90
91 // Parse loop bounds.
92 Type indexType = builder.getIndexType();
River Riddle27975172019-09-20 11:36:49 -070093 if (parser.parseOperand(lb) ||
River Riddle729727e2019-09-20 19:47:05 -070094 parser.resolveOperand(lb, indexType, result.operands) ||
River Riddle27975172019-09-20 11:36:49 -070095 parser.parseKeyword("to") || parser.parseOperand(ub) ||
River Riddle729727e2019-09-20 19:47:05 -070096 parser.resolveOperand(ub, indexType, result.operands) ||
River Riddle27975172019-09-20 11:36:49 -070097 parser.parseKeyword("step") || parser.parseOperand(step) ||
River Riddle729727e2019-09-20 19:47:05 -070098 parser.resolveOperand(step, indexType, result.operands))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070099 return failure();
100
101 // Parse the body region.
River Riddle729727e2019-09-20 19:47:05 -0700102 Region *body = result.addRegion();
River Riddle27975172019-09-20 11:36:49 -0700103 if (parser.parseRegion(*body, inductionVariable, indexType))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700104 return failure();
105
River Riddle729727e2019-09-20 19:47:05 -0700106 ForOp::ensureTerminator(*body, builder, result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700107
108 // Parse the optional attribute list.
River Riddle729727e2019-09-20 19:47:05 -0700109 if (parser.parseOptionalAttributeDict(result.attributes))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700110 return failure();
111
112 return success();
113}
114
115ForOp mlir::loop::getForInductionVarOwner(Value *val) {
116 auto *ivArg = dyn_cast<BlockArgument>(val);
117 if (!ivArg)
118 return ForOp();
119 assert(ivArg->getOwner() && "unlinked block argument");
River Riddle1e429542019-08-09 20:07:25 -0700120 auto *containingInst = ivArg->getOwner()->getParentOp();
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700121 return dyn_cast_or_null<ForOp>(containingInst);
122}
123
124//===----------------------------------------------------------------------===//
125// IfOp
126//===----------------------------------------------------------------------===//
127
River Riddle729727e2019-09-20 19:47:05 -0700128void IfOp::build(Builder *builder, OperationState &result, Value *cond,
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700129 bool withElseRegion) {
River Riddle729727e2019-09-20 19:47:05 -0700130 result.addOperands(cond);
131 Region *thenRegion = result.addRegion();
132 Region *elseRegion = result.addRegion();
133 IfOp::ensureTerminator(*thenRegion, *builder, result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700134 if (withElseRegion)
River Riddle729727e2019-09-20 19:47:05 -0700135 IfOp::ensureTerminator(*elseRegion, *builder, result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700136}
137
138static LogicalResult verify(IfOp op) {
139 // Verify that the entry of each child region does not have arguments.
140 for (auto &region : op.getOperation()->getRegions()) {
141 if (region.empty())
142 continue;
143
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700144 for (auto &b : region)
145 if (b.getNumArguments() != 0)
146 return op.emitOpError(
147 "requires that child entry blocks have no arguments");
148 }
149 return success();
150}
151
River Riddle729727e2019-09-20 19:47:05 -0700152static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700153 // Create the regions for 'then'.
River Riddle729727e2019-09-20 19:47:05 -0700154 result.regions.reserve(2);
155 Region *thenRegion = result.addRegion();
156 Region *elseRegion = result.addRegion();
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700157
River Riddle27975172019-09-20 11:36:49 -0700158 auto &builder = parser.getBuilder();
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700159 OpAsmParser::OperandType cond;
160 Type i1Type = builder.getIntegerType(1);
River Riddle27975172019-09-20 11:36:49 -0700161 if (parser.parseOperand(cond) ||
River Riddle729727e2019-09-20 19:47:05 -0700162 parser.resolveOperand(cond, i1Type, result.operands))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700163 return failure();
164
165 // Parse the 'then' region.
River Riddle27975172019-09-20 11:36:49 -0700166 if (parser.parseRegion(*thenRegion, {}, {}))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700167 return failure();
River Riddle729727e2019-09-20 19:47:05 -0700168 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700169
170 // If we find an 'else' keyword then parse the 'else' region.
River Riddle27975172019-09-20 11:36:49 -0700171 if (!parser.parseOptionalKeyword("else")) {
172 if (parser.parseRegion(*elseRegion, {}, {}))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700173 return failure();
River Riddle729727e2019-09-20 19:47:05 -0700174 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700175 }
176
177 // Parse the optional attribute list.
River Riddle729727e2019-09-20 19:47:05 -0700178 if (parser.parseOptionalAttributeDict(result.attributes))
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700179 return failure();
180
181 return success();
182}
183
River Riddle3a643de2019-09-20 20:43:02 -0700184static void print(OpAsmPrinter &p, IfOp op) {
185 p << IfOp::getOperationName() << " " << *op.condition();
186 p.printRegion(op.thenRegion(),
187 /*printEntryBlockArgs=*/false,
188 /*printBlockTerminators=*/false);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700189
190 // Print the 'else' regions if it exists and has a block.
191 auto &elseRegion = op.elseRegion();
192 if (!elseRegion.empty()) {
River Riddle3a643de2019-09-20 20:43:02 -0700193 p << " else";
194 p.printRegion(elseRegion,
195 /*printEntryBlockArgs=*/false,
196 /*printBlockTerminators=*/false);
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700197 }
198
River Riddle3a643de2019-09-20 20:43:02 -0700199 p.printOptionalAttrDict(op.getAttrs());
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700200}
201
202//===----------------------------------------------------------------------===//
203// TableGen'd op method definitions
204//===----------------------------------------------------------------------===//
205
206#define GET_OP_CLASSES
207#include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"