Fix keras.utils.Sequence.on_epoch_end not being called in Model.fit.

PiperOrigin-RevId: 297709368
Change-Id: I53affbb973facb5ada351ecc46bd070a053f545f
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 3fc66d0..f3f8086 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -223,6 +223,10 @@
       total_sample -= (self.batch_size() - self.partial_batch_size())
     return total_sample
 
+  def on_epoch_end(self):
+    """A hook called after each epoch."""
+    pass
+
 
 class TensorLikeDataAdapter(DataAdapter):
   """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
@@ -891,6 +895,7 @@
 
     self._size = len(x)
     self._shuffle_sequence = shuffle
+    self._keras_sequence = x
     super(KerasSequenceAdapter, self).__init__(
         x,
         shuffle=False,  # Shuffle is handed in the _make_callable override.
@@ -932,6 +937,9 @@
   def should_recreate_iterator(self):
     return True
 
+  def on_epoch_end(self):
+    self._keras_sequence.on_epoch_end()
+
 
 ALL_ADAPTER_CLS = [
     ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
@@ -1084,8 +1092,8 @@
     self._epochs = epochs
     self._insufficient_data = False
 
-    train_adapter_cls = select_data_adapter(x, y)
-    self._train_adapter = train_adapter_cls(
+    adapter_cls = select_data_adapter(x, y)
+    self._adapter = adapter_cls(
         x,
         y,
         batch_size=batch_size,
@@ -1100,21 +1108,22 @@
         model=model)
 
     strategy = ds_context.get_strategy()
-    dataset = self._train_adapter.get_dataset()
+    dataset = self._adapter.get_dataset()
     if class_weight:
       dataset = dataset.map(_make_class_weight_map_fn(class_weight))
     self._steps_per_epoch = self._infer_steps(steps_per_epoch, dataset)
-    self._train_dataset = strategy.experimental_distribute_dataset(dataset)
+    self._dataset = strategy.experimental_distribute_dataset(dataset)
 
   def enumerate_epochs(self):
     """Yields `(epoch, tf.data.Iterator)`."""
-    data_iterator = iter(self._train_dataset)
+    data_iterator = iter(self._dataset)
     for epoch in range(self._initial_epoch, self._epochs):
       if self._insufficient_data:  # Set by `catch_stop_iteration`.
         break
-      if self._train_adapter.should_recreate_iterator():
-        data_iterator = iter(self._train_dataset)
+      if self._adapter.should_recreate_iterator():
+        data_iterator = iter(self._dataset)
       yield epoch, data_iterator
+      self._adapter.on_epoch_end()
 
   @contextlib.contextmanager
   def catch_stop_iteration(self):
@@ -1122,8 +1131,8 @@
     try:
       yield
     except (StopIteration, errors.OutOfRangeError):
-      if (self._train_adapter.get_size() is None and
-          self._steps_per_epoch is None and self._current_step > 0):
+      if (self._adapter.get_size() is None and self._steps_per_epoch is None and
+          self._current_step > 0):
         # The input passed by the user ran out of batches.
         # Now we know the cardinality of the input(dataset or generator).
         self._steps_per_epoch = self._current_step
@@ -1154,7 +1163,7 @@
     if steps is not None:
       return steps
 
-    adapter_steps = self._train_adapter.get_size()
+    adapter_steps = self._adapter.get_size()
     if adapter_steps is not None:
       return adapter_steps
 
@@ -1175,11 +1184,11 @@
 
   @property
   def _samples(self):
-    return self._train_adapter.get_samples()
+    return self._adapter.get_samples()
 
   @property
   def _steps(self):
-    return self._train_adapter.get_size()
+    return self._adapter.get_size()
 
 
 def _make_class_weight_map_fn(class_weight):
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 75ddf0f..346dc32 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -786,7 +786,7 @@
     # User can choose to only partially consume `Dataset`.
     data_handler = data_adapter.DataHandler(
         data, initial_epoch=0, epochs=2, steps_per_epoch=2)
-    self.assertFalse(data_handler._train_adapter.should_recreate_iterator())
+    self.assertFalse(data_handler._adapter.should_recreate_iterator())
     returned_data = []
     for _, iterator in data_handler.enumerate_epochs():
       epoch_data = []
@@ -812,7 +812,7 @@
     # create a new iterator each epoch.
     data_handler = data_adapter.DataHandler(
         data, initial_epoch=0, epochs=2, steps_per_epoch=4)
-    self.assertTrue(data_handler._train_adapter.should_recreate_iterator())
+    self.assertTrue(data_handler._adapter.should_recreate_iterator())
     returned_data = []
     for _, iterator in data_handler.enumerate_epochs():
       epoch_data = []
@@ -842,7 +842,7 @@
     # User can choose to only partially consume `Dataset`.
     data_handler = data_adapter.DataHandler(
         filtered_ds, initial_epoch=0, epochs=2, steps_per_epoch=2)
-    self.assertFalse(data_handler._train_adapter.should_recreate_iterator())
+    self.assertFalse(data_handler._adapter.should_recreate_iterator())
     returned_data = []
     for _, iterator in data_handler.enumerate_epochs():
       epoch_data = []
@@ -860,7 +860,7 @@
 
     data_handler = data_adapter.DataHandler(
         filtered_ds, initial_epoch=0, epochs=2)
-    self.assertTrue(data_handler._train_adapter.should_recreate_iterator())
+    self.assertTrue(data_handler._adapter.should_recreate_iterator())
     returned_data = []
     for _, iterator in data_handler.enumerate_epochs():
       epoch_data = []
diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py
index f5abc0e..dd25fb7 100644
--- a/tensorflow/python/keras/engine/training_generator_test.py
+++ b/tensorflow/python/keras/engine/training_generator_test.py
@@ -454,6 +454,31 @@
     model.evaluate(CustomSequenceChangingBatchSize())
     model.predict(CustomSequenceChangingBatchSize())
 
+  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+  def test_sequence_on_epoch_end(self):
+
+    class MySequence(data_utils.Sequence):
+
+      def __init__(self):
+        self.epochs = 0
+
+      def __getitem__(self, idx):
+        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+      def __len__(self):
+        return 2
+
+      def on_epoch_end(self):
+        self.epochs += 1
+
+    inputs = keras.Input(10)
+    outputs = keras.layers.Dense(1)(inputs)
+    model = keras.Model(inputs, outputs)
+    model.compile('sgd', 'mse')
+    my_seq = MySequence()
+    model.fit(my_seq, epochs=2)
+    self.assertEqual(my_seq.epochs, 2)
+
 
 @tf_test_util.run_all_in_graph_and_eager_modes
 class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase):