Add a new enum for reduction type in distribution strategy. Currently supports SUM, MEAN and ONLY_FIRST_REPLICA
Allow accepting the new enum in `reduce` and `batch_reduce` APIs. Change some callers to use the new enum. PiperOrigin-RevId: 221475510
This commit is contained in:
		
							parent
							
								
									00774b2d5e
								
							
						
					
					
						commit
						67593ee405
					
				| @ -85,6 +85,7 @@ py_library( | ||||
|         "//tensorflow/python:variable_scope", | ||||
|         "//tensorflow/python:variables", | ||||
|         "//tensorflow/python/distribute:multi_worker_util", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "//tensorflow/python/eager:tape", | ||||
|     ], | ||||
| @ -105,6 +106,7 @@ py_library( | ||||
|         "//tensorflow/python:training", | ||||
|         "//tensorflow/python:util", | ||||
|         "//tensorflow/python/distribute:multi_worker_util", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|         "//tensorflow/python/eager:context", | ||||
|     ], | ||||
| ) | ||||
| @ -151,6 +153,7 @@ py_library( | ||||
|         "//tensorflow/python:distribute", | ||||
|         "//tensorflow/python:framework_ops", | ||||
|         "//tensorflow/python:math_ops", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "@six_archive//:six", | ||||
|     ], | ||||
| @ -343,6 +346,7 @@ py_library( | ||||
|         "//tensorflow/python:control_flow_ops", | ||||
|         "//tensorflow/python:framework_ops", | ||||
|         "//tensorflow/python:util", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| @ -669,6 +673,7 @@ py_library( | ||||
|         "//tensorflow/python:resource_variable_ops", | ||||
|         "//tensorflow/python:training", | ||||
|         "//tensorflow/python:variable_scope", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "@six_archive//:six", | ||||
|     ], | ||||
|  | ||||
| @ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import cross_tower_utils | ||||
| from tensorflow.contrib.distribute.python import multi_worker_test_base | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python import keras | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import dtypes | ||||
| @ -128,7 +129,7 @@ class CollectiveAllReduceStrategyTestBase( | ||||
|           with ops.control_dependencies([fetched]): | ||||
|             # TODO(yuefengz): support non-Mirrored variable as destinations. | ||||
|             g = d.reduce( | ||||
|                 variable_scope.VariableAggregation.SUM, g, destinations=v) | ||||
|                 reduce_util.ReduceOp.SUM, g, destinations=v) | ||||
|             with ops.control_dependencies( | ||||
|                 d.update(v, update, g, grouped=False)): | ||||
|               after_list.append(d.read_var(v)) | ||||
| @ -224,7 +225,7 @@ class CollectiveAllReduceStrategyTestBase( | ||||
|       x = distribution.call_for_each_replica(model_fn) | ||||
|       reduced_x = distribution.unwrap( | ||||
|           distribution.reduce( | ||||
|               variable_scope.VariableAggregation.MEAN, x, | ||||
|               reduce_util.ReduceOp.MEAN, x, | ||||
|               destinations='/cpu:0'))[0] | ||||
|       x = distribution.unwrap(x)[0] | ||||
| 
 | ||||
|  | ||||
| @ -24,12 +24,12 @@ import six | ||||
| from tensorflow.contrib.distribute.python import cross_tower_utils | ||||
| from tensorflow.contrib.distribute.python import values as value_lib | ||||
| from tensorflow.python.client import device_lib | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.ops import resource_variable_ops | ||||
| from tensorflow.python.ops import variable_scope as vs | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training import device_util | ||||
| 
 | ||||
| @ -150,7 +150,7 @@ def _simple_broadcast(value, destinations): | ||||
| 
 | ||||
| 
 | ||||
| def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, | ||||
|                    aggregation): | ||||
|                    reduce_op): | ||||
|   # pylint: disable=g-missing-docstring | ||||
|   all_values = [] | ||||
|   count = 0 | ||||
| @ -164,12 +164,11 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, | ||||
|     with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): | ||||
|       reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( | ||||
|           all_values, accumulation_fn) | ||||
|       if aggregation == vs.VariableAggregation.MEAN: | ||||
|       if reduce_op == reduce_util.ReduceOp.MEAN: | ||||
|         reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( | ||||
|             reduced, count) | ||||
|       elif aggregation != vs.VariableAggregation.SUM: | ||||
|         raise ValueError("`aggregation` must be VariableAggregation.SUM " | ||||
|                          "or VariableAggregation.MEAN.") | ||||
|       elif reduce_op != reduce_util.ReduceOp.SUM: | ||||
|         raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") | ||||
|   return reduced | ||||
| 
 | ||||
| 
 | ||||
| @ -179,15 +178,15 @@ class CrossDeviceOps(object): | ||||
|   def __init__(self): | ||||
|     pass | ||||
| 
 | ||||
|   def reduce(self, aggregation, per_replica_value, destinations): | ||||
|   def reduce(self, reduce_op, per_replica_value, destinations): | ||||
|     """Reduce `per_replica_value` to `destinations`. | ||||
| 
 | ||||
|     It runs the reduction operation defined by `aggregation` and put the | ||||
|     It runs the reduction operation defined by `reduce_op` and put the | ||||
|     result on `destinations`. | ||||
| 
 | ||||
|     Args: | ||||
|       aggregation: Indicates how a variable will be aggregated. Accepted values | ||||
|         are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. | ||||
|       reduce_op: Indicates how per_replica_value will be reduced. Accepted | ||||
|         values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. | ||||
|       per_replica_value: a PerReplica object or a tensor with device set. | ||||
|       destinations: the reduction destinations. | ||||
| 
 | ||||
| @ -201,17 +200,17 @@ class CrossDeviceOps(object): | ||||
|       per_replica_value = _make_tensor_into_per_replica(per_replica_value) | ||||
| 
 | ||||
|     validate_destinations(destinations) | ||||
|     return self._reduce(aggregation, per_replica_value, destinations) | ||||
|     return self._reduce(reduce_op, per_replica_value, destinations) | ||||
| 
 | ||||
|   def batch_reduce(self, aggregation, value_destination_pairs): | ||||
|   def batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     """Reduce PerReplica objects in a batch. | ||||
| 
 | ||||
