[XLA:Python] Remove PyLocalClient::(De)SerializeExecutable virtual methods.

PiperOrigin-RevId: 290915533
Change-Id: I566bea2e1067d3971b2b1ab934e2f77d0a8be903
diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc
index 45bcf48..c233852c 100644
--- a/tensorflow/compiler/xla/python/local_client.cc
+++ b/tensorflow/compiler/xla/python/local_client.cc
@@ -268,20 +268,6 @@
   }
 }
 
-StatusOr<std::string> PyLocalClient::SerializeExecutable(
-    const PyLocalExecutable& executable) const {
-  return Unimplemented("Cannot serialize executables on platform '%s'",
-                       platform_name());
-}
-
-StatusOr<std::unique_ptr<PyLocalExecutable>>
-PyLocalClient::DeserializeExecutable(
-    const std::string& serialized,
-    std::shared_ptr<PyLocalClient> this_shared) const {
-  return Unimplemented("Cannot deserialize executables on platform '%s'",
-                       platform_name());
-}
-
 Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
                                        std::shared_ptr<Device> device) {
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h
index 001cf18..be29392 100644
--- a/tensorflow/compiler/xla/python/local_client.h
+++ b/tensorflow/compiler/xla/python/local_client.h
@@ -39,8 +39,6 @@
 
 namespace xla {
 
-class PyLocalExecutable;
-
 class Device {
  public:
   explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
@@ -172,19 +170,6 @@
   // function specifies which one the platform expects.
   virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
 
-  // Returns a platform-specific serialization of `executable`. This is meant
-  // for transferring executables and not for storage, and the serialization is
-  // not guaranteed to be stable over time.
-  virtual StatusOr<std::string> SerializeExecutable(
-      const PyLocalExecutable& executable) const;
-
-  // Deserializes a serialized executable as produced by
-  // SerializeExecutable(). `serialized` must have been produced by client of
-  // the same platform. `this_shared` should point to this PyLocalClient.
-  virtual StatusOr<std::unique_ptr<PyLocalExecutable>> DeserializeExecutable(
-      const std::string& serialized,
-      std::shared_ptr<PyLocalClient> this_shared) const;
-
  protected:
   std::string platform_name_;
   LocalClient* client_;
@@ -353,7 +338,6 @@
 
   void Delete() { executable_ = nullptr; }
 
-  LocalExecutable* executable() const { return executable_.get(); }
   const string& name() const;
 
  private:
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index c1d7893..a98372c 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -462,14 +462,6 @@
              }
              return LiteralToPython(std::move(literal_shared));
            })
-      .def("SerializeExecutable",
-           [](PyLocalClient* client,
-              PyLocalExecutable* executable) -> StatusOr<py::bytes> {
-             TF_ASSIGN_OR_RETURN(std::string serialized,
-                                 client->SerializeExecutable(*executable));
-             return py::bytes(serialized);
-           })
-      .def("DeserializeExecutable", &PyLocalClient::DeserializeExecutable)
       .def("CreateChannelHandle",
            [](PyLocalClient* client) {
              return client->client()->CreateChannelHandle();
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index e3f51d6..7c535ef 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -161,12 +161,6 @@
       # TODO(skye): delete this case after all callers can handle 2D output
       return self.client.GetDefaultDeviceAssignment(num_replicas)
 
-  def serialize(self, executable):
-    return self.client.SerializeExecutable(executable)
-
-  def deserialize(self, serialized_executable):
-    return self.client.DeserializeExecutable(serialized_executable, self.client)
-
 
 xla_platform_names = {
     'cpu': 'Host',