Rename PerDevice -> PerReplica in distribution strategies.
PiperOrigin-RevId: 220838587
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index d9339f8..efa99d1 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -205,7 +205,7 @@
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
# TODO(yuefengz): shard the dataset.
- return values.PerDeviceDataset(
+ return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._devices, True)
def configure(self,
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index bae0f47..6b2fe0a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -62,26 +62,26 @@
raise ValueError("destinations can not be empty")
-def _make_tensor_into_per_device(input_tensor):
- """Converts a single tensor into a PerDevice object."""
+def _make_tensor_into_per_replica(input_tensor):
+ """Converts a single tensor into a PerReplica object."""
if isinstance(input_tensor, (tuple, list)):
- raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, "
+ raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
"got %r but expected a object that is not a tuple or list."
% (input_tensor,))
- if isinstance(input_tensor, value_lib.PerDevice):
+ if isinstance(input_tensor, value_lib.PerReplica):
return input_tensor
try:
device = input_tensor.device
except AttributeError:
- raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object "
+ raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
"because it doesn't have device set.")
- return value_lib.PerDevice({device: input_tensor})
+ return value_lib.PerReplica({device: input_tensor})
def _normalize_value_destination_pairs(value_destination_pairs):
- """Converts each tensor into a PerDevice object in the input list."""
+ """Converts each tensor into a PerReplica object in the input list."""
result = []
if not isinstance(value_destination_pairs, (list, tuple)):
raise ValueError("`value_destination_pairs` should be a list or tuple")
@@ -93,8 +93,8 @@
raise ValueError("Each element of `value_destination_pairs` should be a "
"tuple of size 2.")
- per_device = _make_tensor_into_per_device(pair[0])
- result.append((per_device, pair[1]))
+ per_replica = _make_tensor_into_per_replica(pair[0])
+ result.append((per_replica, pair[1]))
return result
@@ -105,7 +105,7 @@
if not isinstance(value_destination_pairs, (list, tuple)): return False
if not all([isinstance(pair, tuple) for pair in value_destination_pairs]):
return False
- if not all([isinstance(v[0], value_lib.PerDevice)
+ if not all([isinstance(v[0], value_lib.PerReplica)
for v in value_destination_pairs]):
return False
return True
@@ -149,12 +149,12 @@
return value_lib.Mirrored(index)
-def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
+def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
aggregation):
# pylint: disable=g-missing-docstring
all_values = []
count = 0
- for v in per_device_value._index.values(): # pylint: disable=protected-access
+ for v in per_replica_value._index.values(): # pylint: disable=protected-access
if isinstance(v, value_lib.MapOutput):
v_list = v.get()
if not v_list:
@@ -168,7 +168,7 @@
count += 1
all_values.append(v)
if not all_values:
- raise ValueError("`per_device_value` must be non-empty")
+ raise ValueError("`per_replica_value` must be non-empty")
with ops.device(reduce_to_device):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
@@ -189,8 +189,8 @@
def __init__(self):
pass
- def reduce(self, aggregation, per_device_value, destinations):
- """Reduce `per_device_value` to `destinations`.
+ def reduce(self, aggregation, per_replica_value, destinations):
+ """Reduce `per_replica_value` to `destinations`.
It runs the reduction operation defined by `aggregation` and put the
result on `destinations`.
@@ -198,23 +198,23 @@
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
- per_device_value: a PerDevice object or a tensor with device set.
+ per_replica_value: a PerReplica object or a tensor with device set.
destinations: the reduction destinations.
Returns:
a Mirrored object.
Raises:
- ValueError: if per_device_value is not a PerDevice object.
+ ValueError: if per_replica_value is not a PerReplica object.
"""
- if not isinstance(per_device_value, value_lib.PerDevice):
- per_device_value = _make_tensor_into_per_device(per_device_value)
+ if not isinstance(per_replica_value, value_lib.PerReplica):
+ per_replica_value = _make_tensor_into_per_replica(per_replica_value)
validate_destinations(destinations)
- return self._reduce(aggregation, per_device_value, destinations)
+ return self._reduce(aggregation, per_replica_value, destinations)
def batch_reduce(self, aggregation, value_destination_pairs):
- """Reduce PerDevice objects in a batch.
+ """Reduce PerReplica objects in a batch.
Reduce each first element in `value_destination_pairs` to each second
element which indicates the destinations.
@@ -222,7 +222,7 @@
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
- value_destination_pairs: a list or a tuple of tuples of PerDevice objects
+ value_destination_pairs: a list or a tuple of tuples of PerReplica objects
(or tensors with device set if there is one device) and destinations.
Returns:
@@ -230,11 +230,11 @@
Raises:
ValueError: if `value_destination_pairs` is not a list or a tuple of
- tuples of PerDevice objects and destinations
+ tuples of PerReplica objects and destinations
"""
if not _validate_value_destination_pairs(value_destination_pairs):
# If the first element of each pair is a tensor, we try to turn it into a
- # PerDevice object.
+ # PerReplica object.
value_destination_pairs = _normalize_value_destination_pairs(
value_destination_pairs)
@@ -256,7 +256,7 @@
validate_destinations(destinations)
return self._broadcast(tensor, destinations)
- def _reduce(self, aggregation, per_device_value, destinations):
+ def _reduce(self, aggregation, per_replica_value, destinations):
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
@@ -286,13 +286,13 @@
self.accumulation_fn = accumulation_fn
super(ReductionToOneDeviceCrossDeviceOps, self).__init__()
- def _reduce(self, aggregation, per_device_value, destinations):
+ def _reduce(self, aggregation, per_replica_value, destinations):
if check_destinations(destinations):
devices = get_devices_from(destinations)
else:
- devices = get_devices_from(per_device_value)
+ devices = get_devices_from(per_replica_value)
reduce_to_device = self.reduce_to_device or devices[0]
- reduced = _simple_reduce(per_device_value, reduce_to_device,
+ reduced = _simple_reduce(per_replica_value, reduce_to_device,
self.accumulation_fn, aggregation)
return self.broadcast(reduced, devices)
@@ -303,7 +303,7 @@
]
-def _group_value_by_device(per_device_values):
+def _group_value_by_device(per_replica_values):
"""Group values into sublists by their devices.
This grouping is needed to call the all-reduce library because it expects a
@@ -315,18 +315,18 @@
]
Args:
- per_device_values: a list of PerDevice obejcts.
+ per_replica_values: a list of PerReplica obejcts.
Returns:
a list of lists, each sublist has components for its corresponding device of
- PerDevice objects, paired with a None.
+ PerReplica objects, paired with a None.
"""
- destinations = per_device_values[0].devices
+ destinations = per_replica_values[0].devices
grouped = [[] for _ in range(len(destinations))]
- for per_device_value in per_device_values:
+ for per_replica_value in per_replica_values:
# pylint: disable=protected-access
- for i, v in enumerate(per_device_value._index.values()):
- assert per_device_value.devices == destinations
+ for i, v in enumerate(per_replica_value._index.values()):
+ assert per_replica_value.devices == destinations
grouped[i].append((v, None))
return grouped
@@ -354,8 +354,8 @@
a list of Mirrored objects.
"""
index = [{} for _ in range(len(grouped_reduced[0]))]
- for d, per_device_reduced in enumerate(grouped_reduced):
- for i, (v, _) in enumerate(per_device_reduced):
+ for d, per_replica_reduced in enumerate(grouped_reduced):
+ for i, (v, _) in enumerate(per_replica_reduced):
if aggregation == vs.VariableAggregation.MEAN:
index[i][destinations[d]] = v / (
len(destinations) * num_between_graph_workers)
@@ -567,13 +567,13 @@
self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossDeviceOps, self).__init__()
- def _reduce(self, aggregation, per_device_value, destinations):
+ def _reduce(self, aggregation, per_replica_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
- per_device_value)
- if (_devices_match(per_device_value, destinations)
+ per_replica_value)
+ if (_devices_match(per_replica_value, destinations)
and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(aggregation, [per_device_value])[0]
+ return self._batch_all_reduce(aggregation, [per_replica_value])[0]
else:
if contains_indexed_slices:
logging.log_first_n(
@@ -583,9 +583,9 @@
if check_destinations(destinations):
devices = get_devices_from(destinations)
else:
- devices = get_devices_from(per_device_value)
+ devices = get_devices_from(per_replica_value)
reduce_to_device = devices[0]
- reduced = _simple_reduce(per_device_value, reduce_to_device,
+ reduced = _simple_reduce(per_replica_value, reduce_to_device,
math_ops.add_n, aggregation)
return self.broadcast(reduced, devices)
@@ -609,16 +609,16 @@
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, aggregation, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_replica_values):
"""All reduce algorithm in a batch."""
logging.log_first_n(
logging.INFO, "batch_all_reduce invoked for batches size = %d with "
"algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
"agg_small_grads_max_group = %d" %
- (len(per_device_values), self._all_reduce_alg, self._num_packs,
+ (len(per_replica_values), self._all_reduce_alg, self._num_packs,
self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
- destinations = per_device_values[0].devices
- grouped = _group_value_by_device(per_device_values)
+ destinations = per_replica_values[0].devices
+ grouped = _group_value_by_device(per_replica_values)
device_grad_packs, tensor_packer = _pack_tensors(
grouped, self._num_packs, self._agg_small_grads_max_bytes,
@@ -639,7 +639,7 @@
destinations, device_grad_packs))
reduced = _unpack_tensors(reduced, tensor_packer)
- return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
+ return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices,
aggregation)
@@ -723,18 +723,18 @@
validate_and_complete_spec(spec) for spec in all_reduce_spec
]
- def _batch_all_reduce(self, aggregation, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_replica_values):
"""All reduce algorithm in a batch."""
logging.log_first_n(
logging.INFO,
"distributed batch_all_reduce invoked for batches size = %d with "
"allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
"and agg_small_grads_max_group = %d" %
- (len(per_device_values), self._all_reduce_spec, self._num_packs,
+ (len(per_replica_values), self._all_reduce_spec, self._num_packs,
self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
- destinations = sorted(per_device_values[0].devices)
- device_grads = _group_value_by_device(per_device_values)
+ destinations = sorted(per_replica_values[0].devices)
+ device_grads = _group_value_by_device(per_replica_values)
# The all reduce library requires fully defined shapes.
# TODO(yuefengz): when tensor sharding is not needed, static shapes are not
@@ -805,16 +805,16 @@
super(CollectiveAllReduce, self).__init__()
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
- def _reduce(self, aggregation, per_device_value, destinations):
- if cross_tower_utils.contains_indexed_slices(per_device_value):
+ def _reduce(self, aggregation, per_replica_value, destinations):
+ if cross_tower_utils.contains_indexed_slices(per_replica_value):
raise ValueError(
"`IndexSlices` is not supported for Collective All-Reduce.")
if context.executing_eagerly():
raise ValueError(
"Eager execution is not supported for Collective All-Reduce")
- all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
- if _devices_match(per_device_value, destinations):
+ all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0]
+ if _devices_match(per_replica_value, destinations):
return all_reduced
else:
index = {}
@@ -852,7 +852,7 @@
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, aggregation, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_replica_values):
"""All-reduce across all workers in a batch."""
if context.executing_eagerly():
raise ValueError(
@@ -860,9 +860,9 @@
logging.log_first_n(
logging.INFO, "Collective All-reduce invoked with batches size = %d, "
- "num_workers = %d" % (len(per_device_values), self._num_workers), 10)
+ "num_workers = %d" % (len(per_replica_values), self._num_workers), 10)
- grouped_by_device = _group_value_by_device(per_device_values)
+ grouped_by_device = _group_value_by_device(per_replica_values)
grouped_by_var = list(zip(*grouped_by_device))
# grouped_by_var is grouped by variables and takes the following format:
@@ -892,7 +892,7 @@
new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
return _ungroup_and_make_mirrored(
new_device_grads,
- per_device_values[0].devices,
+ per_replica_values[0].devices,
aggregation,
num_between_graph_workers=self._num_workers)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 6a9e8e0..3e274ba 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -40,12 +40,12 @@
from tensorflow.python.training import device_util
-def _make_per_device(values, devices, regroup=False):
+def _make_per_replica(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
- # We simulate the result of regroup called on PerDevice which strips the
- # PerDevice wrapper if it has only one value.
+ # We simulate the result of regroup called on PerReplica which strips the
+ # PerReplica wrapper if it has only one value.
if len(values) == 1 and regroup:
with ops.device(devices[0]):
placed_v = array_ops.identity(values[0])
@@ -56,7 +56,7 @@
with ops.device(d):
placed_v = array_ops.identity(v)
index[d] = placed_v
- return value_lib.PerDevice(index)
+ return value_lib.PerReplica(index)
# pylint: disable=g-doc-args,g-doc-return-or-yield
@@ -122,11 +122,11 @@
devices = distribution.worker_devices
values = [constant_op.constant(float(d)) for d in range(len(devices))]
- per_device = _make_per_device(values, devices)
+ per_replica = _make_per_replica(values, devices)
mean = (len(devices) - 1.) / 2.
values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
- per_device_2 = _make_per_device(values_2, devices)
+ per_replica_2 = _make_per_replica(values_2, devices)
mean_2 = mean + 1.
destination_mirrored = _fake_mirrored(1., devices)
@@ -144,39 +144,41 @@
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
- per_device,
+ per_replica,
destinations=destinations),
_fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
- per_device_2,
+ per_replica_2,
destinations=destinations),
_fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
- vs.VariableAggregation.SUM, per_device,
+ vs.VariableAggregation.SUM, per_replica,
destinations=destinations),
_fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM,
- per_device_2,
+ per_replica_2,
destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
- cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
- [(per_device, d1), (per_device_2, d2)]),
+ cross_tower_ops.batch_reduce(
+ vs.VariableAggregation.MEAN,
+ [(per_replica, d1), (per_replica_2, d2)]),
[
_fake_mirrored(mean, d1),
_fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
- cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
- [(per_device, d1), (per_device_2, d2)]),
+ cross_tower_ops.batch_reduce(
+ vs.VariableAggregation.SUM,
+ [(per_replica, d1), (per_replica_2, d2)]),
[
_fake_mirrored(mean * len(devices), d1),
_fake_mirrored(mean_2 * len(devices), d2)
@@ -277,9 +279,9 @@
devices = ["/cpu:0", "/gpu:0"]
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
- per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+ per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
result = cross_tower_ops_lib._simple_reduce(
- per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
+ per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
# Test that the result is semantically equal to both the concatenated
# IndexedSlices with and without duplicate indices.
@@ -311,13 +313,14 @@
t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
t1 = _make_indexed_slices(
[[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
- per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+ per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
if batch_reduce:
- result = cross_tower_ops_instance.batch_reduce(aggregation,
- [(per_device, devices)])
+ result = cross_tower_ops_instance.batch_reduce(
+ aggregation, [(per_replica, devices)])
else:
- result = cross_tower_ops_instance.reduce(aggregation, per_device, devices)
+ result = cross_tower_ops_instance.reduce(
+ aggregation, per_replica, devices)
total_indices_with_dups = [1, 1, 3]
total_indices_without_dups = [1, 3]
@@ -478,11 +481,11 @@
# Collective ops doesn't support scalar tensors, so we have to construct
# 1-d tensors.
values = [constant_op.constant([float(d)]) for d in range(len(devices))]
- per_device = _make_per_device(values, devices, regroup=True)
+ per_replica = _make_per_replica(values, devices, regroup=True)
mean = np.array([(len(devices) - 1.) / 2.])
values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
- per_device_2 = _make_per_device(values_2, devices)
+ per_replica_2 = _make_per_replica(values_2, devices)
mean_2 = np.array([mean[0] + 1.])
destination_mirrored = _fake_mirrored(1., devices)
@@ -500,26 +503,26 @@
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
- per_device,
+ per_replica,
destinations=destinations),
_fake_mirrored(mean, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
- per_device_2,
+ per_replica_2,
destinations=destinations),
_fake_mirrored(mean_2, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
- per_device,
+ per_replica,
destinations=destinations),
_fake_mirrored(mean * len(devices) * num_workers, destinations),
sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
- per_device_2,
+ per_replica_2,
destinations=destinations),
_fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
sess)
@@ -528,16 +531,16 @@
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN,
- [(per_device, d1),
- (per_device_2, d2)]),
+ [(per_replica, d1),
+ (per_replica_2, d2)]),
[
_fake_mirrored(mean, d1),
_fake_mirrored(mean_2, d2)
], sess)
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
- [(per_device, d1),
- (per_device_2, d2)]),
+ [(per_replica, d1),
+ (per_replica_2, d2)]),
[
_fake_mirrored(mean * len(devices) * num_workers, d1),
_fake_mirrored(mean_2 * len(devices) * num_workers, d2)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
index d25964f..a991156 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
@@ -98,24 +98,24 @@
self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
@test_util.run_in_graph_and_eager_modes
- def testContainsIndexedSlices_PerDevice(self):
+ def testContainsIndexedSlices_PerReplica(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
- per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1})
- self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+ per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1})
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
@test_util.run_in_graph_and_eager_modes
- def testContainsIndexedSlices_PerDeviceMapOutput(self):
+ def testContainsIndexedSlices_PerReplicaMapOutput(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
- per_device = value_lib.PerDevice({
+ per_replica = value_lib.PerReplica({
"/gpu:0": value_lib.MapOutput([t0]),
"/cpu:0": value_lib.MapOutput([t1])})
- self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index df1c8d1..ebd9648 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -192,7 +192,7 @@
raise ValueError("You are passing a `DistributedValue` to "
"`_reduce_non_distributed_value`, which is not allowed.")
- # If the same value is present on all replicas then the PerDevice value will
+ # If the same value is present on all replicas then the PerReplica value will
# be a single value. We also handle the case when `value` is a single value
# and equal to 0.
if value == 0:
@@ -402,7 +402,8 @@
# TODO(josh11b): Require at least 2 devices?
self._devices = [device_util.resolve(d) for d in devices]
self._canonical_device_set = set(self._devices)
- self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)})
+ self._device_index = values.PerReplica(
+ {d: i for i, d in enumerate(devices)})
def _initialize_multi_worker(self, num_gpus, cluster_spec):
"""Initializes the object for multi-worker training."""
@@ -446,7 +447,7 @@
# TODO(josh11b): Require at least 2 devices?
self._devices = [device_util.resolve(d) for d in devices]
self._canonical_device_set = set(self._devices)
- self._device_index = values.PerDevice(
+ self._device_index = values.PerReplica(
{d: i for i, d in enumerate(devices)})
def _create_variable(self, next_creator, *args, **kwargs):
@@ -493,7 +494,7 @@
partial(self._call_dataset_fn, dataset_fn), self._worker_devices,
self._prefetch_on_device, self._auto_shard_dataset)
else:
- return values.PerDeviceDataset(
+ return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._devices,
self._prefetch_on_device)
@@ -546,10 +547,10 @@
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that have already been aggregated, wrap them in a Mirrored
- # container, else in a PerDevice container.
+ # container, else in a PerReplica container.
if aggregation is variables_lib.VariableAggregation.NONE:
last_step_tensor_outputs_dict[name] = values.regroup(
- {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
+ {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
else:
assert len(output) == 1
last_step_tensor_outputs_dict[name] = output[0]
@@ -577,8 +578,8 @@
**values.select_device_mirrored(d, kwargs)))
index[d] = l
# TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
- # in addition to PerDevice data.
- return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
+ # in addition to PerReplica data.
+ return values.PerReplica({k: values.MapOutput(v) for k, v in index.items()})
def configure(self,
session_config=None,
@@ -617,9 +618,10 @@
def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
- # This function handles reducing values that are not PerDevice or Mirrored
- # values. For example, the same value could be present on all replicas in
- # which case `value` would be a single value or value could be 0.
+ # This function handles reducing values that are not PerReplica or
+ # Mirrored values. For example, the same value could be present on all
+ # replicas in which case `value` would be a single value or value could
+ # be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA:
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index b8e7eda..b47c9b0 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -908,7 +908,7 @@
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignMirroredVarReplicaContextWithSum(self):
- # Test that we don't reduce a non-per-device value with the "sum"
+ # Test that we don't reduce a non-per-replica value with the "sum"
# aggregation type.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
@@ -1320,11 +1320,11 @@
# call_for_each has one trace per device. To check that the expected set
# of variables was accessed on each trace, we first retrieve each
# device-specific graph function.
- per_device_graph_functions = dist.call_for_each_replica(
+ per_replica_graph_functions = dist.call_for_each_replica(
defun.get_concrete_function,
mock_model, *inputs, run_concurrently=False)
for device in devices:
- graph_function = per_device_graph_functions.get(device=device)
+ graph_function = per_replica_graph_functions.get(device=device)
self.assertEqual(set(mock_model.variables),
set(graph_function.graph.variables))
@@ -1398,16 +1398,16 @@
two_variables=True)
@test_util.run_in_graph_and_eager_modes()
- def testPassPerDevice(self):
+ def testPassPerReplica(self):
self._skip_eager_if_gpus_less_than(1)
@function.defun
def fn1(mock_model, factor):
return mock_model(factor)
- factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0})
- expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25,
- "GPU:0": 3.0 * 1.25})
+ factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0})
+ expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25,
+ "GPU:0": 3.0 * 1.25})
self._call_and_check(fn1, [factors], expected_result, [fn1])
@test_util.run_in_graph_and_eager_modes()
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
index 8156444..7ecc852 100644
--- a/tensorflow/contrib/distribute/python/moving_averages_test.py
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -93,7 +93,8 @@
var = variables.Variable([10.0, 11.0])
val = constant_op.constant([1.0, 2.0])
decay = 0.25
- # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ # NOTE(josh11b): We currently generate an error if val is a PerReplica
+ # value.
assign = moving_averages.assign_moving_average(
var, val, decay, zero_debias=False)
@@ -121,7 +122,8 @@
var = variables.Variable([0.0, 0.0])
val = array_ops.placeholder(dtypes.float32)
decay = 0.25
- # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ # NOTE(josh11b): We currently generate an error if val is a PerReplica
+ # value.
assign = moving_averages.assign_moving_average(var, val, decay)
variables.global_variables_initializer().run()
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 616508f..afff24d 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -62,7 +62,7 @@
return next_creator(*args, **kwargs)
def distribute_dataset(self, dataset_fn):
- return values.PerDeviceDataset(
+ return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), [self._device],
self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 00e847d..00c2311 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -223,7 +223,7 @@
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
- return values.PerDeviceDataset(
+ return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._compute_devices, True)
def _broadcast(self, tensor, destinations):
@@ -339,9 +339,9 @@
"You cannot update variable with a Mirrored object with multiple "
"components %r when using ParameterServerStrategy. You must "
"specify a single value or a Mirrored with a single value." % x)
- elif isinstance(x, values.PerDevice):
+ elif isinstance(x, values.PerReplica):
raise ValueError(
- "You cannot update variable with a PerDevice object %r when using "
+ "You cannot update variable with a PerReplica object %r when using "
"ParameterServerStrategy. You must specify a single value or a "
"Mirrored with a single value" % x)
else:
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 07cc2c5..8dccce0 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -141,7 +141,7 @@
# parallelism.
device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices)
if "device:TPU:" in d.name}
- self._device_index = values.PerDevice(device_map)
+ self._device_index = values.PerReplica(device_map)
self._host_device = self.get_host_cpu_device(0)
self._tpu_devices = sorted(device_map.keys())
# Only create variables for the number of replicas we're running.
@@ -308,7 +308,8 @@
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
- # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerReplica
+ # value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -445,7 +446,7 @@
return [val.get(device=d) for d in sorted(val.devices)]
elif isinstance(val, list):
# TODO(josh11b): We need to remove this case; per device values should
- # be represented using a PerDevice wrapper instead of a list with
+ # be represented using a PerReplica wrapper instead of a list with
# one entry per device.
return val
return [val]
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 16d9e1c..fe9bb1f 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -51,7 +51,7 @@
# TODO(josh11b): Should device values be strings or DeviceSpec objects?
# Not sure DeviceSpec objects are usable as a dict key.
class DistributedValues(object):
- """Holds a map from device to values. Either PerDevice or Mirrored."""
+ """Holds a map from device to values. Either PerReplica or Mirrored."""
def __init__(self, index):
self._index = {device_util.canonicalize(key): value
@@ -163,12 +163,12 @@
# TODO(josh11b): Even more operator overloads.
-class PerDevice(DistributedValues):
+class PerReplica(DistributedValues):
"""Holds a map from device to unsynchronized values."""
pass
-# Note that unlike PerDevice, Mirrored values inherit from
+# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
class Mirrored(DistributedDelegate):
"""Holds a map from device to values which are kept in sync."""
@@ -870,7 +870,7 @@
"Replica-local variables may only be assigned in a replica context.")
-class ReplicaLocalVariable(DistributedVariable, PerDevice,
+class ReplicaLocalVariable(DistributedVariable, PerReplica,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
@@ -951,9 +951,9 @@
return device_util.canonicalize(d1) == device_util.canonicalize(d2)
-def regroup(per_device, wrap_class=PerDevice):
- """Makes device->nest map into a nest of PerDevice/Mirrored values."""
- items = list(per_device.items())
+def regroup(per_replica, wrap_class=PerReplica):
+ """Makes device->nest map into a nest of PerReplica/Mirrored values."""
+ items = list(per_replica.items())
assert items
v0 = items[0][1] # First value
@@ -1014,7 +1014,7 @@
# want to return the containing MirroredVariable, after a bunch of
# sanity checking. In particular, each component should have the
# same container, and the devices of the variables should match the
- # keys of the per-device dictionary.
+ # keys of the per-replica dictionary.
if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
@@ -1030,11 +1030,11 @@
return distributed_container
# pylint: enable=protected-access
- return wrap_class(per_device)
+ return wrap_class(per_replica)
def select_device(device, structured):
- """Specialize a nest of regular & per-device values for one device."""
+ """Specialize a nest of regular & per-replica values for one device."""
def _get(x):
return x.get(device) if isinstance(x, DistributedValues) else x
@@ -1079,8 +1079,8 @@
return nest.pack_sequence_as(regrouped, grouped_flat)
-class PerDeviceDataIterator(object):
- """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
+class PerReplicaDataIterator(object):
+ """An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`."""
def __init__(self, iterator, devices, prefetch_on_device=None):
self._iterator = iterator
@@ -1123,8 +1123,8 @@
return self._iterator.output_types
-class PerDeviceDataset(object):
- """Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
+class PerReplicaDataset(object):
+ """Like `tf.data.Dataset` split devices, producing `PerReplica` data."""
def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
@@ -1145,20 +1145,20 @@
self._dataset = dataset.batch(len(devices), drop_remainder=True)
def make_one_shot_iterator(self):
- """Get a one time use iterator for the distributed PerDeviceDataset."""
+ """Get a one time use iterator for the distributed PerReplicaDataset."""
# Graph mode with one shot iterator is disabled.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
# Eager mode prefetching would error out in constructor. Only remaining
# case is non-prefetching in eager mode. We delegate to
- # PerDeviceDataIterator to handle that case.
+ # PerReplicaDataIterator to handle that case.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
+ return PerReplicaDataIterator(
dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
- """Get an initializable iterator for the distributed PerDeviceDataset."""
+ """Get an initializable iterator for the distributed PerReplicaDataset."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
@@ -1169,7 +1169,7 @@
self._dataset, self._devices)
else:
dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
+ return PerReplicaDataIterator(
dataset_iterator,
self._devices,
prefetch_on_device=self._prefetch_on_device)
@@ -1227,7 +1227,7 @@
with ops.device(worker):
data_per_worker = iterator.get_next(name=new_name)
- # Ungroup these per-device value so as to get a flat map from devices to
+ # Ungroup these per-replica value so as to get a flat map from devices to
# values.
for d in worker_devices:
v = select_device(d, data_per_worker)
@@ -1266,8 +1266,8 @@
if auto_shard:
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_pairs), i)
- dataset = PerDeviceDataset(worker_input, worker_devices,
- prefetch_on_device=prefetch_on_device)
+ dataset = PerReplicaDataset(
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
self._datasets.append((worker, dataset))
def make_one_shot_iterator(self):
@@ -1447,7 +1447,7 @@
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and aggregation is set, output
- must be a `PerDevice` value.
+ must be a `PerReplica` value.
The aggregation method is also recorded in a dictionary
`_last_step_outputs_aggregations` for later interpreting of the
outputs as already reduced or not.
@@ -1493,7 +1493,7 @@
def value_container(val):
- """Returns the container that this per-device `value` belongs to.
+ """Returns the container that this per-replica `value` belongs to.
Args:
val: A value returned by `call_for_each_replica()` or a variable
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 25c6d56..268393e 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -189,10 +189,10 @@
class RegroupAndSelectDeviceTest(test.TestCase):
- def _is_per_device(self, result, expected, klass=values.PerDevice):
+ def _is_per_replica(self, result, expected, klass=values.PerReplica):
self.assertIsInstance(result, klass)
# We canonicalize the devices to match the device strings returned
- # by PerDevice, which also does device string canonicalization.
+ # by PerReplica, which also does device string canonicalization.
devices = [device_util.canonicalize(_device_str(i))
for i in range(len(expected))]
self.assertEqual(set(devices), set(result.devices))
@@ -205,18 +205,18 @@
_device_str(1): _nested_value("2")})
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
- self._is_per_device(result[0], ["a1", "a2"])
- self._is_per_device(result[2], ["h1", "h2"])
+ self._is_per_replica(result[0], ["a1", "a2"])
+ self._is_per_replica(result[2], ["h1", "h2"])
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
- self._is_per_device(result[1][0], ["b1", "b2"])
- self._is_per_device(result[1][2], ["g1", "g2"])
+ self._is_per_replica(result[1][0], ["b1", "b2"])
+ self._is_per_replica(result[1][2], ["g1", "g2"])
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
- self._is_per_device(result[1][1]["c"], ["d1", "d2"])
- self._is_per_device(result[1][1]["e"], ["f1", "f2"])
+ self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
+ self._is_per_replica(result[1][1]["e"], ["f1", "f2"])
# Also test that we can undo the merge using select_device()
self.assertEqual(_nested_value("1"),
@@ -237,18 +237,18 @@
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
- self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
- self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)
+ self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
+ self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
- self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
- self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)
+ self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
+ self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
- self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
- self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
+ self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
+ self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
# Also test that we can undo the merge using select_device()
self.assertEqual(_nested_value("1"),
@@ -274,7 +274,7 @@
_device_str(1): ("b", foo)})
self.assertIsInstance(result, tuple)
self.assertEqual(2, len(result))
- self._is_per_device(result[0], ["a", "b"])
+ self._is_per_replica(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_device(), should undo the merge done by regroup().
@@ -340,17 +340,17 @@
merged_estimator_spec))
-class PerDeviceDatasetTest(test.TestCase):
+class PerReplicaDatasetTest(test.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
def _test_iterator(self, devices, dataset, expected_values):
- per_device_dataset = values.PerDeviceDataset(dataset, devices)
+ per_replica_dataset = values.PerReplicaDataset(dataset, devices)
if context.executing_eagerly():
- iterator = per_device_dataset.make_one_shot_iterator()
+ iterator = per_replica_dataset.make_one_shot_iterator()
else:
- iterator = per_device_dataset.make_initializable_iterator()
+ iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
for expected_value in expected_values:
@@ -418,8 +418,8 @@
dataset = dataset_ops.Dataset.from_tensor_slices(
random_ops.random_uniform((10,)))
- per_device_dataset = values.PerDeviceDataset(dataset, devices)
- iterator = per_device_dataset.make_initializable_iterator()
+ per_replica_dataset = values.PerReplicaDataset(dataset, devices)
+ iterator = per_replica_dataset.make_initializable_iterator()
self.evaluate(iterator.initializer)
next_element = iterator.get_next()
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 7fa830c..a8e7f7c 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -291,11 +291,11 @@
* Wrapped values: In order to represent values parallel across devices
(either replicas or the devices associated with a particular value), we
- wrap them in a "PerDevice" or "Mirrored" object that contains a map
- from device to values. "PerDevice" is used when the value may be
- different across devices, and "Mirrored" when the value are the same.
+ wrap them in a "PerReplica" or "Mirrored" object that contains a map
+ from device to values. "PerReplica" is used when the value may be
+ different across replicas, and "Mirrored" when the value are the same.
* Unwrapping and merging: Consider calling a function `fn` on
- multiple devices, like `call_for_each_replica(fn, w)` with an
+ multiple replicas, like `call_for_each_replica(fn, w)` with an
argument `w` that is a wrapped value. This means `w` will have a
map taking replica device `d0` to `w0`, replica device `d1` to `w1`,
etc. `call_for_each_replica()` unwraps `w` before calling `fn`, so
@@ -338,7 +338,7 @@
called _locality_ that says what values are compatible with which
APIs:
- * T: different value for each replica (e.g. a PerDevice-wrapped value).
+ * T: different value for each replica (e.g. a PerReplica-wrapped value).
* M: value is "mirrored" across replicas, i.e. there are copies with the
same value on each replica (e.g. a Mirrored-wrapped value).
* V(`v`): value is "mirrored" across all the devices which have a
@@ -544,7 +544,7 @@
"DistributionStrategy.")
return result
- # TODO(josh11b): `PerDeviceDataset` currently only implements a few methods of
+ # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
# Dataset API such as make_one_shot_iterator and make_initializable_iterator.
# Extend to implement more functionality of datasets.
def distribute_dataset(self, dataset_fn):
@@ -567,7 +567,7 @@
dataset_fn: A function that returns a `tf.data.Dataset`.
Returns:
- A `PerDeviceDataset` that will produce data for each replica.
+ A `PerReplicaDataset` that will produce data for each replica.
"""
raise NotImplementedError("must be implemented in descendants")
@@ -784,8 +784,8 @@
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
`tf.VariableAggregation.ONLY_FIRST_REPLICA`.
- value: A per-device value with one value per replica.
- destinations: A mirrored variable, a per-device tensor, a device string,
+ value: A per-replica value with one value per replica.
+ destinations: A mirrored variable, a per-replica tensor, a device string,
or list of device strings. The return value will be copied to all
destination devices (or all the devices where the `destinations` value
resides). To perform an all-reduction, pass `value` to `destinations`.
@@ -852,7 +852,7 @@
Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
- Neither `*args` nor `**kwargs` may contain per-device values.
+ Neither `*args` nor `**kwargs` may contain per-replica values.
If they contain mirrored values, they will be unwrapped before
calling `fn`.
@@ -900,7 +900,7 @@
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
- """Returns the list of all per-device values contained in `value`.
+ """Returns the list of all per-replica values contained in `value`.
Args:
value: A value returned by `call_for_each_replica()` or a variable
@@ -913,7 +913,7 @@
return self._unwrap(value)
def value_container(self, value):
- """Returns the container that this per-device `value` belongs to.
+ """Returns the container that this per-replica `value` belongs to.
Args:
value: A value returned by `call_for_each_replica()` or a variable
@@ -1111,13 +1111,13 @@
Args:
merge_fn: function that joins arguments from threads that are given as
- PerDevice. It accepts `DistributionStrategy` object as the first
+ PerReplica. It accepts `DistributionStrategy` object as the first
argument.
*args: positional per-thread arguments for `merge_fn`
**kwargs: keyword per-thread arguments for `merge_fn`.
Returns:
- The return value of `merge_fn`, except for `PerDevice` values which are
+ The return value of `merge_fn`, except for `PerReplica` values which are
unpacked.
"""
require_replica_context(self)