|     Reduce each first element in `value_destination_pairs` to each second | ||||
|     element which indicates the destinations. | ||||
| 
 | ||||
|     Args: | ||||
|       aggregation: Indicates how a variable will be aggregated. Accepted values | ||||
|         are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. | ||||
|       reduce_op: Indicates how per_replica_value will be reduced. Accepted | ||||
|         values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. | ||||
|       value_destination_pairs: a list or a tuple of tuples of PerReplica objects | ||||
|         (or tensors with device set if there is one device) and destinations. | ||||
| 
 | ||||
| @ -231,7 +230,7 @@ class CrossDeviceOps(object): | ||||
|     for _, d in value_destination_pairs: | ||||
|       validate_destinations(d) | ||||
| 
 | ||||
|     return self._batch_reduce(aggregation, value_destination_pairs) | ||||
|     return self._batch_reduce(reduce_op, value_destination_pairs) | ||||
| 
 | ||||
|   def broadcast(self, tensor, destinations): | ||||
|     """Broadcast the `tensor` to destinations. | ||||
| @ -246,11 +245,11 @@ class CrossDeviceOps(object): | ||||
|     validate_destinations(destinations) | ||||
|     return self._broadcast(tensor, destinations) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, per_replica_value, destinations): | ||||
|   def _reduce(self, reduce_op, per_replica_value, destinations): | ||||
|     raise NotImplementedError( | ||||
|         "_reduce method must be implemented in descendants.") | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     raise NotImplementedError( | ||||
|         "_batch_reduce method must be implemented in descendants.") | ||||
| 
 | ||||
| @ -276,19 +275,19 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): | ||||
|     self.accumulation_fn = accumulation_fn | ||||
|     super(ReductionToOneDeviceCrossDeviceOps, self).__init__() | ||||
| 
 | ||||
|   def _reduce(self, aggregation, per_replica_value, destinations): | ||||
|   def _reduce(self, reduce_op, per_replica_value, destinations): | ||||
|     if check_destinations(destinations): | ||||
|       devices = get_devices_from(destinations) | ||||
|     else: | ||||
|       devices = get_devices_from(per_replica_value) | ||||
|     reduce_to_device = self.reduce_to_device or devices[0] | ||||
|     reduced = _simple_reduce(per_replica_value, reduce_to_device, | ||||
|                              self.accumulation_fn, aggregation) | ||||
|                              self.accumulation_fn, reduce_op) | ||||
|     return self.broadcast(reduced, devices) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     return [ | ||||
|         self._reduce(aggregation, t, destinations=v) | ||||
|         self._reduce(reduce_op, t, destinations=v) | ||||
|         for t, v in value_destination_pairs | ||||
|     ] | ||||
| 
 | ||||
| @ -323,20 +322,20 @@ def _group_value_by_device(per_replica_values): | ||||
| 
 | ||||
| def _ungroup_and_make_mirrored(grouped_reduced, | ||||
|                                destinations, | ||||
|                                aggregation, | ||||
|                                reduce_op, | ||||
|                                num_between_graph_workers=1): | ||||
|   """Ungroup results from all-reduce and make Mirrored objects. | ||||
| 
 | ||||
|   Each all-reduce result will be divided by the number of destinations before | ||||
|   Mirrored objects are created if aggregation is "mean". | ||||
|   Mirrored objects are created if reduce_op is "mean". | ||||
| 
 | ||||
|   Args: | ||||
|     grouped_reduced: a list of lists, each sublist has components for each | ||||
|       device, paired with a None. It is the result from | ||||
|       cross_tower_utils.aggregate_gradients_using*. | ||||
|     destinations: a list of device strings for returned Mirrored objects. | ||||
|     aggregation: Indicates how a variable will be aggregated. Accepted values | ||||
|       are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. | ||||
|     reduce_op: Indicates how values will be aggregated. Accepted values | ||||
|       are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. | ||||
|     num_between_graph_workers: number of workers in the between-graph | ||||
|       replication. | ||||
| 
 | ||||
| @ -346,7 +345,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, | ||||
|   index = [{} for _ in range(len(grouped_reduced[0]))] | ||||
|   for d, per_replica_reduced in enumerate(grouped_reduced): | ||||
|     for i, (v, _) in enumerate(per_replica_reduced): | ||||
|       if aggregation == vs.VariableAggregation.MEAN: | ||||
|       if reduce_op == reduce_util.ReduceOp.MEAN: | ||||
|         index[i][destinations[d]] = v / ( | ||||
|             len(destinations) * num_between_graph_workers) | ||||
|       else: | ||||
| @ -557,13 +556,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): | ||||
|     self._agg_small_grads_max_group = agg_small_grads_max_group | ||||
|     super(AllReduceCrossDeviceOps, self).__init__() | ||||
| 
 | ||||
|   def _reduce(self, aggregation, per_replica_value, destinations): | ||||
|   def _reduce(self, reduce_op, per_replica_value, destinations): | ||||
|     contains_indexed_slices = cross_tower_utils.contains_indexed_slices( | ||||
|         per_replica_value) | ||||
|     if (_devices_match(per_replica_value, destinations) | ||||
|         and not context.executing_eagerly() | ||||
|         and not contains_indexed_slices): | ||||
|       return self._batch_all_reduce(aggregation, [per_replica_value])[0] | ||||
|       return self._batch_all_reduce(reduce_op, [per_replica_value])[0] | ||||
|     else: | ||||
|       if contains_indexed_slices: | ||||
|         logging.log_first_n( | ||||
| @ -576,16 +575,16 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): | ||||
|         devices = get_devices_from(per_replica_value) | ||||
|       reduce_to_device = devices[0] | ||||
|       reduced = _simple_reduce(per_replica_value, reduce_to_device, | ||||
|                                math_ops.add_n, aggregation) | ||||
|                                math_ops.add_n, reduce_op) | ||||
|       return self.broadcast(reduced, devices) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     all_devices_match = _all_devices_match(value_destination_pairs) | ||||
|     contains_indexed_slices = cross_tower_utils.contains_indexed_slices( | ||||
|         value_destination_pairs) | ||||
|     if (all_devices_match and not context.executing_eagerly() | ||||
|         and not contains_indexed_slices): | ||||
|       return self._batch_all_reduce(aggregation, | ||||
|       return self._batch_all_reduce(reduce_op, | ||||
|                                     [v[0] for v in value_destination_pairs]) | ||||
|     else: | ||||
|       if not all_devices_match: | ||||
| @ -595,11 +594,11 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): | ||||
|                             10) | ||||
| 
 | ||||
