[mlir] Add loop.parallel, loop.reduce and loop.reduce.return operations.

Summary:
These operations can be used to specify a loop nest with a body that can
contain reductions. The iteration space can be iterated in any order.

RFC: https://groups.google.com/a/tensorflow.org/d/topic/mlir/pwtSgiKFPis/discussion

Differential Revision: https://reviews.llvm.org/D72394
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index 5452b3d..4824421 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -185,13 +185,13 @@
     return failure();
 
   // Parse the 'then' region.
-  if (parser.parseRegion(*thenRegion, {}, {}))
+  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
     return failure();
   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
 
   // If we find an 'else' keyword then parse the 'else' region.
   if (!parser.parseOptionalKeyword("else")) {
-    if (parser.parseRegion(*elseRegion, {}, {}))
+    if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
       return failure();
     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
   }
@@ -222,6 +222,199 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ParallelOp op) {
+  // Check that there is at least one value in lowerBound, upperBound and step.
+  // It is sufficient to test only step, because it is ensured already that the
+  // number of elements in lowerBound, upperBound and step are the same.
+  Operation::operand_range stepValues = op.step();
+  if (stepValues.empty())
+    return op.emitOpError(
+        "needs at least one tuple element for lowerBound, upperBound and step");
+
+  // Check whether all constant step values are positive.
+  for (Value stepValue : stepValues)
+    if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp()))
+      if (cst.getValue() <= 0)
+        return op.emitOpError("constant step operand must be positive");
+
+  // Check that the body defines the same number of block arguments as the
+  // number of tuple elements in step.
+  Block *body = &op.body().front();
+  if (body->getNumArguments() != stepValues.size())
+    return op.emitOpError(
+        "expects the same number of induction variables as bound and step "
+        "values");
+  for (auto arg : body->getArguments())
+    if (!arg.getType().isIndex())
+      return op.emitOpError(
+          "expects arguments for the induction variable to be of index type");
+
+  // Check that the number of results is the same as the number of ReduceOps.
+  SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
+  if (op.results().size() != reductions.size())
+    return op.emitOpError(
+        "expects number of results to be the same as number of reductions");
+
+  // Check that the types of the results and reductions are the same.
+  for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
+    auto resultType = std::get<0>(resultAndReduce).getType();
+    auto reduceOp = std::get<1>(resultAndReduce);
+    auto reduceType = reduceOp.operand().getType();
+    if (resultType != reduceType)
+      return reduceOp.emitOpError()
+             << "expects type of reduce to be the same as result type: "
+             << resultType;
+  }
+  return success();
+}
+
+static ParseResult parseParallelOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  auto &builder = parser.getBuilder();
+  // Parse an opening `(` followed by induction variables followed by `)`
+  SmallVector<OpAsmParser::OperandType, 4> ivs;
+  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
+                                     OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  // Parse loop bounds.
+  SmallVector<OpAsmParser::OperandType, 4> lower;
+  if (parser.parseEqual() ||
+      parser.parseOperandList(lower, ivs.size(),
+                              OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(lower, builder.getIndexType(), result.operands))
+    return failure();
+
+  SmallVector<OpAsmParser::OperandType, 4> upper;
+  if (parser.parseKeyword("to") ||
+      parser.parseOperandList(upper, ivs.size(),
+                              OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(upper, builder.getIndexType(), result.operands))
+    return failure();
+
+  // Parse step value.
+  SmallVector<OpAsmParser::OperandType, 4> steps;
+  if (parser.parseKeyword("step") ||
+      parser.parseOperandList(steps, ivs.size(),
+                              OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(steps, builder.getIndexType(), result.operands))
+    return failure();
+
+  // Now parse the body.
+  Region *body = result.addRegion();
+  SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
+  if (parser.parseRegion(*body, ivs, types))
+    return failure();
+
+  // Parse attributes and optional results (in case there is a reduce).
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOptionalColonTypeList(result.types))
+    return failure();
+
+  // Add a terminator if none was parsed.
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, ParallelOp op) {
+  p << op.getOperationName() << " (";
+  p.printOperands(op.body().front().getArguments());
+  p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
+    << op.step() << ")";
+  p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
+  p.printOptionalAttrDict(op.getAttrs());
+  if (!op.results().empty())
+    p << " : " << op.getResultTypes();
+}
+
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReduceOp op) {
+  // The region of a ReduceOp has two arguments of the same type as its operand.
+  auto type = op.operand().getType();
+  Block &block = op.reductionOperator().front();
+  if (block.empty())
+    return op.emitOpError("the block inside reduce should not be empty");
+  if (block.getNumArguments() != 2 ||
+      llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
+        return arg.getType() != type;
+      }))
+    return op.emitOpError() << "expects two arguments to reduce block of type "
+                            << type;
+
+  // Check that the block is terminated by a ReduceReturnOp.
+  if (!isa<ReduceReturnOp>(block.getTerminator()))
+    return op.emitOpError("the block inside reduce should be terminated with a "
+                          "'loop.reduce.return' op");
+
+  return success();
+}
+
+static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
+  // Parse an opening `(` followed by the reduced value followed by `)`
+  OpAsmParser::OperandType operand;
+  if (parser.parseLParen() || parser.parseOperand(operand) ||
+      parser.parseRParen())
+    return failure();
+
+  // Now parse the body.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+
+  // And the type of the operand (and also what reduce computes on).
+  Type resultType;
+  if (parser.parseColonType(resultType) ||
+      parser.resolveOperand(operand, resultType, result.operands))
+    return failure();
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceOp op) {
+  p << op.getOperationName() << "(" << op.operand() << ") ";
+  p.printRegion(op.reductionOperator());
+  p << " : " << op.operand().getType();
+}
+
+//===----------------------------------------------------------------------===//
+// ReduceReturnOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReduceReturnOp op) {
+  // The type of the return value should be the same type as the type of the
+  // operand of the enclosing ReduceOp.
+  auto reduceOp = cast<ReduceOp>(op.getParentOp());
+  Type reduceType = reduceOp.operand().getType();
+  if (reduceType != op.result().getType())
+    return op.emitOpError() << "needs to have type " << reduceType
+                            << " (the type of the enclosing ReduceOp)";
+  return success();
+}
+
+static ParseResult parseReduceReturnOp(OpAsmParser &parser,
+                                       OperationState &result) {
+  OpAsmParser::OperandType operand;
+  Type resultType;
+  if (parser.parseOperand(operand) || parser.parseColonType(resultType) ||
+      parser.resolveOperand(operand, resultType, result.operands))
+    return failure();
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceReturnOp op) {
+  p << op.getOperationName() << " " << op.result() << " : "
+    << op.result().getType();
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//