[mlir] Add basic block arguments

This patch adds support for basic block arguments including parsing and printing.

In doing so noticed that `ssa-id-and-type` is undefined in the MLIR spec; suggested an implementation in the spec doc.

PiperOrigin-RevId: 205593369
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index 5b42a68..be64e4f 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -22,6 +22,7 @@
 #include <memory>
 
 namespace mlir {
+class BBArgument;
 
 /// Each basic block in a CFG function contains a list of basic block arguments,
 /// normal instructions, and a terminator instruction.
@@ -39,12 +40,34 @@
     return function;
   }
 
-  // TODO: bb arguments
-
   /// Unlink this BasicBlock from its CFGFunction and delete it.
   void eraseFromFunction();
 
   //===--------------------------------------------------------------------===//
+  // Block arguments management
+  //===--------------------------------------------------------------------===//
+
+  // This is the list of arguments to the block.
+  typedef ArrayRef<BBArgument *> BBArgListType;
+  BBArgListType getArguments() const { return arguments; }
+
+  using args_iterator = BBArgListType::iterator;
+  using reverse_args_iterator = BBArgListType::reverse_iterator;
+  args_iterator args_begin() const { return getArguments().begin(); }
+  args_iterator args_end() const { return getArguments().end(); }
+  reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
+  reverse_args_iterator args_rend() const { return getArguments().rend(); }
+
+  bool args_empty() const { return arguments.empty(); }
+  BBArgument *addArgument(Type *type);
+  llvm::iterator_range<BBArgListType::iterator>
+  addArguments(ArrayRef<Type *> types);
+
+  unsigned getNumArguments() const { return arguments.size(); }
+  BBArgument *getArgument(unsigned i) { return arguments[i]; }
+  const BBArgument *getArgument(unsigned i) const { return arguments[i]; }
+
+  //===--------------------------------------------------------------------===//
   // Operation list management
   //===--------------------------------------------------------------------===//
 
@@ -105,6 +128,9 @@
   /// This is the list of operations in the block.
   OperationListType operations;
 
+  /// This is the list of arguments to the block.
+  std::vector<BBArgument *> arguments;
+
   /// This is the owning reference to the terminator of the block.
   TerminatorInst *terminator = nullptr;
 