|       return [ | ||||
|           self._reduce(aggregation, t, destinations=v) | ||||
|           self._reduce(reduce_op, t, destinations=v) | ||||
|           for t, v in value_destination_pairs | ||||
|       ] | ||||
| 
 | ||||
|   def _batch_all_reduce(self, aggregation, per_replica_values): | ||||
|   def _batch_all_reduce(self, reduce_op, per_replica_values): | ||||
|     """All reduce algorithm in a batch.""" | ||||
|     logging.log_first_n( | ||||
|         logging.INFO, "batch_all_reduce invoked for batches size = %d with " | ||||
| @ -630,7 +629,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): | ||||
| 
 | ||||
|     reduced = _unpack_tensors(reduced, tensor_packer) | ||||
|     return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices, | ||||
|                                       aggregation) | ||||
|                                       reduce_op) | ||||
| 
 | ||||
| 
 | ||||
| # For compatibility with code using the old name of `AllReduceCrossDeviceOps`. | ||||
| @ -713,7 +712,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): | ||||
|           validate_and_complete_spec(spec) for spec in all_reduce_spec | ||||
|       ] | ||||
| 
 | ||||
|   def _batch_all_reduce(self, aggregation, per_replica_values): | ||||
|   def _batch_all_reduce(self, reduce_op, per_replica_values): | ||||
|     """All reduce algorithm in a batch.""" | ||||
|     logging.log_first_n( | ||||
|         logging.INFO, | ||||
| @ -761,7 +760,7 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): | ||||
|     assert not remaining_grads | ||||
| 
 | ||||
|     return _ungroup_and_make_mirrored(aggregated_grads, destinations, | ||||
|                                       aggregation) | ||||
|                                       reduce_op) | ||||
| 
 | ||||
| 
 | ||||
| # TODO(yuefengz): support in-graph collective all-reduce. | ||||
| @ -795,7 +794,7 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
|     super(CollectiveAllReduce, self).__init__() | ||||
| 
 | ||||
|   # TODO(yuefengz, tucker): is indexed slices supported by collective ops? | ||||
|   def _reduce(self, aggregation, per_replica_value, destinations): | ||||
|   def _reduce(self, reduce_op, per_replica_value, destinations): | ||||
|     if cross_tower_utils.contains_indexed_slices(per_replica_value): | ||||
|       raise ValueError( | ||||
|           "`IndexSlices` is not supported for Collective All-Reduce.") | ||||
| @ -803,7 +802,7 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
|       raise ValueError( | ||||
|           "Eager execution is not supported for Collective All-Reduce") | ||||
| 
 | ||||
|     all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0] | ||||
|     all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] | ||||
|     if _devices_match(per_replica_value, destinations): | ||||
|       return all_reduced | ||||
|     else: | ||||
| @ -819,7 +818,7 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
| 
 | ||||
|       return value_lib.Mirrored(index) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     if cross_tower_utils.contains_indexed_slices(value_destination_pairs): | ||||
|       raise ValueError( | ||||
|           "`IndexSlices` is not supported for Collective All-Reduce.") | ||||
| @ -829,7 +828,7 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
| 
 | ||||
|     all_devices_match = _all_devices_match(value_destination_pairs) | ||||
|     if all_devices_match: | ||||
|       return self._batch_all_reduce(aggregation, | ||||
|       return self._batch_all_reduce(reduce_op, | ||||
|                                     [v[0] for v in value_destination_pairs]) | ||||
|     else: | ||||
|       if not all_devices_match: | ||||
| @ -838,11 +837,11 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
|             "destinations are different.", 10) | ||||
| 
 | ||||
|       return [ | ||||
|           self._reduce(aggregation, t, destinations=v) | ||||
|           self._reduce(reduce_op, t, destinations=v) | ||||
|           for t, v in value_destination_pairs | ||||
|       ] | ||||
| 
 | ||||
