Fix segfaults when printing unlinked statements, instructions and blocks. Fancy printing requires a pointer to the function since SSA values get function-specific names. This CL adds checks to ensure that we don't dereference null pointers in unliked objects. Unlinked statements, instructions and blocks are printed as <<UNLINKED STATEMENT>> etc.
PiperOrigin-RevId: 207293992
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c53c333..d71fb92 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -582,6 +582,7 @@
}
void printOperand(const SSAValue *value) { printValueID(value); }
+
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) override;
@@ -1182,6 +1183,10 @@
void SSAValue::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
+ if (!getFunction()) {
+ os << "<<UNLINKED INSTRUCTION>>\n";
+ return;
+ }
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@@ -1193,6 +1198,10 @@
}
void BasicBlock::print(raw_ostream &os) const {
+ if (!getFunction()) {
+ os << "<<UNLINKED BLOCK>>\n";
+ return;
+ }
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
@@ -1202,6 +1211,11 @@
void Statement::print(raw_ostream &os) const {
MLFunction *function = findFunction();
+ if (!function) {
+ os << "<<UNLINKED STATEMENT>>\n";
+ return;
+ }
+
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 847907b..8f64066 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -71,11 +71,13 @@
/// Return the context this operation is associated with.
MLIRContext *Instruction::getContext() const {
- return getFunction()->getContext();
+ auto *fn = getFunction();
+ return fn ? fn->getContext() : nullptr;
}
CFGFunction *Instruction::getFunction() const {
- return getBlock()->getFunction();
+ auto *block = getBlock();
+ return block ? block->getFunction() : nullptr;
}
unsigned Instruction::getNumOperands() const {
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 4e2a7b0..978137b 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -57,10 +57,12 @@
}
}
-Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
+Statement *Statement::getParentStmt() const {
+ return block ? block->getParentStmt() : nullptr;
+}
MLFunction *Statement::findFunction() const {
- return this->getBlock()->findFunction();
+ return block ? block->findFunction() : nullptr;
}
bool Statement::isInnermost() const {
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 21b870f..2769bb9 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -38,7 +38,10 @@
MLFunction *StmtBlock::findFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
- while (block->getParentStmt() != nullptr)
+ while (block->getParentStmt()) {
block = block->getParentStmt()->getBlock();
- return static_cast<MLFunction *>(block);
+ if (!block)
+ return nullptr;
+ }
+ return dyn_cast<MLFunction>(block);
}