Fix segfaults when printing unlinked statements, instructions and blocks. Fancy printing requires a pointer to the function since SSA values get function-specific names. This CL adds checks to ensure that we don't dereference null pointers in unliked objects. Unlinked statements, instructions and blocks are printed as <<UNLINKED STATEMENT>> etc.

PiperOrigin-RevId: 207293992
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c53c333..d71fb92 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -582,6 +582,7 @@
   }
 
   void printOperand(const SSAValue *value) { printValueID(value); }
+
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<const char *> elidedAttrs = {}) override;
 
@@ -1182,6 +1183,10 @@
 void SSAValue::dump() const { print(llvm::errs()); }
 
 void Instruction::print(raw_ostream &os) const {
+  if (!getFunction()) {
+    os << "<<UNLINKED INSTRUCTION>>\n";
+    return;
+  }
   ModuleState state(getFunction()->getContext());
   ModulePrinter modulePrinter(os, state);
   CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@@ -1193,6 +1198,10 @@
 }
 
 void BasicBlock::print(raw_ostream &os) const {
+  if (!getFunction()) {
+    os << "<<UNLINKED BLOCK>>\n";
+    return;
+  }
   ModuleState state(getFunction()->getContext());
   ModulePrinter modulePrinter(os, state);
   CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@@ -1202,6 +1211,11 @@
 
 void Statement::print(raw_ostream &os) const {
   MLFunction *function = findFunction();
+  if (!function) {
+    os << "<<UNLINKED STATEMENT>>\n";
+    return;
+  }
+
   ModuleState state(function->getContext());
   ModulePrinter modulePrinter(os, state);
   MLFunctionPrinter(function, modulePrinter).print(this);
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 847907b..8f64066 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -71,11 +71,13 @@
 
 /// Return the context this operation is associated with.
 MLIRContext *Instruction::getContext() const {
-  return getFunction()->getContext();
+  auto *fn = getFunction();
+  return fn ? fn->getContext() : nullptr;
 }
 
 CFGFunction *Instruction::getFunction() const {
-  return getBlock()->getFunction();
+  auto *block = getBlock();
+  return block ? block->getFunction() : nullptr;
 }
 
 unsigned Instruction::getNumOperands() const {
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 4e2a7b0..978137b 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -57,10 +57,12 @@
   }
 }
 
-Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
+Statement *Statement::getParentStmt() const {
+  return block ? block->getParentStmt() : nullptr;
+}
 
 MLFunction *Statement::findFunction() const {
-  return this->getBlock()->findFunction();
+  return block ? block->findFunction() : nullptr;
 }
 
 bool Statement::isInnermost() const {
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 21b870f..2769bb9 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -38,7 +38,10 @@
 MLFunction *StmtBlock::findFunction() const {
   StmtBlock *block = const_cast<StmtBlock *>(this);
 
-  while (block->getParentStmt() != nullptr)
+  while (block->getParentStmt()) {
     block = block->getParentStmt()->getBlock();
-  return static_cast<MLFunction *>(block);
+    if (!block)
+      return nullptr;
+  }
+  return dyn_cast<MLFunction>(block);
 }