|   def _batch_all_reduce(self, aggregation, per_replica_values): | ||||
|   def _batch_all_reduce(self, reduce_op, per_replica_values): | ||||
|     """All-reduce across all workers in a batch.""" | ||||
|     if context.executing_eagerly(): | ||||
|       raise ValueError( | ||||
| @ -883,7 +882,7 @@ class CollectiveAllReduce(CrossDeviceOps): | ||||
|     return _ungroup_and_make_mirrored( | ||||
|         new_device_grads, | ||||
|         per_replica_values[0].devices, | ||||
|         aggregation, | ||||
|         reduce_op, | ||||
|         num_between_graph_workers=self._num_workers) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -30,13 +30,13 @@ from tensorflow.contrib.distribute.python import mirrored_strategy | ||||
| from tensorflow.contrib.distribute.python import multi_worker_test_base | ||||
| from tensorflow.contrib.distribute.python import values as value_lib | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import test | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.ops import variable_scope as vs | ||||
| from tensorflow.python.training import device_util | ||||
| 
 | ||||
| 
 | ||||
| @ -143,24 +143,24 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): | ||||
|     for destinations in all_destinations: | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.reduce( | ||||
|               vs.VariableAggregation.MEAN, | ||||
|               reduce_util.ReduceOp.MEAN, | ||||
|               per_replica, | ||||
|               destinations=destinations), | ||||
|           _fake_mirrored(mean, destinations)) | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.reduce( | ||||
|               vs.VariableAggregation.MEAN, | ||||
|               reduce_util.ReduceOp.MEAN, | ||||
|               per_replica_2, | ||||
|               destinations=destinations), | ||||
|           _fake_mirrored(mean_2, destinations)) | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.reduce( | ||||
|               vs.VariableAggregation.SUM, per_replica, | ||||
|               reduce_util.ReduceOp.SUM, per_replica, | ||||
|               destinations=destinations), | ||||
|           _fake_mirrored(mean * len(devices), destinations)) | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.reduce( | ||||
|               vs.VariableAggregation.SUM, | ||||
|               reduce_util.ReduceOp.SUM, | ||||
|               per_replica_2, | ||||
|               destinations=destinations), | ||||
|           _fake_mirrored(mean_2 * len(devices), destinations)) | ||||
| @ -169,7 +169,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): | ||||
|     for d1, d2 in itertools.product(all_destinations, all_destinations): | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.batch_reduce( | ||||
|               vs.VariableAggregation.MEAN, | ||||
|               reduce_util.ReduceOp.MEAN, | ||||
|               [(per_replica, d1), (per_replica_2, d2)]), | ||||
|           [ | ||||
|               _fake_mirrored(mean, d1), | ||||
| @ -177,7 +177,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): | ||||
|           ]) | ||||
|       self._assert_values_equal( | ||||
|           cross_tower_ops.batch_reduce( | ||||
|               vs.VariableAggregation.SUM, | ||||
|               reduce_util.ReduceOp.SUM, | ||||
|               [(per_replica, d1), (per_replica_2, d2)]), | ||||
|           [ | ||||
|               _fake_mirrored(mean * len(devices), d1), | ||||
| @ -281,7 +281,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): | ||||
|     t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) | ||||
|     per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) | ||||
|     result = cross_tower_ops_lib._simple_reduce( | ||||
|         per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) | ||||
|         per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) | ||||
| 
 | ||||
|     # Test that the result is semantically equal to both the concatenated | ||||
|     # IndexedSlices with and without duplicate indices. | ||||
| @ -302,11 +302,11 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): | ||||
|                   "AllReduceCrossDeviceOps", | ||||
|                   cross_tower_ops_lib.AllReduceCrossDeviceOps()) | ||||
|           ], | ||||
|           aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], | ||||
|           reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], | ||||
|           batch_reduce=[True, False], | ||||
|           mode=["graph", "eager"], | ||||
|           required_gpus=1)) | ||||
|   def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, | ||||
|   def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, reduce_op, | ||||
|                                  batch_reduce): | ||||
|     devices = ["/cpu:0", "/gpu:0"] | ||||
|     dense_shape = [5, 2] | ||||
| @ -317,19 +317,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): | ||||
| 
 | ||||
|     if batch_reduce: | ||||
|       result = cross_tower_ops_instance.batch_reduce( | ||||
|           aggregation, [(per_replica, devices)]) | ||||
|           reduce_op, [(per_replica, devices)]) | ||||
|     else: | ||||
|       result = cross_tower_ops_instance.reduce( | ||||
|           aggregation, per_replica, devices) | ||||
|           reduce_op, per_replica, devices) | ||||
| 
 | ||||
|     total_indices_with_dups = [1, 1, 3] | ||||
|     total_indices_without_dups = [1, 3] | ||||
| 
 | ||||
|     if aggregation == vs.VariableAggregation.SUM: | ||||
|     if reduce_op == reduce_util.ReduceOp.SUM: | ||||
|       total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] | ||||
|       total_values_without_dups = [[4., 6.], [5., 6.]] | ||||
|     else: | ||||
|       assert aggregation == vs.VariableAggregation.MEAN | ||||
|       assert reduce_op == reduce_util.ReduceOp.MEAN | ||||
|       total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] | ||||
|       total_values_without_dups = [[2., 3.], [2.5, 3.]] | ||||
| 
 | ||||
