Don't crash an XRT server if a client leaks a compilation reference.

PiperOrigin-RevId: 216608167
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index f590fbf..9fc01e6 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -437,6 +437,27 @@
   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 }
 
+TEST(RawApiTest, LeakCompilationReference) {
+  xrt::XLAComputation c;
+  auto config = c.mutable_config();
+  auto shapes = config->mutable_program_shape();
+  *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+  *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+  *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
+      {xla::ShapeUtil::MakeShape(xla::F32, {2})});
+  StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
+
+  Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+  auto computation =
+      ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+  auto c_handle = ops::XRTCompile(root, computation);
+  TF_ASSERT_OK(root.status());
+
+  ClientSession session(root);
+  std::vector<Tensor> outputs;
+  TF_EXPECT_OK(session.Run({c_handle}, &outputs));
+}
+
 }  // namespace
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
index 4844c7f..31bb476 100644
--- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
@@ -46,12 +46,17 @@
 
 XRTCompilationCache::~XRTCompilationCache() {
   VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
+  // A buggy client may be holding onto a reference, or a client might have
+  // crashed while holding onto a reference. In either case, discard all
+  // outstanding client references to avoid leaking storage.
+  for (const auto& entry : entries_by_uid_) {
+    while (!entry.second->RefCountIsOne()) {
+      entry.second->Unref();
+    }
+  }
   while (!entries_by_last_use_.empty()) {
     MarkOldestEntryForEviction();
   }
-  // By the time the cache is deleted all reference holders should have already
-  // been deleted, since they were holding references to the cache. So all
-  // entries should be gone at this point.
   CHECK_EQ(cache_.size(), 0);
   CHECK_EQ(entries_by_uid_.size(), 0);
   CHECK_EQ(cache_entries_, 0);