Add a new enum for reduction type in distribution strategy. Currently supports SUM, MEAN and ONLY_FIRST_REPLICA
Allow accepting the new enum in `reduce` and `batch_reduce` APIs. Change some callers to use the new enum. PiperOrigin-RevId: 221475510
This commit is contained in:
parent
00774b2d5e
commit
67593ee405
@ -85,6 +85,7 @@ py_library(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/distribute:multi_worker_util",
|
"//tensorflow/python/distribute:multi_worker_util",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:tape",
|
"//tensorflow/python/eager:tape",
|
||||||
],
|
],
|
||||||
@ -105,6 +106,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/distribute:multi_worker_util",
|
"//tensorflow/python/distribute:multi_worker_util",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -151,6 +153,7 @@ py_library(
|
|||||||
"//tensorflow/python:distribute",
|
"//tensorflow/python:distribute",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
@ -343,6 +346,7 @@ py_library(
|
|||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -669,6 +673,7 @@ py_library(
|
|||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import cross_tower_utils
|
|||||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -128,7 +129,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
with ops.control_dependencies([fetched]):
|
with ops.control_dependencies([fetched]):
|
||||||
# TODO(yuefengz): support non-Mirrored variable as destinations.
|
# TODO(yuefengz): support non-Mirrored variable as destinations.
|
||||||
g = d.reduce(
|
g = d.reduce(
|
||||||
variable_scope.VariableAggregation.SUM, g, destinations=v)
|
reduce_util.ReduceOp.SUM, g, destinations=v)
|
||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
d.update(v, update, g, grouped=False)):
|
d.update(v, update, g, grouped=False)):
|
||||||
after_list.append(d.read_var(v))
|
after_list.append(d.read_var(v))
|
||||||
@ -224,7 +225,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
x = distribution.call_for_each_replica(model_fn)
|
x = distribution.call_for_each_replica(model_fn)
|
||||||
reduced_x = distribution.unwrap(
|
reduced_x = distribution.unwrap(
|
||||||
distribution.reduce(
|
distribution.reduce(
|
||||||
variable_scope.VariableAggregation.MEAN, x,
|
reduce_util.ReduceOp.MEAN, x,
|
||||||
destinations='/cpu:0'))[0]
|
destinations='/cpu:0'))[0]
|
||||||
x = distribution.unwrap(x)[0]
|
x = distribution.unwrap(x)[0]
|
||||||
|
|
||||||
|
@ -24,12 +24,12 @@ import six
|
|||||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||||
from tensorflow.contrib.distribute.python import values as value_lib
|
from tensorflow.contrib.distribute.python import values as value_lib
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ def _simple_broadcast(value, destinations):
|
|||||||
|
|
||||||
|
|
||||||
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||||
aggregation):
|
reduce_op):
|
||||||
# pylint: disable=g-missing-docstring
|
# pylint: disable=g-missing-docstring
|
||||||
all_values = []
|
all_values = []
|
||||||
count = 0
|
count = 0
|
||||||
@ -164,12 +164,11 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
|||||||
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
|
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||||
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
|
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
|
||||||
all_values, accumulation_fn)
|
all_values, accumulation_fn)
|
||||||
if aggregation == vs.VariableAggregation.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
|
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
|
||||||
reduced, count)
|
reduced, count)
|
||||||
elif aggregation != vs.VariableAggregation.SUM:
|
elif reduce_op != reduce_util.ReduceOp.SUM:
|
||||||
raise ValueError("`aggregation` must be VariableAggregation.SUM "
|
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
|
||||||
"or VariableAggregation.MEAN.")
|
|
||||||
return reduced
|
return reduced
|
||||||
|
|
||||||
|
|
||||||
@ -179,15 +178,15 @@ class CrossDeviceOps(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reduce(self, aggregation, per_replica_value, destinations):
|
def reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
"""Reduce `per_replica_value` to `destinations`.
|
"""Reduce `per_replica_value` to `destinations`.
|
||||||
|
|
||||||
It runs the reduction operation defined by `aggregation` and put the
|
It runs the reduction operation defined by `reduce_op` and put the
|
||||||
result on `destinations`.
|
result on `destinations`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
aggregation: Indicates how a variable will be aggregated. Accepted values
|
reduce_op: Indicates how per_replica_value will be reduced. Accepted
|
||||||
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
|
values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
||||||
per_replica_value: a PerReplica object or a tensor with device set.
|
per_replica_value: a PerReplica object or a tensor with device set.
|
||||||
destinations: the reduction destinations.
|
destinations: the reduction destinations.
|
||||||
|
|
||||||
@ -201,17 +200,17 @@ class CrossDeviceOps(object):
|
|||||||
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
||||||
|
|
||||||
validate_destinations(destinations)
|
validate_destinations(destinations)
|
||||||
return self._reduce(aggregation, per_replica_value, destinations)
|
return self._reduce(reduce_op, per_replica_value, destinations)
|
||||||
|
|
||||||
def batch_reduce(self, aggregation, value_destination_pairs):
|
def batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
"""Reduce PerReplica objects in a batch.
|
"""Reduce PerReplica objects in a batch.
|
||||||
|
|
||||||
Reduce each first element in `value_destination_pairs` to each second
|
Reduce each first element in `value_destination_pairs` to each second
|
||||||
element which indicates the destinations.
|
element which indicates the destinations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
aggregation: Indicates how a variable will be aggregated. Accepted values
|
reduce_op: Indicates how per_replica_value will be reduced. Accepted
|
||||||
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
|
values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
||||||
value_destination_pairs: a list or a tuple of tuples of PerReplica objects
|
value_destination_pairs: a list or a tuple of tuples of PerReplica objects
|
||||||
(or tensors with device set if there is one device) and destinations.
|
(or tensors with device set if there is one device) and destinations.
|
||||||
|
|
||||||
@ -231,7 +230,7 @@ class CrossDeviceOps(object):
|
|||||||
for _, d in value_destination_pairs:
|
for _, d in value_destination_pairs:
|
||||||
validate_destinations(d)
|
validate_destinations(d)
|
||||||
|
|
||||||
return self._batch_reduce(aggregation, value_destination_pairs)
|
return self._batch_reduce(reduce_op, value_destination_pairs)
|
||||||
|
|
||||||
def broadcast(self, tensor, destinations):
|
def broadcast(self, tensor, destinations):
|
||||||
"""Broadcast the `tensor` to destinations.
|
"""Broadcast the `tensor` to destinations.
|
||||||
@ -246,11 +245,11 @@ class CrossDeviceOps(object):
|
|||||||
validate_destinations(destinations)
|
validate_destinations(destinations)
|
||||||
return self._broadcast(tensor, destinations)
|
return self._broadcast(tensor, destinations)
|
||||||
|
|
||||||
def _reduce(self, aggregation, per_replica_value, destinations):
|
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"_reduce method must be implemented in descendants.")
|
"_reduce method must be implemented in descendants.")
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"_batch_reduce method must be implemented in descendants.")
|
"_batch_reduce method must be implemented in descendants.")
|
||||||
|
|
||||||
@ -276,19 +275,19 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps):
|
|||||||
self.accumulation_fn = accumulation_fn
|
self.accumulation_fn = accumulation_fn
|
||||||
super(ReductionToOneDeviceCrossDeviceOps, self).__init__()
|
super(ReductionToOneDeviceCrossDeviceOps, self).__init__()
|
||||||
|
|
||||||
def _reduce(self, aggregation, per_replica_value, destinations):
|
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
if check_destinations(destinations):
|
if check_destinations(destinations):
|
||||||
devices = get_devices_from(destinations)
|
devices = get_devices_from(destinations)
|
||||||
else:
|
else:
|
||||||
devices = get_devices_from(per_replica_value)
|
devices = get_devices_from(per_replica_value)
|
||||||
reduce_to_device = self.reduce_to_device or devices[0]
|
reduce_to_device = self.reduce_to_device or devices[0]
|
||||||
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
||||||
self.accumulation_fn, aggregation)
|
self.accumulation_fn, reduce_op)
|
||||||
return self.broadcast(reduced, devices)
|
return self.broadcast(reduced, devices)
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
return [
|
return [
|
||||||
self._reduce(aggregation, t, destinations=v)
|
self._reduce(reduce_op, t, destinations=v)
|
||||||
for t, v in value_destination_pairs
|
for t, v in value_destination_pairs
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -323,20 +322,20 @@ def _group_value_by_device(per_replica_values):
|
|||||||
|
|
||||||
def _ungroup_and_make_mirrored(grouped_reduced,
|
def _ungroup_and_make_mirrored(grouped_reduced,
|
||||||
destinations,
|
destinations,
|
||||||
aggregation,
|
reduce_op,
|
||||||
num_between_graph_workers=1):
|
num_between_graph_workers=1):
|
||||||
"""Ungroup results from all-reduce and make Mirrored objects.
|
"""Ungroup results from all-reduce and make Mirrored objects.
|
||||||
|
|
||||||
Each all-reduce result will be divided by the number of destinations before
|
Each all-reduce result will be divided by the number of destinations before
|
||||||
Mirrored objects are created if aggregation is "mean".
|
Mirrored objects are created if reduce_op is "mean".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grouped_reduced: a list of lists, each sublist has components for each
|
grouped_reduced: a list of lists, each sublist has components for each
|
||||||
device, paired with a None. It is the result from
|
device, paired with a None. It is the result from
|
||||||
cross_tower_utils.aggregate_gradients_using*.
|
cross_tower_utils.aggregate_gradients_using*.
|
||||||
destinations: a list of device strings for returned Mirrored objects.
|
destinations: a list of device strings for returned Mirrored objects.
|
||||||
aggregation: Indicates how a variable will be aggregated. Accepted values
|
reduce_op: Indicates how values will be aggregated. Accepted values
|
||||||
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
|
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
||||||
num_between_graph_workers: number of workers in the between-graph
|
num_between_graph_workers: number of workers in the between-graph
|
||||||
replication.
|
replication.
|
||||||
|
|
||||||
@ -346,7 +345,7 @@ def _ungroup_and_make_mirrored(grouped_reduced,
|
|||||||
index = [{} for _ in range(len(grouped_reduced[0]))]
|
index = [{} for _ in range(len(grouped_reduced[0]))]
|
||||||
for d, per_replica_reduced in enumerate(grouped_reduced):
|
for d, per_replica_reduced in enumerate(grouped_reduced):
|
||||||
for i, (v, _) in enumerate(per_replica_reduced):
|
for i, (v, _) in enumerate(per_replica_reduced):
|
||||||
if aggregation == vs.VariableAggregation.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
index[i][destinations[d]] = v / (
|
index[i][destinations[d]] = v / (
|
||||||
len(destinations) * num_between_graph_workers)
|
len(destinations) * num_between_graph_workers)
|
||||||
else:
|
else:
|
||||||
@ -557,13 +556,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|||||||
self._agg_small_grads_max_group = agg_small_grads_max_group
|
self._agg_small_grads_max_group = agg_small_grads_max_group
|
||||||
super(AllReduceCrossDeviceOps, self).__init__()
|
super(AllReduceCrossDeviceOps, self).__init__()
|
||||||
|
|
||||||
def _reduce(self, aggregation, per_replica_value, destinations):
|
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
||||||
per_replica_value)
|
per_replica_value)
|
||||||
if (_devices_match(per_replica_value, destinations)
|
if (_devices_match(per_replica_value, destinations)
|
||||||
and not context.executing_eagerly()
|
and not context.executing_eagerly()
|
||||||
and not contains_indexed_slices):
|
and not contains_indexed_slices):
|
||||||
return self._batch_all_reduce(aggregation, [per_replica_value])[0]
|
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||||
else:
|
else:
|
||||||
if contains_indexed_slices:
|
if contains_indexed_slices:
|
||||||
logging.log_first_n(
|
logging.log_first_n(
|
||||||
@ -576,16 +575,16 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|||||||
devices = get_devices_from(per_replica_value)
|
devices = get_devices_from(per_replica_value)
|
||||||
reduce_to_device = devices[0]
|
reduce_to_device = devices[0]
|
||||||
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
||||||
math_ops.add_n, aggregation)
|
math_ops.add_n, reduce_op)
|
||||||
return self.broadcast(reduced, devices)
|
return self.broadcast(reduced, devices)
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||||
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
||||||
value_destination_pairs)
|
value_destination_pairs)
|
||||||
if (all_devices_match and not context.executing_eagerly()
|
if (all_devices_match and not context.executing_eagerly()
|
||||||
and not contains_indexed_slices):
|
and not contains_indexed_slices):
|
||||||
return self._batch_all_reduce(aggregation,
|
return self._batch_all_reduce(reduce_op,
|
||||||
[v[0] for v in value_destination_pairs])
|
[v[0] for v in value_destination_pairs])
|
||||||
else:
|
else:
|
||||||
if not all_devices_match:
|
if not all_devices_match:
|
||||||
@ -595,11 +594,11 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|||||||
10)
|
10)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self._reduce(aggregation, t, destinations=v)
|
self._reduce(reduce_op, t, destinations=v)
|
||||||
for t, v in value_destination_pairs
|
for t, v in value_destination_pairs
|
||||||
]
|
]
|
||||||
|
|
||||||
def _batch_all_reduce(self, aggregation, per_replica_values):
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||||
"""All reduce algorithm in a batch."""
|
"""All reduce algorithm in a batch."""
|
||||||
logging.log_first_n(
|
logging.log_first_n(
|
||||||
logging.INFO, "batch_all_reduce invoked for batches size = %d with "
|
logging.INFO, "batch_all_reduce invoked for batches size = %d with "
|
||||||
@ -630,7 +629,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|||||||
|
|
||||||
reduced = _unpack_tensors(reduced, tensor_packer)
|
reduced = _unpack_tensors(reduced, tensor_packer)
|
||||||
return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices,
|
return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices,
|
||||||
aggregation)
|
reduce_op)
|
||||||
|
|
||||||
|
|
||||||
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
|
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
|
||||||
@ -713,7 +712,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
|
|||||||
validate_and_complete_spec(spec) for spec in all_reduce_spec
|
validate_and_complete_spec(spec) for spec in all_reduce_spec
|
||||||
]
|
]
|
||||||
|
|
||||||
def _batch_all_reduce(self, aggregation, per_replica_values):
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||||
"""All reduce algorithm in a batch."""
|
"""All reduce algorithm in a batch."""
|
||||||
logging.log_first_n(
|
logging.log_first_n(
|
||||||
logging.INFO,
|
logging.INFO,
|
||||||
@ -761,7 +760,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
|
|||||||
assert not remaining_grads
|
assert not remaining_grads
|
||||||
|
|
||||||
return _ungroup_and_make_mirrored(aggregated_grads, destinations,
|
return _ungroup_and_make_mirrored(aggregated_grads, destinations,
|
||||||
aggregation)
|
reduce_op)
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): support in-graph collective all-reduce.
|
# TODO(yuefengz): support in-graph collective all-reduce.
|
||||||
@ -795,7 +794,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
super(CollectiveAllReduce, self).__init__()
|
super(CollectiveAllReduce, self).__init__()
|
||||||
|
|
||||||
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
|
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
|
||||||
def _reduce(self, aggregation, per_replica_value, destinations):
|
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
if cross_tower_utils.contains_indexed_slices(per_replica_value):
|
if cross_tower_utils.contains_indexed_slices(per_replica_value):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`IndexSlices` is not supported for Collective All-Reduce.")
|
"`IndexSlices` is not supported for Collective All-Reduce.")
|
||||||
@ -803,7 +802,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Eager execution is not supported for Collective All-Reduce")
|
"Eager execution is not supported for Collective All-Reduce")
|
||||||
|
|
||||||
all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0]
|
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||||
if _devices_match(per_replica_value, destinations):
|
if _devices_match(per_replica_value, destinations):
|
||||||
return all_reduced
|
return all_reduced
|
||||||
else:
|
else:
|
||||||
@ -819,7 +818,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
|
|
||||||
return value_lib.Mirrored(index)
|
return value_lib.Mirrored(index)
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
|
if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`IndexSlices` is not supported for Collective All-Reduce.")
|
"`IndexSlices` is not supported for Collective All-Reduce.")
|
||||||
@ -829,7 +828,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
|
|
||||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||||
if all_devices_match:
|
if all_devices_match:
|
||||||
return self._batch_all_reduce(aggregation,
|
return self._batch_all_reduce(reduce_op,
|
||||||
[v[0] for v in value_destination_pairs])
|
[v[0] for v in value_destination_pairs])
|
||||||
else:
|
else:
|
||||||
if not all_devices_match:
|
if not all_devices_match:
|
||||||
@ -838,11 +837,11 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
"destinations are different.", 10)
|
"destinations are different.", 10)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self._reduce(aggregation, t, destinations=v)
|
self._reduce(reduce_op, t, destinations=v)
|
||||||
for t, v in value_destination_pairs
|
for t, v in value_destination_pairs
|
||||||
]
|
]
|
||||||
|
|
||||||
def _batch_all_reduce(self, aggregation, per_replica_values):
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||||
"""All-reduce across all workers in a batch."""
|
"""All-reduce across all workers in a batch."""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -883,7 +882,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
return _ungroup_and_make_mirrored(
|
return _ungroup_and_make_mirrored(
|
||||||
new_device_grads,
|
new_device_grads,
|
||||||
per_replica_values[0].devices,
|
per_replica_values[0].devices,
|
||||||
aggregation,
|
reduce_op,
|
||||||
num_between_graph_workers=self._num_workers)
|
num_between_graph_workers=self._num_workers)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,13 +30,13 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
|
|||||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||||
from tensorflow.contrib.distribute.python import values as value_lib
|
from tensorflow.contrib.distribute.python import values as value_lib
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
|
|
||||||
|
|
||||||
@ -143,24 +143,24 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
for destinations in all_destinations:
|
for destinations in all_destinations:
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.reduce(
|
cross_tower_ops.reduce(
|
||||||
vs.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.MEAN,
|
||||||
per_replica,
|
per_replica,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean, destinations))
|
_fake_mirrored(mean, destinations))
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.reduce(
|
cross_tower_ops.reduce(
|
||||||
vs.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.MEAN,
|
||||||
per_replica_2,
|
per_replica_2,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean_2, destinations))
|
_fake_mirrored(mean_2, destinations))
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.reduce(
|
cross_tower_ops.reduce(
|
||||||
vs.VariableAggregation.SUM, per_replica,
|
reduce_util.ReduceOp.SUM, per_replica,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean * len(devices), destinations))
|
_fake_mirrored(mean * len(devices), destinations))
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.reduce(
|
cross_tower_ops.reduce(
|
||||||
vs.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
per_replica_2,
|
per_replica_2,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean_2 * len(devices), destinations))
|
_fake_mirrored(mean_2 * len(devices), destinations))
|
||||||
@ -169,7 +169,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
for d1, d2 in itertools.product(all_destinations, all_destinations):
|
for d1, d2 in itertools.product(all_destinations, all_destinations):
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.batch_reduce(
|
cross_tower_ops.batch_reduce(
|
||||||
vs.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.MEAN,
|
||||||
[(per_replica, d1), (per_replica_2, d2)]),
|
[(per_replica, d1), (per_replica_2, d2)]),
|
||||||
[
|
[
|
||||||
_fake_mirrored(mean, d1),
|
_fake_mirrored(mean, d1),
|
||||||
@ -177,7 +177,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
])
|
])
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
cross_tower_ops.batch_reduce(
|
cross_tower_ops.batch_reduce(
|
||||||
vs.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
[(per_replica, d1), (per_replica_2, d2)]),
|
[(per_replica, d1), (per_replica_2, d2)]),
|
||||||
[
|
[
|
||||||
_fake_mirrored(mean * len(devices), d1),
|
_fake_mirrored(mean * len(devices), d1),
|
||||||
@ -281,7 +281,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
|||||||
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
|
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
|
||||||
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
|
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
|
||||||
result = cross_tower_ops_lib._simple_reduce(
|
result = cross_tower_ops_lib._simple_reduce(
|
||||||
per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
|
per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
|
||||||
|
|
||||||
# Test that the result is semantically equal to both the concatenated
|
# Test that the result is semantically equal to both the concatenated
|
||||||
# IndexedSlices with and without duplicate indices.
|
# IndexedSlices with and without duplicate indices.
|
||||||
@ -302,11 +302,11 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
|||||||
"AllReduceCrossDeviceOps",
|
"AllReduceCrossDeviceOps",
|
||||||
cross_tower_ops_lib.AllReduceCrossDeviceOps())
|
cross_tower_ops_lib.AllReduceCrossDeviceOps())
|
||||||
],
|
],
|
||||||
aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN],
|
reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN],
|
||||||
batch_reduce=[True, False],
|
batch_reduce=[True, False],
|
||||||
mode=["graph", "eager"],
|
mode=["graph", "eager"],
|
||||||
required_gpus=1))
|
required_gpus=1))
|
||||||
def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation,
|
def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, reduce_op,
|
||||||
batch_reduce):
|
batch_reduce):
|
||||||
devices = ["/cpu:0", "/gpu:0"]
|
devices = ["/cpu:0", "/gpu:0"]
|
||||||
dense_shape = [5, 2]
|
dense_shape = [5, 2]
|
||||||
@ -317,19 +317,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
|||||||
|
|
||||||
if batch_reduce:
|
if batch_reduce:
|
||||||
result = cross_tower_ops_instance.batch_reduce(
|
result = cross_tower_ops_instance.batch_reduce(
|
||||||
aggregation, [(per_replica, devices)])
|
reduce_op, [(per_replica, devices)])
|
||||||
else:
|
else:
|
||||||
result = cross_tower_ops_instance.reduce(
|
result = cross_tower_ops_instance.reduce(
|
||||||
aggregation, per_replica, devices)
|
reduce_op, per_replica, devices)
|
||||||
|
|
||||||
total_indices_with_dups = [1, 1, 3]
|
total_indices_with_dups = [1, 1, 3]
|
||||||
total_indices_without_dups = [1, 3]
|
total_indices_without_dups = [1, 3]
|
||||||
|
|
||||||
if aggregation == vs.VariableAggregation.SUM:
|
if reduce_op == reduce_util.ReduceOp.SUM:
|
||||||
total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
|
total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
|
||||||
total_values_without_dups = [[4., 6.], [5., 6.]]
|
total_values_without_dups = [[4., 6.], [5., 6.]]
|
||||||
else:
|
else:
|
||||||
assert aggregation == vs.VariableAggregation.MEAN
|
assert reduce_op == reduce_util.ReduceOp.MEAN
|
||||||
total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
|
total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
|
||||||
total_values_without_dups = [[2., 3.], [2.5, 3.]]
|
total_values_without_dups = [[2., 3.], [2.5, 3.]]
|
||||||
|
|
||||||
@ -502,26 +502,26 @@ class MultiWorkerCollectiveAllReduceTest(
|
|||||||
for destinations in all_destinations:
|
for destinations in all_destinations:
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.reduce(
|
collective_all_reduce.reduce(
|
||||||
vs.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.MEAN,
|
||||||
per_replica,
|
per_replica,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean, destinations), sess)
|
_fake_mirrored(mean, destinations), sess)
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.reduce(
|
collective_all_reduce.reduce(
|
||||||
vs.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.MEAN,
|
||||||
per_replica_2,
|
per_replica_2,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean_2, destinations), sess)
|
_fake_mirrored(mean_2, destinations), sess)
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.reduce(
|
collective_all_reduce.reduce(
|
||||||
vs.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
per_replica,
|
per_replica,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean * len(devices) * num_workers, destinations),
|
_fake_mirrored(mean * len(devices) * num_workers, destinations),
|
||||||
sess)
|
sess)
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.reduce(
|
collective_all_reduce.reduce(
|
||||||
vs.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
per_replica_2,
|
per_replica_2,
|
||||||
destinations=destinations),
|
destinations=destinations),
|
||||||
_fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
|
_fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
|
||||||
@ -530,7 +530,7 @@ class MultiWorkerCollectiveAllReduceTest(
|
|||||||
# test batch_reduce()
|
# test batch_reduce()
|
||||||
for d1, d2 in itertools.product(all_destinations, all_destinations):
|
for d1, d2 in itertools.product(all_destinations, all_destinations):
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN,
|
collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN,
|
||||||
[(per_replica, d1),
|
[(per_replica, d1),
|
||||||
(per_replica_2, d2)]),
|
(per_replica_2, d2)]),
|
||||||
[
|
[
|
||||||
@ -538,7 +538,7 @@ class MultiWorkerCollectiveAllReduceTest(
|
|||||||
_fake_mirrored(mean_2, d2)
|
_fake_mirrored(mean_2, d2)
|
||||||
], sess)
|
], sess)
|
||||||
self._assert_values_equal(
|
self._assert_values_equal(
|
||||||
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
|
collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM,
|
||||||
[(per_replica, d1),
|
[(per_replica, d1),
|
||||||
(per_replica_2, d2)]),
|
(per_replica_2, d2)]),
|
||||||
[
|
[
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations
|
|||||||
from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example
|
from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example
|
||||||
from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
|
from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -484,7 +485,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(distribution.num_replicas_in_sync,
|
self.assertEqual(distribution.num_replicas_in_sync,
|
||||||
len(distribution.unwrap(loss_output)))
|
len(distribution.unwrap(loss_output)))
|
||||||
loss_output = distribution.reduce(
|
loss_output = distribution.reduce(
|
||||||
aggregation=variables_lib.VariableAggregation.MEAN,
|
aggregation=reduce_util.ReduceOp.MEAN,
|
||||||
value=loss_output, destinations="/device:CPU:0")
|
value=loss_output, destinations="/device:CPU:0")
|
||||||
|
|
||||||
unwrapped_output = distribution.unwrap(loss_output)
|
unwrapped_output = distribution.unwrap(loss_output)
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import shared_variable_creator
|
|||||||
from tensorflow.contrib.distribute.python import values
|
from tensorflow.contrib.distribute.python import values
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -178,7 +179,7 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
|
|||||||
return values.regroup({t.device: t.main_result for t in threads})
|
return values.regroup({t.device: t.main_result for t in threads})
|
||||||
|
|
||||||
|
|
||||||
def _reduce_non_distributed_value(distribution, aggregation, value,
|
def _reduce_non_distributed_value(distribution, reduce_op, value,
|
||||||
destinations):
|
destinations):
|
||||||
"""Reduce a non-DistributedValue `value` to `destinations`."""
|
"""Reduce a non-DistributedValue `value` to `destinations`."""
|
||||||
if isinstance(value, values.DistributedValues):
|
if isinstance(value, values.DistributedValues):
|
||||||
@ -190,21 +191,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
|
|||||||
# and equal to 0.
|
# and equal to 0.
|
||||||
if value == 0:
|
if value == 0:
|
||||||
return 0
|
return 0
|
||||||
# If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this
|
# If the reduce op is MEAN or ONLY_FIRST_REPLICA, then this
|
||||||
# essentially means that the same value should be on all destinations.
|
# essentially means that the same value should be on all destinations.
|
||||||
if aggregation in (
|
if reduce_op in (reduce_util.ReduceOp.MEAN,
|
||||||
variable_scope.VariableAggregation.MEAN,
|
reduce_util.ReduceOp.ONLY_FIRST_REPLICA):
|
||||||
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
cross_tower_ops_lib.validate_destinations(destinations)
|
cross_tower_ops_lib.validate_destinations(destinations)
|
||||||
# We do not support an aggregation type of SUM if the value is the same across
|
# We do not support a reduce op of SUM if the value is the same across
|
||||||
# all replicas. We call this as part of assign functions for MirroredVariables
|
# all replicas. We call this as part of assign functions for MirroredVariables
|
||||||
# and summing up identical values across replicas is not clearly defined.
|
# and summing up identical values across replicas is not clearly defined.
|
||||||
if (len(distribution.worker_devices) != 1 or
|
if (len(distribution.worker_devices) != 1 or
|
||||||
not cross_tower_ops_lib.check_destinations(destinations)):
|
not cross_tower_ops_lib.check_destinations(destinations)):
|
||||||
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
|
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
|
||||||
"the given aggregation %s." % (value, aggregation))
|
"the given reduce op %s." % (value, reduce_op))
|
||||||
# TODO(anjalisridhar): Moves these methods to a device utility file?
|
# TODO(anjalisridhar): Moves these methods to a device utility file?
|
||||||
devices = cross_tower_ops_lib.get_devices_from(destinations)
|
devices = cross_tower_ops_lib.get_devices_from(destinations)
|
||||||
if len(devices) == 1:
|
if len(devices) == 1:
|
||||||
@ -588,28 +588,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
|||||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps())
|
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps())
|
||||||
return self._cross_tower_ops
|
return self._cross_tower_ops
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
assert not isinstance(value, values.Mirrored)
|
assert not isinstance(value, values.Mirrored)
|
||||||
if not isinstance(value, values.DistributedValues):
|
if not isinstance(value, values.DistributedValues):
|
||||||
# This function handles reducing values that are not PerReplica or
|
# This function handles reducing values that are not PerReplica or
|
||||||
# Mirrored values. For example, the same value could be present on all
|
# Mirrored values. For example, the same value could be present on all
|
||||||
# replicas in which case `value` would be a single value or value could
|
# replicas in which case `value` would be a single value or value could
|
||||||
# be 0.
|
# be 0.
|
||||||
return _reduce_non_distributed_value(self, aggregation, value,
|
return _reduce_non_distributed_value(self, reduce_op, value,
|
||||||
destinations)
|
destinations)
|
||||||
if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA:
|
if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA:
|
||||||
value = value.get(self._devices[0])
|
value = value.get(self._devices[0])
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
return value
|
return value
|
||||||
return self.broadcast(value, destinations)
|
return self.broadcast(value, destinations)
|
||||||
return self._get_cross_tower_ops().reduce(
|
return self._get_cross_tower_ops().reduce(
|
||||||
aggregation, value, destinations=destinations)
|
reduce_op, value, destinations=destinations)
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA:
|
if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA:
|
||||||
return [self.broadcast(v.get(self._devices[0]), d)
|
return [self.broadcast(v.get(self._devices[0]), d)
|
||||||
for v, d in value_destination_pairs]
|
for v, d in value_destination_pairs]
|
||||||
return self._get_cross_tower_ops().batch_reduce(aggregation,
|
return self._get_cross_tower_ops().batch_reduce(reduce_op,
|
||||||
value_destination_pairs)
|
value_destination_pairs)
|
||||||
|
|
||||||
def _update(self, var, options, fn, *args, **kwargs):
|
def _update(self, var, options, fn, *args, **kwargs):
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import strategy_test_lib
|
|||||||
from tensorflow.contrib.distribute.python import values
|
from tensorflow.contrib.distribute.python import values
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
@ -117,7 +118,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
|
|||||||
with dist.scope():
|
with dist.scope():
|
||||||
result = dist.call_for_each_replica(_replica_id)
|
result = dist.call_for_each_replica(_replica_id)
|
||||||
reduced = dist.reduce(
|
reduced = dist.reduce(
|
||||||
variable_scope.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
result,
|
result,
|
||||||
destinations="/device:CPU:0")
|
destinations="/device:CPU:0")
|
||||||
unwrapped = dist.unwrap(reduced)
|
unwrapped = dist.unwrap(reduced)
|
||||||
@ -137,7 +138,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
|
|||||||
with dist.scope():
|
with dist.scope():
|
||||||
result = dist.call_for_each_replica(run_fn)
|
result = dist.call_for_each_replica(run_fn)
|
||||||
reduced = dist.reduce(
|
reduced = dist.reduce(
|
||||||
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA,
|
reduce_util.ReduceOp.ONLY_FIRST_REPLICA,
|
||||||
result,
|
result,
|
||||||
destinations="/device:CPU:0")
|
destinations="/device:CPU:0")
|
||||||
unwrapped = dist.unwrap(reduced)
|
unwrapped = dist.unwrap(reduced)
|
||||||
@ -157,7 +158,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
|
|||||||
dist = mirrored_strategy.MirroredStrategy(devices)
|
dist = mirrored_strategy.MirroredStrategy(devices)
|
||||||
with dist.scope():
|
with dist.scope():
|
||||||
reduced = dist.reduce(
|
reduced = dist.reduce(
|
||||||
variable_scope.VariableAggregation.SUM,
|
reduce_util.ReduceOp.SUM,
|
||||||
1.0,
|
1.0,
|
||||||
destinations=["/device:CPU:0", "/device:GPU:0"])
|
destinations=["/device:CPU:0", "/device:GPU:0"])
|
||||||
unwrapped = dist.unwrap(reduced)
|
unwrapped = dist.unwrap(reduced)
|
||||||
@ -912,7 +913,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
|
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
|
||||||
"with the given aggregation VariableAggregation.SUM."):
|
"with the given reduce op ReduceOp.SUM."):
|
||||||
self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn)))
|
self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn)))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||||
|
@ -122,8 +122,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
|||||||
with ops.device(self._device), _OneDeviceReplicaContext(self):
|
with ops.device(self._device), _OneDeviceReplicaContext(self):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
del aggregation, destinations
|
del reduce_op, destinations
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _update(self, var, options, fn, *args, **kwargs):
|
def _update(self, var, options, fn, *args, **kwargs):
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
|
|||||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||||
from tensorflow.contrib.distribute.python import values
|
from tensorflow.contrib.distribute.python import values
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -307,24 +308,24 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
|
|||||||
"Cannot reduce to another worker: %r, current worker is %r" %
|
"Cannot reduce to another worker: %r, current worker is %r" %
|
||||||
(d, self._worker_device))
|
(d, self._worker_device))
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
self._verify_destinations_not_different_worker(destinations)
|
self._verify_destinations_not_different_worker(destinations)
|
||||||
if not isinstance(value, values.DistributedValues):
|
if not isinstance(value, values.DistributedValues):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return mirrored_strategy._reduce_non_distributed_value(
|
return mirrored_strategy._reduce_non_distributed_value(
|
||||||
self, aggregation, value, destinations)
|
self, reduce_op, value, destinations)
|
||||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA:
|
||||||
return self.broadcast(value.get(self._compute_devices[0]), destinations)
|
return self.broadcast(value.get(self._compute_devices[0]), destinations)
|
||||||
return self._cross_tower_ops.reduce(
|
return self._cross_tower_ops.reduce(
|
||||||
aggregation, value, destinations=destinations)
|
reduce_op, value, destinations=destinations)
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA:
|
||||||
return [self.broadcast(v.get(self._compute_devices[0]), d)
|
return [self.broadcast(v.get(self._compute_devices[0]), d)
|
||||||
for v, d in value_destination_pairs]
|
for v, d in value_destination_pairs]
|
||||||
for _, destinations in value_destination_pairs:
|
for _, destinations in value_destination_pairs:
|
||||||
self._verify_destinations_not_different_worker(destinations)
|
self._verify_destinations_not_different_worker(destinations)
|
||||||
return self._cross_tower_ops.batch_reduce(aggregation,
|
return self._cross_tower_ops.batch_reduce(reduce_op,
|
||||||
value_destination_pairs)
|
value_destination_pairs)
|
||||||
|
|
||||||
def _select_single_value(self, structured):
|
def _select_single_value(self, structured):
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import parameter_server_strategy
|
|||||||
from tensorflow.contrib.distribute.python import values
|
from tensorflow.contrib.distribute.python import values
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.estimator import run_config
|
from tensorflow.python.estimator import run_config
|
||||||
@ -473,7 +474,7 @@ class ParameterServerStrategyTestBase(
|
|||||||
with ops.control_dependencies([fetched]):
|
with ops.control_dependencies([fetched]):
|
||||||
# TODO(yuefengz): support non-Mirrored variable as destinations.
|
# TODO(yuefengz): support non-Mirrored variable as destinations.
|
||||||
g = d.reduce(
|
g = d.reduce(
|
||||||
variable_scope.VariableAggregation.SUM, g, destinations=v)
|
reduce_util.ReduceOp.SUM, g, destinations=v)
|
||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
d.update(v, update, g, grouped=False)):
|
d.update(v, update, g, grouped=False)):
|
||||||
after_list.append(d.read_var(v))
|
after_list.append(d.read_var(v))
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
@ -26,7 +27,6 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.layers import core
|
from tensorflow.python.layers import core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import distribution_strategy_context
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
from tensorflow.python.training import optimizer
|
from tensorflow.python.training import optimizer
|
||||||
@ -114,8 +114,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
before_list.append(fetched)
|
before_list.append(fetched)
|
||||||
# control_dependencies irrelevant but harmless in eager execution
|
# control_dependencies irrelevant but harmless in eager execution
|
||||||
with ops.control_dependencies([fetched]):
|
with ops.control_dependencies([fetched]):
|
||||||
g = d.reduce(
|
g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v)
|
||||||
variable_scope.VariableAggregation.SUM, g, destinations=v)
|
|
||||||
with ops.control_dependencies(d.update(
|
with ops.control_dependencies(d.update(
|
||||||
v, update, g, grouped=False)):
|
v, update, g, grouped=False)):
|
||||||
after_list.append(d.read_var(v))
|
after_list.append(d.read_var(v))
|
||||||
@ -169,8 +168,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
fetched = d.read_var(v)
|
fetched = d.read_var(v)
|
||||||
before_list.append(fetched)
|
before_list.append(fetched)
|
||||||
with ops.control_dependencies([fetched]):
|
with ops.control_dependencies([fetched]):
|
||||||
g = d.reduce(
|
g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v)
|
||||||
variable_scope.VariableAggregation.SUM, g, destinations=v)
|
|
||||||
with ops.control_dependencies(d.update(
|
with ops.control_dependencies(d.update(
|
||||||
v, update, g, grouped=False)):
|
v, update, g, grouped=False)):
|
||||||
after_list.append(d.read_var(v))
|
after_list.append(d.read_var(v))
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_
|
|||||||
from tensorflow.contrib.tpu.python.tpu import training_loop
|
from tensorflow.contrib.tpu.python.tpu import training_loop
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -439,12 +440,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args,
|
return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||||
if aggregation == vs.VariableAggregation.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
# TODO(jhseu): Revisit once we support model-parallelism.
|
# TODO(jhseu): Revisit once we support model-parallelism.
|
||||||
value *= (1. / self.num_replicas_in_sync)
|
value *= (1. / self.num_replicas_in_sync)
|
||||||
elif aggregation != vs.VariableAggregation.SUM:
|
elif reduce_op != reduce_util.ReduceOp.SUM:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Currently only support sum & mean in TPUStrategy.")
|
"Currently only support sum & mean in TPUStrategy.")
|
||||||
return tpu_ops.cross_replica_sum(value)
|
return tpu_ops.cross_replica_sum(value)
|
||||||
@ -459,10 +460,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Multiple devices are not supported for TPUStrategy")
|
raise ValueError("Multiple devices are not supported for TPUStrategy")
|
||||||
|
|
||||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA:
|
||||||
return value[0]
|
return value[0]
|
||||||
output = math_ops.add_n(value)
|
output = math_ops.add_n(value)
|
||||||
if aggregation == vs.VariableAggregation.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
return output * (1. / len(value))
|
return output * (1. / len(value))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ import six
|
|||||||
from tensorflow.contrib.distribute.python import input_ops
|
from tensorflow.contrib.distribute.python import input_ops
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
@ -373,12 +374,13 @@ class MirroredVariable(DistributedVariable, Mirrored,
|
|||||||
if self._aggregation == vs.VariableAggregation.NONE:
|
if self._aggregation == vs.VariableAggregation.NONE:
|
||||||
raise ValueError("You must specify an aggregation method to update a "
|
raise ValueError("You must specify an aggregation method to update a "
|
||||||
"MirroredVariable in Replica Context.")
|
"MirroredVariable in Replica Context.")
|
||||||
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(
|
||||||
|
self._aggregation)
|
||||||
|
|
||||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
||||||
return strategy.update(
|
return strategy.update(
|
||||||
self, f,
|
self, f,
|
||||||
strategy.reduce(
|
strategy.reduce(reduce_op, value=value, destinations=self),
|
||||||
aggregation=self._aggregation, value=value, destinations=self),
|
|
||||||
*other_args, **other_kwargs)
|
*other_args, **other_kwargs)
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
return distribution_strategy_context.get_replica_context().merge_call(
|
||||||
@ -614,12 +616,13 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
|
|||||||
if self._aggregation == vs.VariableAggregation.NONE:
|
if self._aggregation == vs.VariableAggregation.NONE:
|
||||||
raise ValueError("You must specify an aggregation method to update a "
|
raise ValueError("You must specify an aggregation method to update a "
|
||||||
"TPUMirroredVariable in Replica Context.")
|
"TPUMirroredVariable in Replica Context.")
|
||||||
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(
|
||||||
|
self._aggregation)
|
||||||
|
|
||||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
||||||
return strategy.update(
|
return strategy.update(
|
||||||
self, f,
|
self, f,
|
||||||
strategy.reduce(
|
strategy.reduce(reduce_op, value=value, destinations=self),
|
||||||
aggregation=self._aggregation, value=value, destinations=self),
|
|
||||||
*other_args, **other_kwargs)
|
*other_args, **other_kwargs)
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
return distribution_strategy_context.get_replica_context().merge_call(
|
||||||
@ -1549,6 +1552,7 @@ class MultiStepContext(object):
|
|||||||
The aggregation method is also recorded in a dictionary
|
The aggregation method is also recorded in a dictionary
|
||||||
`_last_step_outputs_aggregations` for later interpreting of the
|
`_last_step_outputs_aggregations` for later interpreting of the
|
||||||
outputs as already reduced or not.
|
outputs as already reduced or not.
|
||||||
|
# TODO(priyag): Change aggregation type used here.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if distribution_strategy_context.get_cross_replica_context():
|
if distribution_strategy_context.get_cross_replica_context():
|
||||||
@ -1650,12 +1654,13 @@ class AggregatingVariable(checkpointable.CheckpointableBase):
|
|||||||
if self._aggregation == vs.VariableAggregation.NONE:
|
if self._aggregation == vs.VariableAggregation.NONE:
|
||||||
raise ValueError("You must specify an aggregation method to update a "
|
raise ValueError("You must specify an aggregation method to update a "
|
||||||
"a variable in Replica Context.")
|
"a variable in Replica Context.")
|
||||||
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(
|
||||||
|
self._aggregation)
|
||||||
|
|
||||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
||||||
return strategy.update(
|
return strategy.update(
|
||||||
self, f,
|
self, f,
|
||||||
strategy.reduce(
|
strategy.reduce(reduce_op, value=value, destinations=self),
|
||||||
aggregation=self._aggregation, value=value, destinations=self),
|
|
||||||
*other_args, **other_kwargs)
|
*other_args, **other_kwargs)
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
return distribution_strategy_context.get_replica_context().merge_call(
|
||||||
|
@ -3590,6 +3590,7 @@ py_library(
|
|||||||
":util",
|
":util",
|
||||||
":variable_scope",
|
":variable_scope",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -155,3 +155,11 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "reduce_util",
|
||||||
|
srcs = [
|
||||||
|
"reduce_util.py",
|
||||||
|
],
|
||||||
|
deps = [],
|
||||||
|
)
|
||||||
|
58
tensorflow/python/distribute/reduce_util.py
Normal file
58
tensorflow/python/distribute/reduce_util.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilites for reduce operations."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(priyag): Add this to tf.distribute namespace when it exists.
|
||||||
|
class ReduceOp(enum.Enum):
|
||||||
|
"""Indicates how a set of values should be reduced.
|
||||||
|
|
||||||
|
* `SUM`: Add all the values.
|
||||||
|
* `MEAN`: Take the arithmetic mean ("average") of the values.
|
||||||
|
* `ONLY_FIRST_REPLICA`: Return the value on the first replica.
|
||||||
|
|
||||||
|
TODO(priyag): Add the following types:
|
||||||
|
* `MIN`: Return the minimum of all values.
|
||||||
|
* `MAX`: Return the maximum of all values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUM = 0
|
||||||
|
MEAN = 1
|
||||||
|
ONLY_FIRST_REPLICA = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_variable_aggregation(aggregation):
|
||||||
|
mapping = {
|
||||||
|
variable_scope.VariableAggregation.SUM: ReduceOp.SUM,
|
||||||
|
variable_scope.VariableAggregation.MEAN: ReduceOp.MEAN,
|
||||||
|
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
|
ReduceOp.ONLY_FIRST_REPLICA
|
||||||
|
}
|
||||||
|
|
||||||
|
reduce_op = mapping.get(aggregation)
|
||||||
|
if not reduce_op:
|
||||||
|
raise ValueError("Could not convert from `tf.VariableAggregation` to"
|
||||||
|
"`tf.distribute.ReduceOp` type")
|
||||||
|
return reduce_op
|
||||||
|
|
||||||
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -755,9 +756,13 @@ class DistributionStrategy(object):
|
|||||||
"""Combine (via e.g. sum or mean) values across replicas.
|
"""Combine (via e.g. sum or mean) values across replicas.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
aggregation: Indicates how a variable will be aggregated. Accepted values
|
aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||||
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
|
DEPRECATED but still accepted values:
|
||||||
|
`tf.VariableAggregation.SUM`,
|
||||||
|
`tf.VariableAggregation.MEAN`,
|
||||||
`tf.VariableAggregation.ONLY_FIRST_REPLICA`.
|
`tf.VariableAggregation.ONLY_FIRST_REPLICA`.
|
||||||
|
# TODO(priyag): Rename this argument when moving the method to
|
||||||
|
# DSExtended.
|
||||||
value: A per-replica value with one value per replica.
|
value: A per-replica value with one value per replica.
|
||||||
destinations: A mirrored variable, a per-replica tensor, a device string,
|
destinations: A mirrored variable, a per-replica tensor, a device string,
|
||||||
or list of device strings. The return value will be copied to all
|
or list of device strings. The return value will be copied to all
|
||||||
@ -771,23 +776,32 @@ class DistributionStrategy(object):
|
|||||||
# TODO(josh11b): Return an unwrapped value if colocate_with is a
|
# TODO(josh11b): Return an unwrapped value if colocate_with is a
|
||||||
# single device.
|
# single device.
|
||||||
_require_cross_replica_context(self)
|
_require_cross_replica_context(self)
|
||||||
assert aggregation in [
|
|
||||||
variable_scope.VariableAggregation.SUM,
|
|
||||||
variable_scope.VariableAggregation.MEAN,
|
|
||||||
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
|
||||||
]
|
|
||||||
return self._reduce(aggregation, value, destinations)
|
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
# TODO(priyag): Remove this when all callers have been updated.
|
||||||
|
reduce_op = aggregation
|
||||||
|
if isinstance(aggregation, variable_scope.VariableAggregation):
|
||||||
|
assert aggregation in [
|
||||||
|
variable_scope.VariableAggregation.SUM,
|
||||||
|
variable_scope.VariableAggregation.MEAN,
|
||||||
|
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
||||||
|
]
|
||||||
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
||||||
|
return self._reduce(reduce_op, value, destinations)
|
||||||
|
|
||||||
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
def batch_reduce(self, aggregation, value_destination_pairs):
|
def batch_reduce(self, aggregation, value_destination_pairs):
|
||||||
"""Combine multiple `reduce` calls into one for faster execution.
|
"""Combine multiple `reduce` calls into one for faster execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
aggregation: Indicates how a variable will be aggregated. Accepted values
|
aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||||
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
|
DEPRECATED but still accepted values:
|
||||||
|
`tf.VariableAggregation.SUM`,
|
||||||
|
`tf.VariableAggregation.MEAN`,
|
||||||
`tf.VariableAggregation.ONLY_FIRST_REPLICA`.
|
`tf.VariableAggregation.ONLY_FIRST_REPLICA`.
|
||||||
|
# TODO(priyag): Rename this argument when moving the method to
|
||||||
|
# DSExtended.
|
||||||
value_destination_pairs: A sequence of (value, destinations)
|
value_destination_pairs: A sequence of (value, destinations)
|
||||||
pairs. See `reduce()` for a description.
|
pairs. See `reduce()` for a description.
|
||||||
|
|
||||||
@ -796,16 +810,21 @@ class DistributionStrategy(object):
|
|||||||
"""
|
"""
|
||||||
# TODO(josh11b): More docstring
|
# TODO(josh11b): More docstring
|
||||||
_require_cross_replica_context(self)
|
_require_cross_replica_context(self)
|
||||||
assert aggregation in [
|
|
||||||
variable_scope.VariableAggregation.SUM,
|
|
||||||
variable_scope.VariableAggregation.MEAN,
|
|
||||||
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
|
||||||
]
|
|
||||||
return self._batch_reduce(aggregation, value_destination_pairs)
|
|
||||||
|
|
||||||
def _batch_reduce(self, aggregation, value_destination_pairs):
|
# TODO(priyag): Remove this when all callers have been updated.
|
||||||
|
reduce_op = aggregation
|
||||||
|
if isinstance(aggregation, variable_scope.VariableAggregation):
|
||||||
|
assert aggregation in [
|
||||||
|
variable_scope.VariableAggregation.SUM,
|
||||||
|
variable_scope.VariableAggregation.MEAN,
|
||||||
|
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
||||||
|
]
|
||||||
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
||||||
|
return self._batch_reduce(reduce_op, value_destination_pairs)
|
||||||
|
|
||||||
|
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||||
return [
|
return [
|
||||||
self.reduce(aggregation, t, destinations=v)
|
self.reduce(reduce_op, t, destinations=v)
|
||||||
for t, v in value_destination_pairs
|
for t, v in value_destination_pairs
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1154,9 +1173,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
|||||||
with ReplicaContext(self, replica_id=0):
|
with ReplicaContext(self, replica_id=0):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
def _reduce(self, aggregation, value, destinations):
|
def _reduce(self, reduce_op, value, destinations):
|
||||||
# TODO(josh11b): Use destinations?
|
# TODO(josh11b): Use destinations?
|
||||||
del aggregation, destinations
|
del reduce_op, destinations
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _update(self, var, options, fn, *args, **kwargs):
|
def _update(self, var, options, fn, *args, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user