| @ -502,26 +502,26 @@ class MultiWorkerCollectiveAllReduceTest( | ||||
|       for destinations in all_destinations: | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.reduce( | ||||
|                 vs.VariableAggregation.MEAN, | ||||
|                 reduce_util.ReduceOp.MEAN, | ||||
|                 per_replica, | ||||
|                 destinations=destinations), | ||||
|             _fake_mirrored(mean, destinations), sess) | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.reduce( | ||||
|                 vs.VariableAggregation.MEAN, | ||||
|                 reduce_util.ReduceOp.MEAN, | ||||
|                 per_replica_2, | ||||
|                 destinations=destinations), | ||||
|             _fake_mirrored(mean_2, destinations), sess) | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.reduce( | ||||
|                 vs.VariableAggregation.SUM, | ||||
|                 reduce_util.ReduceOp.SUM, | ||||
|                 per_replica, | ||||
|                 destinations=destinations), | ||||
|             _fake_mirrored(mean * len(devices) * num_workers, destinations), | ||||
|             sess) | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.reduce( | ||||
|                 vs.VariableAggregation.SUM, | ||||
|                 reduce_util.ReduceOp.SUM, | ||||
|                 per_replica_2, | ||||
|                 destinations=destinations), | ||||
|             _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), | ||||
| @ -530,7 +530,7 @@ class MultiWorkerCollectiveAllReduceTest( | ||||
|       # test batch_reduce() | ||||
|       for d1, d2 in itertools.product(all_destinations, all_destinations): | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, | ||||
|             collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, | ||||
|                                                [(per_replica, d1), | ||||
|                                                 (per_replica_2, d2)]), | ||||
|             [ | ||||
| @ -538,7 +538,7 @@ class MultiWorkerCollectiveAllReduceTest( | ||||
|                 _fake_mirrored(mean_2, d2) | ||||
|             ], sess) | ||||
|         self._assert_values_equal( | ||||
|             collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, | ||||
|             collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, | ||||
|                                                [(per_replica, d1), | ||||
|                                                 (per_replica_2, d2)]), | ||||
|             [ | ||||
|  | ||||
| @ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations | ||||
| from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example | ||||
| from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import test | ||||
| from tensorflow.python.framework import constant_op | ||||
| @ -484,7 +485,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|       self.assertEqual(distribution.num_replicas_in_sync, | ||||
|                        len(distribution.unwrap(loss_output))) | ||||
|       loss_output = distribution.reduce( | ||||
|           aggregation=variables_lib.VariableAggregation.MEAN, | ||||
|           aggregation=reduce_util.ReduceOp.MEAN, | ||||
|           value=loss_output, destinations="/device:CPU:0") | ||||
| 
 | ||||
|     unwrapped_output = distribution.unwrap(loss_output) | ||||
|  | ||||
| @ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python import shared_variable_creator | ||||
| from tensorflow.contrib.distribute.python import values | ||||
| from tensorflow.python import pywrap_tensorflow | ||||
| from tensorflow.python.distribute import multi_worker_util | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import tape | ||||
| from tensorflow.python.framework import constant_op | ||||
| @ -178,7 +179,7 @@ def _call_for_each_replica(distribution, fn, args, kwargs): | ||||
|   return values.regroup({t.device: t.main_result for t in threads}) | ||||
| 
 | ||||
| 
 | ||||
| def _reduce_non_distributed_value(distribution, aggregation, value, | ||||
| def _reduce_non_distributed_value(distribution, reduce_op, value, | ||||
|                                   destinations): | ||||
|   """Reduce a non-DistributedValue `value` to `destinations`.""" | ||||
|   if isinstance(value, values.DistributedValues): | ||||
| @ -190,21 +191,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value, | ||||
|   # and equal to 0. | ||||
|   if value == 0: | ||||
|     return 0 | ||||
|   # If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this | ||||
|   # If the reduce op is MEAN or ONLY_FIRST_REPLICA, then this | ||||
|   # essentially means that the same value should be on all destinations. | ||||
|   if aggregation in ( | ||||
|       variable_scope.VariableAggregation.MEAN, | ||||
|       variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): | ||||
|   if reduce_op in (reduce_util.ReduceOp.MEAN, | ||||
|                    reduce_util.ReduceOp.ONLY_FIRST_REPLICA): | ||||
|     return value | ||||
| 
 | ||||
|   cross_tower_ops_lib.validate_destinations(destinations) | ||||
|   # We do not support an aggregation type of SUM if the value is the same across | ||||
|   # We do not support a reduce op of SUM if the value is the same across | ||||
|   # all replicas. We call this as part of assign functions for MirroredVariables | ||||
|   # and summing up identical values across replicas is not clearly defined. | ||||
|   if (len(distribution.worker_devices) != 1 or | ||||
|       not cross_tower_ops_lib.check_destinations(destinations)): | ||||
|     raise ValueError("A non-DistributedValues value %s cannot be reduced with " | ||||
|                      "the given aggregation %s." % (value, aggregation)) | ||||
|                      "the given reduce op %s." % (value, reduce_op)) | ||||
|   # TODO(anjalisridhar): Moves these methods to a device utility file? | ||||
|   devices = cross_tower_ops_lib.get_devices_from(destinations) | ||||
|   if len(devices) == 1: | ||||
| @ -588,28 +588,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): | ||||
|           cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()) | ||||
|     return self._cross_tower_ops | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     assert not isinstance(value, values.Mirrored) | ||||
|     if not isinstance(value, values.DistributedValues): | ||||
|       # This function handles reducing values that are not PerReplica or | ||||
|       # Mirrored values. For example, the same value could be present on all | ||||
|       # replicas in which case `value` would be a single value or value could | ||||
|       # be 0. | ||||
|       return _reduce_non_distributed_value(self, aggregation, value, | ||||
|       return _reduce_non_distributed_value(self, reduce_op, value, | ||||
|                                            destinations) | ||||
|     if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|     if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: | ||||
|       value = value.get(self._devices[0]) | ||||
|       if isinstance(value, (int, float)): | ||||
|         return value | ||||
|       return self.broadcast(value, destinations) | ||||
|     return self._get_cross_tower_ops().reduce( | ||||
|         aggregation, value, destinations=destinations) | ||||
|         reduce_op, value, destinations=destinations) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|     if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: | ||||
|       return [self.broadcast(v.get(self._devices[0]), d) | ||||
|               for v, d in value_destination_pairs] | ||||
|     return self._get_cross_tower_ops().batch_reduce(aggregation, | ||||
|     return self._get_cross_tower_ops().batch_reduce(reduce_op, | ||||
|                                                     value_destination_pairs) | ||||
| 
 | ||||
|   def _update(self, var, options, fn, *args, **kwargs): | ||||
|  | ||||
| @ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import strategy_test_lib | ||||
| from tensorflow.contrib.distribute.python import values | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import backprop | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import function | ||||
| @ -117,7 +118,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): | ||||
|     with dist.scope(): | ||||
|       result = dist.call_for_each_replica(_replica_id) | ||||
|       reduced = dist.reduce( | ||||
|           variable_scope.VariableAggregation.SUM, | ||||
|           reduce_util.ReduceOp.SUM, | ||||
|           result, | ||||
|           destinations="/device:CPU:0") | ||||
|       unwrapped = dist.unwrap(reduced) | ||||
| @ -137,7 +138,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): | ||||
|     with dist.scope(): | ||||
|       result = dist.call_for_each_replica(run_fn) | ||||
|       reduced = dist.reduce( | ||||
|           variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, | ||||
|           reduce_util.ReduceOp.ONLY_FIRST_REPLICA, | ||||
|           result, | ||||
|           destinations="/device:CPU:0") | ||||
|       unwrapped = dist.unwrap(reduced) | ||||
| @ -157,7 +158,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): | ||||
|     dist = mirrored_strategy.MirroredStrategy(devices) | ||||
|     with dist.scope(): | ||||
|       reduced = dist.reduce( | ||||
|           variable_scope.VariableAggregation.SUM, | ||||
|           reduce_util.ReduceOp.SUM, | ||||
|           1.0, | ||||
|           destinations=["/device:CPU:0", "/device:GPU:0"]) | ||||
|       unwrapped = dist.unwrap(reduced) | ||||
| @ -912,7 +913,7 @@ class MirroredVariableUpdateTest(test.TestCase): | ||||
| 
 | ||||
