From 35228fbf0caf18a820a7306809f1fae5e111a64e Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 14 Nov 2018 19:53:57 -0800 Subject: [PATCH] Use new enum `ReduceOp` for all callers of strategy.reduce. Also use it for loss reduction as well as OutputContext which were previously using VariableAggregation. PiperOrigin-RevId: 221558489 --- .../distribute/python/minimize_loss_test.py | 47 +++++++++---------- .../distribute/python/mirrored_strategy.py | 7 ++- .../contrib/distribute/python/tpu_strategy.py | 9 ++-- .../contrib/distribute/python/values.py | 36 +++++++------- tensorflow/contrib/optimizer_v2/BUILD | 1 + .../contrib/optimizer_v2/optimizer_v2.py | 6 +-- tensorflow/python/BUILD | 1 + tensorflow/python/keras/BUILD | 1 + .../keras/engine/training_distributed.py | 18 +++---- tensorflow/python/keras/optimizer_v2/BUILD | 1 + .../python/keras/optimizer_v2/optimizer_v2.py | 4 +- tensorflow/python/training/distribute.py | 6 +-- tensorflow/python/training/moving_averages.py | 4 +- tensorflow/python/training/optimizer.py | 8 ++-- 14 files changed, 73 insertions(+), 76 deletions(-) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 5d3b5d8922a..1f57dd1754c 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -403,9 +403,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( - name="replica_loss_agg", + name="replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) @@ -413,11 +413,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): (train_op, loss) = distribution.call_for_each_replica( model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( - name="cross_replica_loss_agg", + name="cross_replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( - name="cross_replica_loss_noagg", + name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) @@ -425,16 +425,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): initial_loss = lambda: constant_op.constant(1e7) - # Initial values corresponding to aggregated losses are just single - # tensors. But for non aggregated losses, we need to have initial + # Initial values corresponding to reduced losses are just single + # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `broadcast` followed by `unwrap` # gives us the desired initial value structure. initial_loop_values = { - "replica_loss_agg": initial_loss(), - "cross_replica_loss_agg": initial_loss(), - "cross_replica_loss_noagg": + "replica_loss_reduced": initial_loss(), + "cross_replica_loss_reduced": initial_loss(), + "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } ctx = distribution.run_steps_on_dataset( @@ -444,17 +444,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"], - aggregated=False, distribution=distribution) - return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"]) + loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], + reduced=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) self.evaluate(distribution.initialize()) if not context.executing_eagerly(): @@ -479,17 +479,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing) - def _verify_loss_output(self, initial_loss, loss_output, aggregated, + def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): - if not aggregated: - self.assertEqual(distribution.num_replicas_in_sync, - len(distribution.unwrap(loss_output))) + if not reduced: + self.assertLen(distribution.unwrap(loss_output), + distribution.num_replicas_in_sync) loss_output = distribution.reduce( - aggregation=reduce_util.ReduceOp.MEAN, - value=loss_output, destinations="/device:CPU:0") + reduce_util.ReduceOp.MEAN, loss_output, destinations="/device:CPU:0") unwrapped_output = distribution.unwrap(loss_output) - self.assertEqual(1, len(unwrapped_output)) + self.assertLen(unwrapped_output, 1) loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.shape, loss_tensor.shape) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 4bb2cb1990b..9795df8d131 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -36,7 +36,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -532,11 +531,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, wrap them in a Mirrored + # For outputs that have already been reduced, wrap them in a Mirrored # container, else in a PerReplica container. - if aggregation is variables_lib.VariableAggregation.NONE: + if reduce_op is None: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerReplica) else: diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index adc075011d4..dcc2c644388 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -360,14 +359,14 @@ class TPUStrategy(distribute_lib.DistributionStrategy): last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, take the first value + # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerReplica + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. - if aggregation is not variables_lib.VariableAggregation.NONE: + if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index f565381b492..bf8c958efaf 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -41,7 +41,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribution_strategy_context @@ -1508,7 +1507,7 @@ class MultiStepContext(object): A context object. """ self._last_step_outputs = {} - self._last_step_outputs_aggregations = {} + self._last_step_outputs_reduce_ops = {} self._non_tensor_outputs = {} @property @@ -1518,8 +1517,8 @@ class MultiStepContext(object): Keys in the dictionary are names of tensors to be captured, as specified when `set_last_step_output` is called. Values in the dictionary are the tensors themselves. If - `set_last_step_output` was called with an `aggregation` for this output, - then the value is the aggregated value. + `set_last_step_output` was called with a `reduce_op` for this output, + then the value is the reduced value. Returns: A dictionary with last step outputs. @@ -1532,8 +1531,7 @@ class MultiStepContext(object): raise ValueError("Need a dictionary to set last_step_outputs.") self._last_step_outputs = outputs - def set_last_step_output(self, name, output, - aggregation=variables_lib.VariableAggregation.NONE): + def set_last_step_output(self, name, output, reduce_op=None): """Set `output` with `name` to be outputted from the last step. Args: @@ -1541,37 +1539,35 @@ class MultiStepContext(object): name. output: The tensors that should be outputted with `name`. See below for actual types supported. - aggregation: Aggregation method to use to aggregate outputs from multiple + reduce_op: Reduction method to use to reduce outputs from multiple replicas. Required if `set_last_step_output` is called in a replica context. Optional in cross_replica_context. - When present, the outputs from all the replicas are aggregated using the + When present, the outputs from all the replicas are reduced using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. - For e.g. if using MirroredStrategy and aggregation is set, output + For e.g. if using MirroredStrategy and reduction is set, output must be a `PerReplica` value. - The aggregation method is also recorded in a dictionary - `_last_step_outputs_aggregations` for later interpreting of the + The reduce method is also recorded in a dictionary + `_last_step_outputs_reduce_ops` for later interpreting of the outputs as already reduced or not. - # TODO(priyag): Change aggregation type used here. - """ if distribution_strategy_context.get_cross_replica_context(): - self._last_step_outputs_aggregations[name] = aggregation - if aggregation is variables_lib.VariableAggregation.NONE: + self._last_step_outputs_reduce_ops[name] = reduce_op + if reduce_op is None: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( - aggregation, output, destinations="/device:CPU:0") + reduce_op, output, destinations="/device:CPU:0") else: - assert aggregation is not variables_lib.VariableAggregation.NONE + assert reduce_op is not None def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( - aggregation, value, destinations="/device:CPU:0") + reduce_op, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all replicas share the same # context object, so it's more robust to set it only once (even if all # the replicas are trying to set the same value). - self._last_step_outputs_aggregations[name] = aggregation + self._last_step_outputs_reduce_ops[name] = reduce_op distribution_strategy_context.get_replica_context().merge_call( merge_fn, output) @@ -1588,7 +1584,7 @@ class MultiStepContext(object): else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values - # in a list as aggregation doesn't make sense on non tensors. + # in a list as reduction doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) distribution_strategy_context.get_replica_context().merge_call( merge_fn, output) diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 3ba3ee29ec7..835fb4aec4f 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -56,6 +56,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:reduce_util", ], ) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 467dd86d8fd..1b6e70d3a03 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,6 +24,7 @@ import abc import six +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -848,8 +849,7 @@ class OptimizerV2(optimizer_v1.Optimizer): """Scale loss for the number of replicas.""" if scale_loss_by_num_replicas is None: scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) + distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync @@ -928,7 +928,7 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" reduced_grads = distribution.batch_reduce( - variable_scope.VariableAggregation.SUM, grads_and_vars) + ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a6e5c110b7f..18f226beb5c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3476,6 +3476,7 @@ py_library( "@six_archive//:six", "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index db78eff86e4..365ae4746d1 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -146,6 +146,7 @@ py_library( deps = [ ":backend", "//tensorflow/python/data", + "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/training/checkpointable:data_structures", "//tensorflow/tools/docs:doc_controls", diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 808d7c9f333..3a2373b4cf0 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -22,6 +22,7 @@ from __future__ import print_function import enum import numpy as np +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -34,7 +35,6 @@ from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -288,12 +288,12 @@ def _experimental_fit_loop( for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': - aggregation = distribute_lib.get_loss_reduction() + reduce_op = distribute_lib.get_loss_reduction() else: - # We aggregate all other metrics using mean for now. This is temporary + # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. - aggregation = variable_scope.VariableAggregation.MEAN - ctx.set_last_step_output(label, output, aggregation) + reduce_op = ds_reduce_util.ReduceOp.MEAN + ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should @@ -571,12 +571,12 @@ def _experimental_test_loop(model, iterator, verbose=0, steps=None, for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': - aggregation = distribute_lib.get_loss_reduction() + reduce_op = distribute_lib.get_loss_reduction() else: - # We aggregate all other metrics using mean for now. This is temporary + # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. - aggregation = variable_scope.VariableAggregation.MEAN - ctx.set_last_step_output(label, output, aggregation) + reduce_op = ds_reduce_util.ReduceOp.MEAN + ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index eaa764992f6..e0ff5875496 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -27,6 +27,7 @@ py_library( "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:reduce_util", ], ) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 68b5368711a..8e3aea28465 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -24,6 +24,7 @@ import abc import six +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -32,7 +33,6 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import gradients -from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribution_strategy_context @@ -580,7 +580,7 @@ def merge_grads(grads_and_vars): def merge_grad_fn(strategy, grads_and_vars): reduced_grads = strategy.batch_reduce( - variable_scope.VariableAggregation.MEAN, grads_and_vars) + ds_reduce_util.ReduceOp.MEAN, grads_and_vars) return reduced_grads return distribution_strategy_context.get_replica_context().merge_call( diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 883a22389a6..aed2a413ae5 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -71,11 +71,11 @@ class UpdateContext(object): def get_loss_reduction(): - """Reduce `aggregation` corresponding to the last loss reduction.""" + """Reduce op corresponding to the last loss reduction.""" loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if loss_reduction == losses_impl.Reduction.SUM: - return variable_scope.VariableAggregation.SUM - return variable_scope.VariableAggregation.MEAN + return reduce_util.ReduceOp.SUM + return reduce_util.ReduceOp.MEAN # ------------------------------------------------------------------------------ diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index fc9eb479cc3..957c8810ac1 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops @@ -95,8 +96,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): # In a replica context, we update variable using the mean of value across # replicas. def merge_fn(strategy, v, value): - value = strategy.reduce( - variable_scope.VariableAggregation.MEAN, value, v) + value = strategy.reduce(ds_reduce_util.ReduceOp.MEAN, value, v) return strategy.update(v, update_fn, value) return replica_context.merge_call(merge_fn, variable, value) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9dfa9d2afb2..ada8a7d6163 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -24,6 +24,7 @@ import abc import six +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -520,8 +521,7 @@ class Optimizer( @staticmethod def _scale_loss(loss_value): - if (distribute_lib.get_loss_reduction() == - variable_scope.VariableAggregation.MEAN): + if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: @@ -658,10 +658,10 @@ class Optimizer( Returns: An `Operation` that applies the specified gradients across all replicas. If `global_step` was not None, that operation also - increments `global_step`. + increments `global_step` """ reduced_grads = distribution.batch_reduce( - variable_scope.VariableAggregation.SUM, grads_and_vars) + ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) # Note that this is called in a cross-replica context.