Export load_partial as an internal symbol.
PiperOrigin-RevId: 342157049
Change-Id: I748039361b5180b16f44088909898fc90c6f5301
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 03a1048..1d513b4 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -668,6 +668,7 @@
return instance.__call__(*args, **kwargs)
+@tf_export("__internal__.saved_model.load_partial", v1=[])
def load_partial(export_dir, filters, tags=None, options=None):
"""Partially load a SavedModel (saved from V2).
@@ -771,47 +772,45 @@
Signatures associated with the SavedModel are available as functions:
- >>> class Adder(tf.Module):
- ... @tf.function(
- ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
- ... def add(self, x):
- ... return x + x
- >>> model = Adder()
- >>> model.add(tf.constant(1.))
- 2.0
- >>> tf.saved_model.save(model, "/tmp/adder")
- >>> imported = tf.saved_model.load("/tmp/adder")
- >>> f = imported.signatures["serving_default"]
- >>> f(x=tf.constant(1.))
- {'output_0': <tf.Tensor: shape=(), dtype=float32, numpy=2.0>}
+ ```python
+ imported = tf.saved_model.load(path)
+ f = imported.signatures["serving_default"]
+ print(f(x=tf.constant([[1.]])))
+ ```
- Any trackable attributes on the exported object will be restored on load:
+ Objects exported with `tf.saved_model.save` additionally have trackable
+ objects and functions assigned to attributes:
- >>> exported = tf.train.Checkpoint(v=tf.Variable(3.))
- >>> exported.multiply = tf.function(
- ... lambda x: exported.v * x,
- ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
- >>> tf.saved_model.save(exported, "/tmp/exported")
- >>> imported = tf.saved_model.load("/tmp/exported")
- >>> imported.v.numpy()
- 3.0
- >>> imported.multiply(x=tf.constant(2.)).numpy()
- 6.0
+ ```python
+ exported = tf.train.Checkpoint(v=tf.Variable(3.))
+ exported.f = tf.function(
+ lambda x: exported.v * x,
+ input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
+ tf.saved_model.save(exported, path)
+ imported = tf.saved_model.load(path)
+ assert 3. == imported.v.numpy()
+ assert 6. == imported.f(x=tf.constant(2.)).numpy()
+ ```
_Loading Keras models_
- Keras models are trackable, so they can be saved and loaded via SavedModel.
- The object returned by `tf.saved_model.load` is not a Keras object, however
- (i.e. it doesn't have `.fit`, `.predict`, etc. methods). A few attributes and
- functions are still available: `.variables`, `.trainable_variables` and
- `.__call__`.
+ Keras models are trackable, so they can be saved to SavedModel. The object
+ returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
+ `.fit`, `.predict`, etc. methods). A few attributes and functions are still
+ available: `.variables`, `.trainable_variables` and `.__call__`.
- To restore a full Keras model along with all its attributes and functions,
- use `tf.keras.models.load_model` instead.
+ ```python
+ model = tf.keras.Model(...)
+ tf.saved_model.save(model, path)
+ imported = tf.saved_model.load(path)
+ outputs = imported(inputs)
+ ```
+
+ Use `tf.keras.models.load_model` to restore the Keras model.
_Importing SavedModels from TensorFlow 1.x_
- SavedModels from `tf.estimator.Estimator` and 1.x SavedModel APIs have a flat
+ SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
graph instead of `tf.function` objects. These SavedModels will be loaded with
the following attributes:
@@ -832,16 +831,14 @@
* `.restore(save_path)`: A function that restores variables from a checkpoint
saved from `tf.compat.v1.Saver`.
- _Making sure a SavedModel is ready to be loaded_
+ _Consuming SavedModels asynchronously_
- When exporting a SavedModel, TensorFlow first creates `export_dir` and then
- writes a number of additional files. Calling `tf.saved_model.load` on a
- directory in a partially-written state will fail.
-
- If you would like to make sure a SavedModel is fully written and ready for
- loading, check for the presence of `"saved_model_dir/saved_model.pb"` rather
- than `export_dir`. This file is written atomically as the last step in
- saving.
+ When consuming SavedModels asynchronously (the producer is a separate
+ process), the SavedModel directory will appear before all files have been
+ written, and `tf.saved_model.load` will fail if pointed at an incomplete
+ SavedModel. Rather than checking for the directory, check for
+ "saved_model_dir/saved_model.pb". This file is written atomically as the last
+ `tf.saved_model.save` file operation.
Args:
export_dir: The SavedModel directory to load from.
@@ -852,10 +849,10 @@
loading.
Returns:
- A trackable object with a `signatures` attribute mapping signature keys to
- functions. If the SavedModel was exported by `tf.saved_model.save`, it will
- also have attributes pointing to any trackable objects attached to the
- originally exported object.
+ A trackable object with a `signatures` attribute mapping from signature
+ keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
+ it also points to trackable objects, functions, debug info which it has been
+ saved.
Raises:
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 359a6b1..20d5564 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -17,6 +17,7 @@
"__internal__/test/__init__.py",
"__internal__/test/combinations/__init__.py",
"__internal__/tf2/__init__.py",
+ "__internal__/saved_model/__init__.py",
"__internal__/tracking/__init__.py",
"__operators__/__init__.py",
"audio/__init__.py",
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
index 22eccb3..6558b88 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
@@ -41,6 +41,10 @@
mtype: "<type \'module\'>"
}
member {
+ name: "saved_model"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "test"
mtype: "<type \'module\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.pbtxt
new file mode 100644
index 0000000..0117df4
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.__internal__.saved_model"
+tf_module {
+ member_method {
+ name: "load_partial"
+ argspec: "args=[\'export_dir\', \'filters\', \'tags\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+}