Disable auto_shard for MirroredStrategy by default.
We will re-enable it when it is more robust.

PiperOrigin-RevId: 214956066
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f9..2e02576 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@
 important to shuffle your dataset in your `input_fn`.
 
 `MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
 
 ### Performance Tips
 
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 504f45a..93d42e0 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@
       set, the `configure` method will try to find the best one.
     prefetch_on_device: optional boolean to specify whether to prefetch input
       data to devices.
+    auto_shard_dataset: whether to auto-shard the dataset when there are
+      multiple workers.
   """
 
   def __init__(self,
@@ -354,11 +356,13 @@
                num_gpus=None,
                num_gpus_per_worker=None,
                cross_tower_ops=None,
-               prefetch_on_device=None):
+               prefetch_on_device=None,
+               auto_shard_dataset=False):
     super(MirroredStrategy, self).__init__()
 
     self._cross_tower_ops = cross_tower_ops
     self._prefetch_on_device = prefetch_on_device
+    self._auto_shard_dataset = auto_shard_dataset
     # Rememeber num GPUs which might be needed by `configure` method.
     if num_gpus is not None and num_gpus_per_worker is not None:
       raise ValueError(
@@ -477,7 +481,7 @@
     if self._cluster_spec:
       return values.MultiWorkerDataset(
           partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
-          self._prefetch_on_device)
+          self._prefetch_on_device, self._auto_shard_dataset)
     else:
       return values.PerDeviceDataset(
           self._call_dataset_fn(dataset_fn),
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index cce41e7..327775a 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -814,7 +814,8 @@
   eager mode.
   """
 
-  def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+  def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+               auto_shard=False):
     """Initialize the MultiWorkerDataset object.
 
     Args:
@@ -822,6 +823,7 @@
       worker_device_map: a dict mapping from each worker to a list of devices
         that belong to this worker.
       prefetch_on_device: whether to prefetch to devices.
+      auto_shard: whether to auto-shard the dataset.
     """
     self._worker_device_map = worker_device_map
     self._datasets = {}
@@ -831,8 +833,9 @@
         six.iteritems(worker_device_map)):
       with ops.device(worker):
         worker_input = dataset_fn()
-        worker_input = input_ops.auto_shard_dataset(
-            worker_input, len(worker_device_map), i)
+        if auto_shard:
+          worker_input = input_ops.auto_shard_dataset(
+              worker_input, len(worker_device_map), i)
         self._datasets[worker] = PerDeviceDataset(
             worker_input,
             worker_devices,