|       with self.assertRaisesRegexp( | ||||
|           ValueError, "A non-DistributedValues value 5.0 cannot be reduced " | ||||
|           "with the given aggregation VariableAggregation.SUM."): | ||||
|           "with the given reduce op ReduceOp.SUM."): | ||||
|         self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) | ||||
| 
 | ||||
|   @test_util.run_in_graph_and_eager_modes(config=config) | ||||
|  | ||||
| @ -122,8 +122,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): | ||||
|     with ops.device(self._device), _OneDeviceReplicaContext(self): | ||||
|       return fn(*args, **kwargs) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|     del aggregation, destinations | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     del reduce_op, destinations | ||||
|     return value | ||||
| 
 | ||||
|   def _update(self, var, options, fn, *args, **kwargs): | ||||
|  | ||||
| @ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ | ||||
| from tensorflow.contrib.distribute.python import mirrored_strategy | ||||
| from tensorflow.contrib.distribute.python import values | ||||
| from tensorflow.python.distribute import multi_worker_util | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import device as tf_device | ||||
| from tensorflow.python.framework import ops | ||||
| @ -307,24 +308,24 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): | ||||
|             "Cannot reduce to another worker: %r, current worker is %r" % | ||||
|             (d, self._worker_device)) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     self._verify_destinations_not_different_worker(destinations) | ||||
|     if not isinstance(value, values.DistributedValues): | ||||
|       # pylint: disable=protected-access | ||||
|       return mirrored_strategy._reduce_non_distributed_value( | ||||
|           self, aggregation, value, destinations) | ||||
|     if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|           self, reduce_op, value, destinations) | ||||
|     if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: | ||||
|       return self.broadcast(value.get(self._compute_devices[0]), destinations) | ||||
|     return self._cross_tower_ops.reduce( | ||||
|         aggregation, value, destinations=destinations) | ||||
|         reduce_op, value, destinations=destinations) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|     if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: | ||||
|       return [self.broadcast(v.get(self._compute_devices[0]), d) | ||||
|               for v, d in value_destination_pairs] | ||||
|     for _, destinations in value_destination_pairs: | ||||
|       self._verify_destinations_not_different_worker(destinations) | ||||
|     return self._cross_tower_ops.batch_reduce(aggregation, | ||||
|     return self._cross_tower_ops.batch_reduce(reduce_op, | ||||
|                                               value_destination_pairs) | ||||
| 
 | ||||
|   def _select_single_value(self, structured): | ||||
|  | ||||
| @ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import parameter_server_strategy | ||||
| from tensorflow.contrib.distribute.python import values | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python.distribute import multi_worker_util | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import backprop | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.estimator import run_config | ||||
| @ -473,7 +474,7 @@ class ParameterServerStrategyTestBase( | ||||
|           with ops.control_dependencies([fetched]): | ||||
|             # TODO(yuefengz): support non-Mirrored variable as destinations. | ||||
|             g = d.reduce( | ||||
|                 variable_scope.VariableAggregation.SUM, g, destinations=v) | ||||
|                 reduce_util.ReduceOp.SUM, g, destinations=v) | ||||
|             with ops.control_dependencies( | ||||
|                 d.update(v, update, g, grouped=False)): | ||||
|               after_list.append(d.read_var(v)) | ||||
|  | ||||
| @ -19,6 +19,7 @@ from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.core.protobuf import config_pb2 | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import backprop | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import test | ||||
| @ -26,7 +27,6 @@ from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.layers import core | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import variable_scope | ||||
| from tensorflow.python.ops import variables | ||||
| from tensorflow.python.training import distribution_strategy_context | ||||
| from tensorflow.python.training import optimizer | ||||
| @ -114,8 +114,7 @@ class DistributionTestBase(test.TestCase): | ||||
|           before_list.append(fetched) | ||||
|           # control_dependencies irrelevant but harmless in eager execution | ||||
|           with ops.control_dependencies([fetched]): | ||||
|             g = d.reduce( | ||||
|                 variable_scope.VariableAggregation.SUM, g, destinations=v) | ||||
|             g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) | ||||
|             with ops.control_dependencies(d.update( | ||||
|                 v, update, g, grouped=False)): | ||||
|               after_list.append(d.read_var(v)) | ||||
| @ -169,8 +168,7 @@ class DistributionTestBase(test.TestCase): | ||||
|           fetched = d.read_var(v) | ||||
|           before_list.append(fetched) | ||||
|           with ops.control_dependencies([fetched]): | ||||
|             g = d.reduce( | ||||
|                 variable_scope.VariableAggregation.SUM, g, destinations=v) | ||||
|             g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) | ||||
|             with ops.control_dependencies(d.update( | ||||
|                 v, update, g, grouped=False)): | ||||
|               after_list.append(d.read_var(v)) | ||||
|  | ||||
| @ -31,6 +31,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_ | ||||
| from tensorflow.contrib.tpu.python.tpu import training_loop | ||||
| from tensorflow.python.data.experimental.ops import batching | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import tape | ||||
| from tensorflow.python.framework import constant_op | ||||
| @ -439,12 +440,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): | ||||
|     return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, | ||||
|                                          **kwargs) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access | ||||
|       if aggregation == vs.VariableAggregation.MEAN: | ||||
|       if reduce_op == reduce_util.ReduceOp.MEAN: | ||||
|         # TODO(jhseu):  Revisit once we support model-parallelism. | ||||
|         value *= (1. / self.num_replicas_in_sync) | ||||
|       elif aggregation != vs.VariableAggregation.SUM: | ||||
|       elif reduce_op != reduce_util.ReduceOp.SUM: | ||||
|         raise NotImplementedError( | ||||
|             "Currently only support sum & mean in TPUStrategy.") | ||||
|       return tpu_ops.cross_replica_sum(value) | ||||
| @ -459,10 +460,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy): | ||||
|     else: | ||||
|       raise ValueError("Multiple devices are not supported for TPUStrategy") | ||||
| 
 | ||||