diff --git a/include/mlir/IR/CFGValue.h b/include/mlir/IR/CFGValue.h
index d18d395..5dd070b 100644
--- a/include/mlir/IR/CFGValue.h
+++ b/include/mlir/IR/CFGValue.h
@@ -26,6 +26,7 @@
 #include "mlir/IR/SSAValue.h"
 
 namespace mlir {
+class BasicBlock;
 class CFGValue;
 class Instruction;
 
@@ -33,7 +34,7 @@
 /// function.  This should be kept as a proper subtype of SSAValueKind,
 /// including having all of the values of the enumerators align.
 enum class CFGValueKind {
-  // TODO: BBArg,
+  BBArgument = (int)SSAValueKind::BBArgument,
   InstResult = (int)SSAValueKind::InstResult,
 };
 
@@ -45,6 +46,7 @@
 public:
   static bool classof(const SSAValue *value) {
     switch (value->getKind()) {
+    case SSAValueKind::BBArgument:
     case SSAValueKind::InstResult:
       return true;
     }
@@ -54,6 +56,27 @@
   CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
 };
 
+/// Basic block arguments are CFG Values.
+class BBArgument : public CFGValue {
+public:
+  static bool classof(const SSAValue *value) {
+    return value->getKind() == SSAValueKind::BBArgument;
+  }
+
+  BasicBlock *getOwner() { return owner; }
+  const BasicBlock *getOwner() const { return owner; }
+
+private:
+  friend class BasicBlock; // For access to private constructor.
+  BBArgument(Type *type, BasicBlock *owner)
+      : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
+
+  /// The owner of this operand.
+  /// TODO: can encode this more efficiently to avoid the space hit of this
+  /// through bitpacking shenanigans.
+  BasicBlock *const owner;
+};
+
 /// Instruction results are CFG Values.
 class InstResult : public CFGValue {
 public:
diff --git a/include/mlir/IR/SSAValue.h b/include/mlir/IR/SSAValue.h
index 09fd45b..0b7648c 100644
--- a/include/mlir/IR/SSAValue.h
+++ b/include/mlir/IR/SSAValue.h
@@ -34,7 +34,7 @@
 
 /// This enumerates all of the SSA value kinds in the MLIR system.
 enum class SSAValueKind {
-  // TODO: BBArg,
+  BBArgument,
   InstResult,
 
   // FnArg
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index d811223..0d5610f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -665,7 +665,9 @@
 
 /// Number all of the SSA values in the specified basic block.
 void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
-  // TODO: basic block arguments.
+  for (auto *arg : block->getArguments()) {
+    numberValueID(arg);
+  }
   for (auto &op : *block) {
     // We number instruction that have results, and we only number the first
     // result.
@@ -686,16 +688,26 @@
 }
 
 void CFGFunctionPrinter::print(const BasicBlock *block) {
-  os << "bb" << getBBID(block) << ":\n";
+  os << "bb" << getBBID(block);
 
-  // TODO Print arguments.
+  if (!block->args_empty()) {
+    os << '(';
+    interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
+      printValueID(arg);
+      os << ": ";
+      ModulePrinter::print(arg->getType());
+    });
+    os << ')';
+  }
+  os << ":\n";
+
   for (auto &inst : block->getOperations()) {
     print(&inst);
-    os << "\n";
+    os << '\n';
   }
 
   print(block->getTerminator());
-  os << "\n";
+  os << '\n';
 }
 
 void CFGFunctionPrinter::print(const Instruction *inst) {
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index c2c865c..1aad430 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -19,12 +19,14 @@
 #include "mlir/IR/CFGFunction.h"
 using namespace mlir;
 
-BasicBlock::BasicBlock() {
-}
+BasicBlock::BasicBlock() {}
 
 BasicBlock::~BasicBlock() {
   if (terminator)
     terminator->eraseFromBlock();
+  for (BBArgument *arg : arguments)
+    delete arg;
+  arguments.clear();
 }
 
 /// Unlink this BasicBlock from its CFGFunction and delete it.
@@ -84,3 +86,17 @@
   for (; first != last; ++first)
     first->function = curParent;
 }
+
+BBArgument *BasicBlock::addArgument(Type *type) {
+  arguments.push_back(new BBArgument(type, this));
+  return arguments.back();
+}
+
+llvm::iterator_range<BasicBlock::BBArgListType::iterator>
+BasicBlock::addArguments(ArrayRef<Type *> types) {
+  auto initial_size = arguments.size();
+  for (auto *type : types) {
+    addArgument(type);
+  }
+  return {arguments.data() + initial_size, arguments.data() + arguments.size()};
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 3c3bb16..d1bb2ac 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -106,6 +106,26 @@
 bool CFGFuncVerifier::verify() {
   // TODO: Lots to be done here, including verifying dominance information when
   // we have uses and defs.
+  // TODO: Verify the first block has no predecessors.
+
+  if (fn.empty())
+    return failure("cfgfunc must have at least one basic block", fn);
+
+  // Verify that the argument list of the function and the arg list of the first
+  // block line up.
+  auto *firstBB = &fn.front();
+  auto fnInputTypes = fn.getType()->getInputs();
+  if (fnInputTypes.size() != firstBB->getNumArguments())
+    return failure("first block of cfgfunc must have " +
+                       Twine(fnInputTypes.size()) +
+                       " arguments to match function signature",
+                   fn);
+  for (unsigned i = 0, e = firstBB->getNumArguments(); i != e; ++i)
+    if (fnInputTypes[i] != firstBB->getArgument(i)->getType())
+      return failure(
+          "type of argument #" + Twine(i) +
+              " must match corresponding argument in function signature",
+          fn);
 
   for (auto &block : fn) {
     if (verifyBlock(block))
@@ -121,6 +141,11 @@
   if (verifyTerminator(*block.getTerminator()))
     return true;
 
+  for (auto *arg : block.getArguments()) {
+    if (arg->getOwner() != &block)
+      return failure("basic block argument not owned by block", block);
+  }
+
   for (auto &inst : block) {
     if (verifyOperation(inst))
       return true;
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 62a9fe7..bc26bda 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1219,7 +1219,17 @@
   // SSA parsing productions.
   ParseResult parseSSAUse(SSAUseInfo &result);
   ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
-  SSAValue *parseSSAUseAndType();
+
+  template <typename ResultType>
+  ResultType parseSSADefOrUseAndType(
+      const std::function<ResultType(SSAUseInfo, Type *)> &action);
+
+  SSAValue *parseSSAUseAndType() {
+    return parseSSADefOrUseAndType<SSAValue *>(
+        [&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
+          return resolveSSAUse(useInfo, type);
+        });
+  }
 
   template <typename ValueTy>
   ParseResult
@@ -1355,8 +1365,7 @@
 
 /// Parse a SSA operand for an instruction or statement.
 ///
-///   ssa-use ::= ssa-id | ssa-constant
-/// TODO: SSA Constants.
+///   ssa-use ::= ssa-id
 ///
 ParseResult FunctionParser::parseSSAUse(SSAUseInfo &result) {
   result.name = getTokenSpelling();
@@ -1398,7 +1407,9 @@
 /// Parse an SSA use with an associated type.
 ///
 ///   ssa-use-and-type ::= ssa-use `:` type
-SSAValue *FunctionParser::parseSSAUseAndType() {
+template <typename ResultType>
+ResultType FunctionParser::parseSSADefOrUseAndType(
+    const std::function<ResultType(SSAUseInfo, Type *)> &action) {
   SSAUseInfo useInfo;
   if (parseSSAUse(useInfo))
     return nullptr;
@@ -1410,7 +1421,7 @@
   if (!type)
     return nullptr;
 
-  return resolveSSAUse(useInfo, type);
+  return action(useInfo, type);
 }
 
 /// Parse a (possibly empty) list of SSA operands with types.
@@ -1570,12 +1581,39 @@
     return blockAndLoc.first;
   }
 
+  ParseResult
+  parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
+                                 BasicBlock *owner);
+
   ParseResult parseBasicBlock();
   OperationInst *parseCFGOperation();
   TerminatorInst *parseTerminator();
 };
 } // end anonymous namespace
 
+/// Parse a (possibly empty) list of SSA operands with types as basic block
+/// arguments. Unlike parseOptionalSsaUseAndTypeList the SSA IDs are treated as
+/// defs, not uses.
+///
+///   ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
+///
+ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
+    SmallVectorImpl<BBArgument *> &results, BasicBlock *owner) {
+  if (getToken().is(Token::r_brace))
+    return ParseSuccess;
+
+  return parseCommaSeparatedList([&]() -> ParseResult {
+    auto type = parseSSADefOrUseAndType<Type *>(
+        [&](SSAUseInfo useInfo, Type *type) -> Type * {
+          BBArgument *arg = owner->addArgument(type);
+          if (addDefinition(useInfo, arg) == ParseFailure)
+            return nullptr;
+          return type;
+        });
+    return type ? ParseSuccess : ParseFailure;
+  });
+}
+
 ParseResult CFGFunctionParser::parseFunctionBody() {
   auto braceLoc = getToken().getLoc();
   if (!consumeIf(Token::l_brace))
@@ -1625,20 +1663,18 @@
   if (block->getFunction())
     return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
 
-  // Add the block to the function.
-  function->push_back(block);
-
   // If an argument list is present, parse it.
   if (consumeIf(Token::l_paren)) {
-    SmallVector<SSAUseInfo, 8> bbArgs;
-    if (parseOptionalSSAUseList(bbArgs))
+    SmallVector<BBArgument *, 8> bbArgs;
+    if (parseOptionalBasicBlockArgList(bbArgs, block))
       return ParseFailure;
     if (!consumeIf(Token::r_paren))
       return emitError("expected ')' to end argument list");
-
-    // TODO: attach it.
   }
 
+  // Add the block to the function.
+  function->push_back(block);
+
   if (!consumeIf(Token::colon))
     return emitError("expected ':' after basic block name");
 
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 6177136..c9086fe 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -91,6 +91,27 @@
 
 // -----
 
+cfgfunc @block_no_rparen() {
+bb42 (%bb42 : i32: // expected-error {{expected ')' to end argument list}}
+  return
+}
+
+// -----
+
+cfgfunc @block_arg_no_ssaid() {
+bb42 (i32): // expected-error {{expected SSA operand}}
+  return
+}
+
+// -----
+
+cfgfunc @block_arg_no_type() {
+bb42 (%0): // expected-error {{expected ':' and type for SSA operand}}
+  return
+}
+
+// -----
+
 mlfunc @foo()
 mlfunc @bar() // expected-error {{expected '{' in ML function}}
 
@@ -208,3 +229,17 @@
 }
 
 // -----
+
+cfgfunc @argError() {  
+bb1(%a: i64):  // expected-error {{previously defined here}}
+  br bb2
+bb2(%a: i64):  // expected-error{{redefinition of SSA value '%a'}}
+  return
+}
+
+// -----
+
+cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc must have 2 arguments to match function signature}}
+bb42(%0: f32):
+  return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 98844dc..a4d45c6 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -68,17 +68,28 @@
 
 // CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
 cfgfunc @simpleCFG(i32, f32) {
-// CHECK: bb0:
-bb42: // (%0: i32, %f: f32):    TODO(clattner): implement bbargs.
-  // CHECK: %0 = "foo"() : () -> i64
+// CHECK: bb0(%0: i32, %1: f32):
+bb42 (%0: i32, %f: f32):
+  // CHECK: %2 = "foo"() : () -> i64
   %1 = "foo"() : ()->i64
-  // CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
+  // CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
   %2 = "bar"(%1) : (i64) -> (i1,i1,i1)
   // CHECK: return
   return
 // CHECK: }
 }
 
+// CHECK-LABEL: cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
+cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
+// CHECK: bb0(%0: i32, %1: i64):
+bb42 (%0: i32, %f: i64):
+  // CHECK: "bar"(%1) : (i64) -> (i1, i1, i1)
+  %2 = "bar"(%f) : (i64) -> (i1,i1,i1)
+  // CHECK: return
+  return
+// CHECK: }
+}
+
 // CHECK-LABEL: cfgfunc @multiblock() {
 cfgfunc @multiblock() {
 bb0:         // CHECK: bb0:
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 6561d24..80baee1 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -118,7 +118,7 @@
                                  SMLoc());
 
     // Extracing the expected errors.
-    llvm::Regex expected("expected-error(@[+-][0-9]+)? {{(.*)}}");
+    llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
     SmallVector<ExpectedError, 2> expectedErrors;
     SmallVector<StringRef, 100> lines;
     subbuffer.split(lines, '\n');