Don't crash an XRT server if a client leaks a compilation reference.
PiperOrigin-RevId: 216608167
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index f590fbf..9fc01e6 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -437,6 +437,27 @@
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
+TEST(RawApiTest, LeakCompilationReference) {
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::F32, {2})});
+ StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({c_handle}, &outputs));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
index 4844c7f..31bb476 100644
--- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
@@ -46,12 +46,17 @@
XRTCompilationCache::~XRTCompilationCache() {
VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
+ // A buggy client may be holding onto a reference, or a client might have
+ // crashed while holding onto a reference. In either case, discard all
+ // outstanding client references to avoid leaking storage.
+ for (const auto& entry : entries_by_uid_) {
+ while (!entry.second->RefCountIsOne()) {
+ entry.second->Unref();
+ }
+ }
while (!entries_by_last_use_.empty()) {
MarkOldestEntryForEviction();
}
- // By the time the cache is deleted all reference holders should have already
- // been deleted, since they were holding references to the cache. So all
- // entries should be gone at this point.
CHECK_EQ(cache_.size(), 0);
CHECK_EQ(entries_by_uid_.size(), 0);
CHECK_EQ(cache_entries_, 0);