|     if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|     if reduce_op == reduce_util.ReduceOp.ONLY_FIRST_REPLICA: | ||||
|       return value[0] | ||||
|     output = math_ops.add_n(value) | ||||
|     if aggregation == vs.VariableAggregation.MEAN: | ||||
|     if reduce_op == reduce_util.ReduceOp.MEAN: | ||||
|       return output * (1. / len(value)) | ||||
|     return output | ||||
| 
 | ||||
|  | ||||
| @ -30,6 +30,7 @@ import six | ||||
| from tensorflow.contrib.distribute.python import input_ops | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.data.ops import multi_device_iterator_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import tape | ||||
| from tensorflow.python.framework import device as tf_device | ||||
| @ -373,12 +374,13 @@ class MirroredVariable(DistributedVariable, Mirrored, | ||||
|       if self._aggregation == vs.VariableAggregation.NONE: | ||||
|         raise ValueError("You must specify an aggregation method to update a " | ||||
|                          "MirroredVariable in Replica Context.") | ||||
|       reduce_op = reduce_util.ReduceOp.from_variable_aggregation( | ||||
|           self._aggregation) | ||||
| 
 | ||||
|       def merge_fn(strategy, value, *other_args, **other_kwargs): | ||||
|         return strategy.update( | ||||
|             self, f, | ||||
|             strategy.reduce( | ||||
|                 aggregation=self._aggregation, value=value, destinations=self), | ||||
|             strategy.reduce(reduce_op, value=value, destinations=self), | ||||
|             *other_args, **other_kwargs) | ||||
| 
 | ||||
|       return distribution_strategy_context.get_replica_context().merge_call( | ||||
| @ -614,12 +616,13 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): | ||||
|       if self._aggregation == vs.VariableAggregation.NONE: | ||||
|         raise ValueError("You must specify an aggregation method to update a " | ||||
|                          "TPUMirroredVariable in Replica Context.") | ||||
|       reduce_op = reduce_util.ReduceOp.from_variable_aggregation( | ||||
|           self._aggregation) | ||||
| 
 | ||||
|       def merge_fn(strategy, value, *other_args, **other_kwargs): | ||||
|         return strategy.update( | ||||
|             self, f, | ||||
|             strategy.reduce( | ||||
|                 aggregation=self._aggregation, value=value, destinations=self), | ||||
|             strategy.reduce(reduce_op, value=value, destinations=self), | ||||
|             *other_args, **other_kwargs) | ||||
| 
 | ||||
|       return distribution_strategy_context.get_replica_context().merge_call( | ||||
| @ -1549,6 +1552,7 @@ class MultiStepContext(object): | ||||
|         The aggregation method is also recorded in a dictionary | ||||
|         `_last_step_outputs_aggregations` for later interpreting of the | ||||
|         outputs as already reduced or not. | ||||
|         # TODO(priyag): Change aggregation type used here. | ||||
| 
 | ||||
|     """ | ||||
|     if distribution_strategy_context.get_cross_replica_context(): | ||||
| @ -1650,12 +1654,13 @@ class AggregatingVariable(checkpointable.CheckpointableBase): | ||||
|       if self._aggregation == vs.VariableAggregation.NONE: | ||||
|         raise ValueError("You must specify an aggregation method to update a " | ||||
|                          "a variable in Replica Context.") | ||||
|       reduce_op = reduce_util.ReduceOp.from_variable_aggregation( | ||||
|           self._aggregation) | ||||
| 
 | ||||
|       def merge_fn(strategy, value, *other_args, **other_kwargs): | ||||
|         return strategy.update( | ||||
|             self, f, | ||||
|             strategy.reduce( | ||||
|                 aggregation=self._aggregation, value=value, destinations=self), | ||||
|             strategy.reduce(reduce_op, value=value, destinations=self), | ||||
|             *other_args, **other_kwargs) | ||||
| 
 | ||||
|       return distribution_strategy_context.get_replica_context().merge_call( | ||||
|  | ||||
| @ -3590,6 +3590,7 @@ py_library( | ||||
|         ":util", | ||||
|         ":variable_scope", | ||||
|         "//tensorflow/python/data", | ||||
|         "//tensorflow/python/distribute:reduce_util", | ||||
|         "//tensorflow/python/ops/losses", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -155,3 +155,11 @@ py_library( | ||||
|         "//tensorflow/python:training", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "reduce_util", | ||||
|     srcs = [ | ||||
|         "reduce_util.py", | ||||
|     ], | ||||
|     deps = [], | ||||
| ) | ||||
|  | ||||
							
								
								
									
										58
									
								
								tensorflow/python/distribute/reduce_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								tensorflow/python/distribute/reduce_util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Utilites for reduce operations.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import enum | ||||
| 
 | ||||
| from tensorflow.python.ops import variable_scope | ||||
| 
 | ||||
| 
 | ||||
| # TODO(priyag): Add this to tf.distribute namespace when it exists. | ||||
| class ReduceOp(enum.Enum): | ||||
|   """Indicates how a set of values should be reduced. | ||||
| 
 | ||||
|   * `SUM`: Add all the values. | ||||
|   * `MEAN`: Take the arithmetic mean ("average") of the values. | ||||
|   * `ONLY_FIRST_REPLICA`: Return the value on the first replica. | ||||
| 
 | ||||
|   TODO(priyag): Add the following types: | ||||
|   * `MIN`: Return the minimum of all values. | ||||
|   * `MAX`: Return the maximum of all values. | ||||
|   """ | ||||
| 
 | ||||
|   SUM = 0 | ||||
|   MEAN = 1 | ||||
|   ONLY_FIRST_REPLICA = 2 | ||||
| 
 | ||||
|   @staticmethod | ||||
|   def from_variable_aggregation(aggregation): | ||||
|     mapping = { | ||||
|         variable_scope.VariableAggregation.SUM: ReduceOp.SUM, | ||||
|         variable_scope.VariableAggregation.MEAN: ReduceOp.MEAN, | ||||
|         variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: | ||||
|             ReduceOp.ONLY_FIRST_REPLICA | ||||
|     } | ||||
| 
 | ||||
|     reduce_op = mapping.get(aggregation) | ||||
|     if not reduce_op: | ||||
|       raise ValueError("Could not convert from `tf.VariableAggregation` to" | ||||
|                        "`tf.distribute.ReduceOp` type") | ||||
|     return reduce_op | ||||
| 
 | ||||
| 
 | ||||
| @ -21,6 +21,7 @@ from __future__ import print_function | ||||
| import threading | ||||
| 
 | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| @ -755,9 +756,13 @@ class DistributionStrategy(object): | ||||
|     """Combine (via e.g. sum or mean) values across replicas. | ||||
| 
 | ||||
|     Args: | ||||
|       aggregation: Indicates how a variable will be aggregated. Accepted values | ||||
|         are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, | ||||
|       aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum. | ||||
|         DEPRECATED but still accepted values: | ||||
|         `tf.VariableAggregation.SUM`, | ||||
|         `tf.VariableAggregation.MEAN`, | ||||
|         `tf.VariableAggregation.ONLY_FIRST_REPLICA`. | ||||
|         # TODO(priyag): Rename this argument when moving the method to | ||||
|         # DSExtended. | ||||
|       value: A per-replica value with one value per replica. | ||||
|       destinations: A mirrored variable, a per-replica tensor, a device string, | ||||
|         or list of device strings. The return value will be copied to all | ||||
| @ -771,23 +776,32 @@ class DistributionStrategy(object): | ||||
|     # TODO(josh11b): Return an unwrapped value if colocate_with is a | ||||
|     # single device. | ||||
|     _require_cross_replica_context(self) | ||||
|     assert aggregation in [ | ||||
|         variable_scope.VariableAggregation.SUM, | ||||
|         variable_scope.VariableAggregation.MEAN, | ||||
|         variable_scope.VariableAggregation.ONLY_FIRST_REPLICA | ||||
|     ] | ||||
|     return self._reduce(aggregation, value, destinations) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|     # TODO(priyag): Remove this when all callers have been updated. | ||||
|     reduce_op = aggregation | ||||
|     if isinstance(aggregation, variable_scope.VariableAggregation): | ||||
|       assert aggregation in [ | ||||
|           variable_scope.VariableAggregation.SUM, | ||||
|           variable_scope.VariableAggregation.MEAN, | ||||
|           variable_scope.VariableAggregation.ONLY_FIRST_REPLICA | ||||
|       ] | ||||
|       reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) | ||||
|     return self._reduce(reduce_op, value, destinations) | ||||
| 
 | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     raise NotImplementedError("must be implemented in descendants") | ||||
| 
 | ||||
|   def batch_reduce(self, aggregation, value_destination_pairs): | ||||
|     """Combine multiple `reduce` calls into one for faster execution. | ||||
| 
 | ||||
|     Args: | ||||
|       aggregation: Indicates how a variable will be aggregated. Accepted values | ||||
|         are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, | ||||
|       aggregation: Reduction type, an instance of `tf.distribute.ReduceOp` enum. | ||||
|         DEPRECATED but still accepted values: | ||||
|         `tf.VariableAggregation.SUM`, | ||||
|         `tf.VariableAggregation.MEAN`, | ||||
|         `tf.VariableAggregation.ONLY_FIRST_REPLICA`. | ||||
|         # TODO(priyag): Rename this argument when moving the method to | ||||
|         # DSExtended. | ||||
|       value_destination_pairs: A sequence of (value, destinations) | ||||
|         pairs. See `reduce()` for a description. | ||||
| 
 | ||||
| @ -796,16 +810,21 @@ class DistributionStrategy(object): | ||||
|     """ | ||||
|     # TODO(josh11b): More docstring | ||||
|     _require_cross_replica_context(self) | ||||
|     assert aggregation in [ | ||||
|         variable_scope.VariableAggregation.SUM, | ||||
|         variable_scope.VariableAggregation.MEAN, | ||||
|         variable_scope.VariableAggregation.ONLY_FIRST_REPLICA | ||||
|     ] | ||||
|     return self._batch_reduce(aggregation, value_destination_pairs) | ||||
| 
 | ||||
|   def _batch_reduce(self, aggregation, value_destination_pairs): | ||||
|     # TODO(priyag): Remove this when all callers have been updated. | ||||
|     reduce_op = aggregation | ||||
|     if isinstance(aggregation, variable_scope.VariableAggregation): | ||||
|       assert aggregation in [ | ||||
|           variable_scope.VariableAggregation.SUM, | ||||
|           variable_scope.VariableAggregation.MEAN, | ||||
|           variable_scope.VariableAggregation.ONLY_FIRST_REPLICA | ||||
|       ] | ||||
|       reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) | ||||
|     return self._batch_reduce(reduce_op, value_destination_pairs) | ||||
| 
 | ||||
|   def _batch_reduce(self, reduce_op, value_destination_pairs): | ||||
|     return [ | ||||
|         self.reduce(aggregation, t, destinations=v) | ||||
|         self.reduce(reduce_op, t, destinations=v) | ||||
|         for t, v in value_destination_pairs | ||||
|     ] | ||||
| 
 | ||||
| @ -1154,9 +1173,9 @@ class _DefaultDistributionStrategy(DistributionStrategy): | ||||
|     with ReplicaContext(self, replica_id=0): | ||||
|       return fn(*args, **kwargs) | ||||
| 
 | ||||
|   def _reduce(self, aggregation, value, destinations): | ||||
|   def _reduce(self, reduce_op, value, destinations): | ||||
|     # TODO(josh11b): Use destinations? | ||||
|     del aggregation, destinations | ||||
|     del reduce_op, destinations | ||||
|     return value | ||||
| 
 | ||||
|   def _update(self, var, options, fn, *args, **kwargs): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user