From 67593ee405a919f1eec048bd991ddf22dcb1ac22 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 14 Nov 2018 11:07:12 -0800 Subject: [PATCH] Add a new enum for reduction type in distribution strategy. Currently supports SUM, MEAN and ONLY_FIRST_REPLICA Allow accepting the new enum in `reduce` and `batch_reduce` APIs. Change some callers to use the new enum. PiperOrigin-RevId: 221475510 --- tensorflow/contrib/distribute/python/BUILD | 5 ++ .../collective_all_reduce_strategy_test.py | 5 +- .../distribute/python/cross_tower_ops.py | 85 +++++++++---------- .../distribute/python/cross_tower_ops_test.py | 40 ++++----- .../distribute/python/minimize_loss_test.py | 3 +- .../distribute/python/mirrored_strategy.py | 28 +++--- .../python/mirrored_strategy_multigpu_test.py | 9 +- .../distribute/python/one_device_strategy.py | 4 +- .../python/parameter_server_strategy.py | 15 ++-- .../python/parameter_server_strategy_test.py | 3 +- .../distribute/python/strategy_test_lib.py | 8 +- .../contrib/distribute/python/tpu_strategy.py | 11 +-- .../contrib/distribute/python/values.py | 17 ++-- tensorflow/python/BUILD | 1 + tensorflow/python/distribute/BUILD | 8 ++ tensorflow/python/distribute/reduce_util.py | 58 +++++++++++++ tensorflow/python/training/distribute.py | 61 ++++++++----- 17 files changed, 230 insertions(+), 131 deletions(-) create mode 100644 tensorflow/python/distribute/reduce_util.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8e0866c505b..24bcb98e095 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -85,6 +85,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", ], @@ -105,6 +106,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:context", ], ) @@ -151,6 +153,7 @@ py_library( "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -343,6 +346,7 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:util", + "//tensorflow/python/distribute:reduce_util", ], ) @@ -669,6 +673,7 @@ py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:context", "@six_archive//:six", ], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index e3d919dd0d4..219c5f531c5 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -128,7 +129,7 @@ class CollectiveAllReduceStrategyTestBase( with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -224,7 +225,7 @@ class CollectiveAllReduceStrategyTestBase( x = distribution.call_for_each_replica(model_fn) reduced_x = distribution.unwrap( distribution.reduce( - variable_scope.VariableAggregation.MEAN, x, + reduce_util.ReduceOp.MEAN, x, destinations='/cpu:0'))[0] x = distribution.unwrap(x)[0] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index b5b349aa64e..994ed345d8a 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -24,12 +24,12 @@ import six from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.client import device_lib +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util @@ -150,7 +150,7 @@ def _simple_broadcast(value, destinations): def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, - aggregation): + reduce_op): # pylint: disable=g-missing-docstring all_values = [] count = 0 @@ -164,12 +164,11 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) - elif aggregation != vs.VariableAggregation.SUM: - raise ValueError("`aggregation` must be VariableAggregation.SUM " - "or VariableAggregation.MEAN.") + elif reduce_op != reduce_util.ReduceOp.SUM: + raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") return reduced @@ -179,15 +178,15 @@ class CrossDeviceOps(object): def __init__(self): pass - def reduce(self, aggregation, per_replica_value, destinations): + def reduce(self, reduce_op, per_replica_value, destinations): """Reduce `per_replica_value` to `destinations`. - It runs the reduction operation defined by `aggregation` and put the + It runs the reduction operation defined by `reduce_op` and put the result on `destinations`. Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + reduce_op: Indicates how per_replica_value will be reduced. Accepted + values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. per_replica_value: a PerReplica object or a tensor with device set. destinations: the reduction destinations. @@ -201,17 +200,17 @@ class CrossDeviceOps(object): per_replica_value = _make_tensor_into_per_replica(per_replica_value) validate_destinations(destinations) - return self._reduce(aggregation, per_replica_value, destinations) + return self._reduce(reduce_op, per_replica_value, destinations) - def batch_reduce(self, aggregation, value_destination_pairs): + def batch_reduce(self, reduce_op, value_destination_pairs): """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + reduce_op: Indicates how per_replica_value will be reduced. Accepted + values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. value_destination_pairs: a list or a tuple of tuples of PerReplica objects (or tensors with device set if there is one device) and destinations. @@ -231,7 +230,7 @@ class CrossDeviceOps(object): for _, d in value_destination_pairs: validate_destinations(d) - return self._batch_reduce(aggregation, value_destination_pairs) + return self._batch_reduce(reduce_op, value_destination_pairs) def broadcast(self, tensor, destinations): """Broadcast the `tensor` to destinations. @@ -246,11 +245,11 @@ class CrossDeviceOps(object): validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, aggregation, per_replica_value, destinations): + def _reduce(self, reduce_op, per_replica_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") - def _batch_reduce(self, aggregation, value_destination_pairs): + def _batch_reduce(self, reduce_op, value_destination_pairs): raise NotImplementedError( "_batch_reduce method must be implemented in descendants.") @@ -276,19 +275,19 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_replica_value, destinations): + def _reduce(self, reduce_op, per_replica_value, destinations): if check_destinations(destinations): devices = get_devices_from(destinations) else: devices = get_devices_from(per_replica_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_replica_value, reduce_to_device, - self.accumulation_fn, aggregation) + self.accumulation_fn, reduce_op) return self.broadcast(reduced, devices) - def _batch_reduce(self, aggregation, value_destination_pairs): + def _batch_reduce(self, reduce_op, value_destination_pairs): return [ - self._reduce(aggregation, t, destinations=v) + self._reduce(reduce_op, t, destinations=v) for t, v in value_destination_pairs ] @@ -323,20 +322,20 @@ def _group_value_by_device(per_replica_values): def _ungroup_and_make_mirrored(grouped_reduced, destinations, - aggregation, + reduce_op, num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before - Mirrored objects are created if aggregation is "mean". + Mirrored objects are created if reduce_op is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + reduce_op: Indicates how values will be aggregated. Accepted values + are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. num_between_graph_workers: number of workers in the between-graph replication. @@ -346,7 +345,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, index = [{} for _ in range(len(grouped_reduced[0]))] for d, per_replica_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_replica_reduced): - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: index[i][destinations[d]] = v / ( len(destinations) * num_between_graph_workers) else: @@ -557,13 +556,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_replica_value, destinations): + def _reduce(self, reduce_op, per_replica_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( per_replica_value) if (_devices_match(per_replica_value, destinations) and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, [per_replica_value])[0] + return self._batch_all_reduce(reduce_op, [per_replica_value])[0] else: if contains_indexed_slices: logging.log_first_n( @@ -576,16 +575,16 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): devices = get_devices_from(per_replica_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_replica_value, reduce_to_device, - math_ops.add_n, aggregation) + math_ops.add_n, reduce_op) return self.broadcast(reduced, devices) - def _batch_reduce(self, aggregation, value_destination_pairs): + def _batch_reduce(self, reduce_op, value_destination_pairs): all_devices_match = _all_devices_match(value_destination_pairs) contains_indexed_slices = cross_tower_utils.contains_indexed_slices( value_destination_pairs) if (all_devices_match and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, + return self._batch_all_reduce(reduce_op, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: @@ -595,11 +594,11 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): 10) return [ - self._reduce(aggregation, t, destinations=v) + self._reduce(reduce_op, t, destinations=v) for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_replica_values): + def _batch_all_reduce(self, reduce_op, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "batch_all_reduce invoked for batches size = %d with " @@ -630,7 +629,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices, - aggregation) + reduce_op) # For compatibility with code using the old name of `AllReduceCrossDeviceOps`. @@ -713,7 +712,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): validate_and_complete_spec(spec) for spec in all_reduce_spec ] - def _batch_all_reduce(self, aggregation, per_replica_values): + def _batch_all_reduce(self, reduce_op, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, @@ -761,7 +760,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): assert not remaining_grads return _ungroup_and_make_mirrored(aggregated_grads, destinations, - aggregation) + reduce_op) # TODO(yuefengz): support in-graph collective all-reduce. @@ -795,7 +794,7 @@ class CollectiveAllReduce(CrossDeviceOps): super(CollectiveAllReduce, self).__init__() # TODO(yuefengz, tucker): is indexed slices supported by collective ops? - def _reduce(self, aggregation, per_replica_value, destinations): + def _reduce(self, reduce_op, per_replica_value, destinations): if cross_tower_utils.contains_indexed_slices(per_replica_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") @@ -803,7 +802,7 @@ class CollectiveAllReduce(CrossDeviceOps): raise ValueError( "Eager execution is not supported for Collective All-Reduce") - all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0] + all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] if _devices_match(per_replica_value, destinations): return all_reduced else: @@ -819,7 +818,7 @@ class CollectiveAllReduce(CrossDeviceOps): return value_lib.Mirrored(index) - def _batch_reduce(self, aggregation, value_destination_pairs): + def _batch_reduce(self, reduce_op, value_destination_pairs): if cross_tower_utils.contains_indexed_slices(value_destination_pairs): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") @@ -829,7 +828,7 @@ class CollectiveAllReduce(CrossDeviceOps): all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: - return self._batch_all_reduce(aggregation, + return self._batch_all_reduce(reduce_op, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: @@ -838,11 +837,11 @@ class CollectiveAllReduce(CrossDeviceOps): "destinations are different.", 10) return [ - self._reduce(aggregation, t, destinations=v) + self._reduce(reduce_op, t, destinations=v) for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_replica_values): + def _batch_all_reduce(self, reduce_op, per_replica_values): """All-reduce across all workers in a batch.""" if context.executing_eagerly(): raise ValueError( @@ -883,7 +882,7 @@ class CollectiveAllReduce(CrossDeviceOps): return _ungroup_and_make_mirrored( new_device_grads, per_replica_values[0].devices, - aggregation, + reduce_op, num_between_graph_workers=self._num_workers) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 3e274ba67ca..2e352360a43 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -30,13 +30,13 @@ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util @@ -143,24 +143,24 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): for destinations in all_destinations: self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_replica, + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) @@ -169,7 +169,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( cross_tower_ops.batch_reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), @@ -177,7 +177,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): ]) self._assert_values_equal( cross_tower_ops.batch_reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), @@ -281,7 +281,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) result = cross_tower_ops_lib._simple_reduce( - per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -302,11 +302,11 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): "AllReduceCrossDeviceOps", cross_tower_ops_lib.AllReduceCrossDeviceOps()) ], - aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], batch_reduce=[True, False], mode=["graph", "eager"], required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, reduce_op, batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] @@ -317,19 +317,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): if batch_reduce: result = cross_tower_ops_instance.batch_reduce( - aggregation, [(per_replica, devices)]) + reduce_op, [(per_replica, devices)]) else: result = cross_tower_ops_instance.reduce( - aggregation, per_replica, devices) + reduce_op, per_replica, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if aggregation == vs.VariableAggregation.SUM: + if reduce_op == reduce_util.ReduceOp.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert aggregation == vs.VariableAggregation.MEAN + assert reduce_op == reduce_util.ReduceOp.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -502,26 +502,26 @@ class MultiWorkerCollectiveAllReduceTest( for destinations in all_destinations: self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), @@ -530,7 +530,7 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ @@ -538,7 +538,7 @@ class MultiWorkerCollectiveAllReduceTest( _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index c6562463edb..5d3b5d8922a 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -484,7 +485,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual(distribution.num_replicas_in_sync, len(distribution.unwrap(loss_output))) loss_output = distribution.reduce( - aggregation=variables_lib.VariableAggregation.MEAN, + aggregation=reduce_util.ReduceOp.MEAN, value=loss_output, destinations="/device:CPU:0") unwrapped_output = distribution.unwrap(loss_output) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 62619a57563..bf065713a06 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import shared_variable_creator from tensorflow.contrib.distribute.python import values from tensorflow.python import pywrap_tensorflow from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op @@ -178,7 +179,7 @@ def _call_for_each_replica(distribution, fn, args, kwargs): return values.regroup({t.device: t.main_result for t in threads}) -def _reduce_non_distributed_value(distribution, aggregation, value, +def _reduce_non_distributed_value(distribution, reduce_op, value, destinations): """Reduce a non-DistributedValue `value` to `destinations`.""" if isinstance(value, values.DistributedValues): @@ -190,21 +191,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value, # and equal to 0. if value == 0: return 0 - # If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this + # If the reduce op is MEAN or ONLY_FIRST_REPLICA, then this # essentially means that the same value should be on all destinations. - if aggregation in ( - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): + if reduce_op in (reduce_util.ReduceOp.MEAN, + reduce_util.ReduceOp.ONLY_FIRST_REPLICA): return value cross_tower_ops_lib.validate_destinations(destinations) - # We do not support an aggregation type of SUM if the value is the same across + # We do not support a reduce op of SUM if the value is the same across # all replicas. We call this as part of assign functions for MirroredVariables # and summing up identical values across replicas is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given aggregation %s." % (value, aggregation)) + "the given reduce op %s." % (value, reduce_op)) # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: @@ -588,28 +588,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()) return self._cross_tower_ops - def _reduce(self, aggregation, value, destinations): + def _reduce(self, reduce_op, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. - return _reduce_non_distributed_value(self, aggregation, value, + return _reduce_non_distributed_value(self, reduce_op, value, destinations) - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: + if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: value = value.get(self._devices[0]) if isinstance(value, (int, float)): return value return self.broadcast(value, destinations) return self._get_cross_tower_ops().reduce( - aggregation, value, destinations=destinations) + reduce_op, value, destinations=destinations) - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: + def _batch_reduce(self, reduce_op, value_destination_pairs): + if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: return [self.broadcast(v.get(self._devices[0]), d) for v, d in value_destination_pairs] - return self._get_cross_tower_ops().batch_reduce(aggregation, + return self._get_cross_tower_ops().batch_reduce(reduce_op, value_destination_pairs) def _update(self, var, options, fn, *args, **kwargs): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 73614a000f8..23379a72d97 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function @@ -117,7 +118,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): with dist.scope(): result = dist.call_for_each_replica(_replica_id) reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, result, destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) @@ -137,7 +138,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): with dist.scope(): result = dist.call_for_each_replica(run_fn) reduced = dist.reduce( - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, + reduce_util.ReduceOp.ONLY_FIRST_REPLICA, result, destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) @@ -157,7 +158,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, 1.0, destinations=["/device:CPU:0", "/device:GPU:0"]) unwrapped = dist.unwrap(reduced) @@ -912,7 +913,7 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " - "with the given aggregation VariableAggregation.SUM."): + "with the given reduce op ReduceOp.SUM."): self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index d623798d0cc..8ce8e115899 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -122,8 +122,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), _OneDeviceReplicaContext(self): return fn(*args, **kwargs) - def _reduce(self, aggregation, value, destinations): - del aggregation, destinations + def _reduce(self, reduce_op, value, destinations): + del reduce_op, destinations return value def _update(self, var, options, fn, *args, **kwargs): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 438b91bc8d3..49733dd7a68 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -307,24 +308,24 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) - def _reduce(self, aggregation, value, destinations): + def _reduce(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( - self, aggregation, value, destinations) - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + self, reduce_op, value, destinations) + if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: return self.broadcast(value.get(self._compute_devices[0]), destinations) return self._cross_tower_ops.reduce( - aggregation, value, destinations=destinations) + reduce_op, value, destinations=destinations) - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + def _batch_reduce(self, reduce_op, value_destination_pairs): + if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: return [self.broadcast(v.get(self._compute_devices[0]), d) for v, d in value_destination_pairs] for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) - return self._cross_tower_ops.batch_reduce(aggregation, + return self._cross_tower_ops.batch_reduce(reduce_op, value_destination_pairs) def _select_single_value(self, structured): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 81a23c89030..ec61f62dc7e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -473,7 +474,7 @@ class ParameterServerStrategyTestBase( with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 9fee75a476a..31d74902903 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -26,7 +27,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -114,8 +114,7 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -169,8 +168,7 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 5ef59bf74d8..ae5c55a3463 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -31,6 +31,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op @@ -439,12 +440,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_replicas_in_sync) - elif aggregation != vs.VariableAggregation.SUM: + elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) @@ -459,10 +460,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy): else: raise ValueError("Multiple devices are not supported for TPUStrategy") - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: return value[0] output = math_ops.add_n(value) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index ffb7b79839e..e2931fcacdc 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -30,6 +30,7 @@ import six from tensorflow.contrib.distribute.python import input_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device @@ -373,12 +374,13 @@ class MirroredVariable(DistributedVariable, Mirrored, if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Replica Context.") + reduce_op = reduce_util.ReduceOp.from_variable_aggregation( + self._aggregation) def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), + strategy.reduce(reduce_op, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_replica_context().merge_call( @@ -614,12 +616,13 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "TPUMirroredVariable in Replica Context.") + reduce_op = reduce_util.ReduceOp.from_variable_aggregation( + self._aggregation) def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), + strategy.reduce(reduce_op, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_replica_context().merge_call( @@ -1549,6 +1552,7 @@ class MultiStepContext(object): The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. + # TODO(priyag): Change aggregation type used here. """ if distribution_strategy_context.get_cross_replica_context(): @@ -1650,12 +1654,13 @@ class AggregatingVariable(checkpointable.CheckpointableBase): if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "a variable in Replica Context.") + reduce_op = reduce_util.ReduceOp.from_variable_aggregation( + self._aggregation) def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), + strategy.reduce(reduce_op, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_replica_context().merge_call( diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9bd6ddd0838..a6e5c110b7f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3590,6 +3590,7 @@ py_library( ":util", ":variable_scope", "//tensorflow/python/data", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/ops/losses", ], ) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 216dcc8587d..13d797d9872 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -155,3 +155,11 @@ py_library( "//tensorflow/python:training", ], ) + +py_library( + name = "reduce_util", + srcs = [ + "reduce_util.py", + ], + deps = [], +) diff --git a/tensorflow/python/distribute/reduce_util.py b/tensorflow/python/distribute/reduce_util.py new file mode 100644 index 00000000000..8df3923dafa --- /dev/null +++ b/tensorflow/python/distribute/reduce_util.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilites for reduce operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum + +from tensorflow.python.ops import variable_scope + + +# TODO(priyag): Add this to tf.distribute namespace when it exists. +class ReduceOp(enum.Enum): + """Indicates how a set of values should be reduced. + + * `SUM`: Add all the values. + * `MEAN`: Take the arithmetic mean ("average") of the values. + * `ONLY_FIRST_REPLICA`: Return the value on the first replica. + + TODO(priyag): Add the following types: + * `MIN`: Return the minimum of all values. + * `MAX`: Return the maximum of all values. + """ + + SUM = 0 + MEAN = 1 + ONLY_FIRST_REPLICA = 2 + + @staticmethod + def from_variable_aggregation(aggregation): + mapping = { + variable_scope.VariableAggregation.SUM: ReduceOp.SUM, + variable_scope.VariableAggregation.MEAN: ReduceOp.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: + ReduceOp.ONLY_FIRST_REPLICA + } + + reduce_op = mapping.get(aggregation) + if not reduce_op: + raise ValueError("Could not convert from `tf.VariableAggregation` to" + "`tf.distribute.ReduceOp` type") + return reduce_op + + diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index f1d2db0647d..4ed2fb19254 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -21,6 +21,7 @@ from __future__ import print_function import threading from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -755,9 +756,13 @@ class DistributionStrategy(object): """Combine (via e.g. sum or mean) values across replicas. Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum. + DEPRECATED but still accepted values: + `tf.VariableAggregation.SUM`, + `tf.VariableAggregation.MEAN`, `tf.VariableAggregation.ONLY_FIRST_REPLICA`. + # TODO(priyag): Rename this argument when moving the method to + # DSExtended. value: A per-replica value with one value per replica. destinations: A mirrored variable, a per-replica tensor, a device string, or list of device strings. The return value will be copied to all @@ -771,23 +776,32 @@ class DistributionStrategy(object): # TODO(josh11b): Return an unwrapped value if colocate_with is a # single device. _require_cross_replica_context(self) - assert aggregation in [ - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - ] - return self._reduce(aggregation, value, destinations) - def _reduce(self, aggregation, value, destinations): + # TODO(priyag): Remove this when all callers have been updated. + reduce_op = aggregation + if isinstance(aggregation, variable_scope.VariableAggregation): + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_REPLICA + ] + reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) + return self._reduce(reduce_op, value, destinations) + + def _reduce(self, reduce_op, value, destinations): raise NotImplementedError("must be implemented in descendants") def batch_reduce(self, aggregation, value_destination_pairs): """Combine multiple `reduce` calls into one for faster execution. Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum. + DEPRECATED but still accepted values: + `tf.VariableAggregation.SUM`, + `tf.VariableAggregation.MEAN`, `tf.VariableAggregation.ONLY_FIRST_REPLICA`. + # TODO(priyag): Rename this argument when moving the method to + # DSExtended. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -796,16 +810,21 @@ class DistributionStrategy(object): """ # TODO(josh11b): More docstring _require_cross_replica_context(self) - assert aggregation in [ - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - ] - return self._batch_reduce(aggregation, value_destination_pairs) - def _batch_reduce(self, aggregation, value_destination_pairs): + # TODO(priyag): Remove this when all callers have been updated. + reduce_op = aggregation + if isinstance(aggregation, variable_scope.VariableAggregation): + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_REPLICA + ] + reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) + return self._batch_reduce(reduce_op, value_destination_pairs) + + def _batch_reduce(self, reduce_op, value_destination_pairs): return [ - self.reduce(aggregation, t, destinations=v) + self.reduce(reduce_op, t, destinations=v) for t, v in value_destination_pairs ] @@ -1154,9 +1173,9 @@ class _DefaultDistributionStrategy(DistributionStrategy): with ReplicaContext(self, replica_id=0): return fn(*args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce(self, reduce_op, value, destinations): # TODO(josh11b): Use destinations? - del aggregation, destinations + del reduce_op, destinations return value def _update(self, var, options, fn, *args, **kwargs):