Fix keras.utils.Sequence.on_epoch_end not being called in Model.fit.
PiperOrigin-RevId: 297709368
Change-Id: I53affbb973facb5ada351ecc46bd070a053f545f
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 3fc66d0..f3f8086 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -223,6 +223,10 @@
total_sample -= (self.batch_size() - self.partial_batch_size())
return total_sample
+ def on_epoch_end(self):
+ """A hook called after each epoch."""
+ pass
+
class TensorLikeDataAdapter(DataAdapter):
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
@@ -891,6 +895,7 @@
self._size = len(x)
self._shuffle_sequence = shuffle
+ self._keras_sequence = x
super(KerasSequenceAdapter, self).__init__(
x,
shuffle=False, # Shuffle is handed in the _make_callable override.
@@ -932,6 +937,9 @@
def should_recreate_iterator(self):
return True
+ def on_epoch_end(self):
+ self._keras_sequence.on_epoch_end()
+
ALL_ADAPTER_CLS = [
ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
@@ -1084,8 +1092,8 @@
self._epochs = epochs
self._insufficient_data = False
- train_adapter_cls = select_data_adapter(x, y)
- self._train_adapter = train_adapter_cls(
+ adapter_cls = select_data_adapter(x, y)
+ self._adapter = adapter_cls(
x,
y,
batch_size=batch_size,
@@ -1100,21 +1108,22 @@
model=model)
strategy = ds_context.get_strategy()
- dataset = self._train_adapter.get_dataset()
+ dataset = self._adapter.get_dataset()
if class_weight:
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
self._steps_per_epoch = self._infer_steps(steps_per_epoch, dataset)
- self._train_dataset = strategy.experimental_distribute_dataset(dataset)
+ self._dataset = strategy.experimental_distribute_dataset(dataset)
def enumerate_epochs(self):
"""Yields `(epoch, tf.data.Iterator)`."""
- data_iterator = iter(self._train_dataset)
+ data_iterator = iter(self._dataset)
for epoch in range(self._initial_epoch, self._epochs):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break
- if self._train_adapter.should_recreate_iterator():
- data_iterator = iter(self._train_dataset)
+ if self._adapter.should_recreate_iterator():
+ data_iterator = iter(self._dataset)
yield epoch, data_iterator
+ self._adapter.on_epoch_end()
@contextlib.contextmanager
def catch_stop_iteration(self):
@@ -1122,8 +1131,8 @@
try:
yield
except (StopIteration, errors.OutOfRangeError):
- if (self._train_adapter.get_size() is None and
- self._steps_per_epoch is None and self._current_step > 0):
+ if (self._adapter.get_size() is None and self._steps_per_epoch is None and
+ self._current_step > 0):
# The input passed by the user ran out of batches.
# Now we know the cardinality of the input(dataset or generator).
self._steps_per_epoch = self._current_step
@@ -1154,7 +1163,7 @@
if steps is not None:
return steps
- adapter_steps = self._train_adapter.get_size()
+ adapter_steps = self._adapter.get_size()
if adapter_steps is not None:
return adapter_steps
@@ -1175,11 +1184,11 @@
@property
def _samples(self):
- return self._train_adapter.get_samples()
+ return self._adapter.get_samples()
@property
def _steps(self):
- return self._train_adapter.get_size()
+ return self._adapter.get_size()
def _make_class_weight_map_fn(class_weight):
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 75ddf0f..346dc32 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -786,7 +786,7 @@
# User can choose to only partially consume `Dataset`.
data_handler = data_adapter.DataHandler(
data, initial_epoch=0, epochs=2, steps_per_epoch=2)
- self.assertFalse(data_handler._train_adapter.should_recreate_iterator())
+ self.assertFalse(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
@@ -812,7 +812,7 @@
# create a new iterator each epoch.
data_handler = data_adapter.DataHandler(
data, initial_epoch=0, epochs=2, steps_per_epoch=4)
- self.assertTrue(data_handler._train_adapter.should_recreate_iterator())
+ self.assertTrue(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
@@ -842,7 +842,7 @@
# User can choose to only partially consume `Dataset`.
data_handler = data_adapter.DataHandler(
filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2)
- self.assertFalse(data_handler._train_adapter.should_recreate_iterator())
+ self.assertFalse(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
@@ -860,7 +860,7 @@
data_handler = data_adapter.DataHandler(
filtered_ds, initial_epoch=0, epochs=2)
- self.assertTrue(data_handler._train_adapter.should_recreate_iterator())
+ self.assertTrue(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py
index f5abc0e..dd25fb7 100644
--- a/tensorflow/python/keras/engine/training_generator_test.py
+++ b/tensorflow/python/keras/engine/training_generator_test.py
@@ -454,6 +454,31 @@
model.evaluate(CustomSequenceChangingBatchSize())
model.predict(CustomSequenceChangingBatchSize())
+ @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+ def test_sequence_on_epoch_end(self):
+
+ class MySequence(data_utils.Sequence):
+
+ def __init__(self):
+ self.epochs = 0
+
+ def __getitem__(self, idx):
+ return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ def __len__(self):
+ return 2
+
+ def on_epoch_end(self):
+ self.epochs += 1
+
+ inputs = keras.Input(10)
+ outputs = keras.layers.Dense(1)(inputs)
+ model = keras.Model(inputs, outputs)
+ model.compile('sgd', 'mse')
+ my_seq = MySequence()
+ model.fit(my_seq, epochs=2)
+ self.assertEqual(my_seq.epochs, 2)
+
@tf_test_util.run_all_in_graph_and_eager_modes
class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase):