Use new enum ReduceOp for all callers of strategy.reduce. Also use it for loss reduction as well as OutputContext which were previously using VariableAggregation.

PiperOrigin-RevId: 221558489
This commit is contained in:
Priya Gupta 2018-11-14 19:53:57 -08:00 committed by TensorFlower Gardener
parent ddfa238d6e
commit 35228fbf0c
14 changed files with 73 additions and 76 deletions

View File

@ -403,9 +403,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
train_op = optimizer.minimize(loss_fn) train_op = optimizer.minimize(loss_fn)
loss = loss_fn() loss = loss_fn()
output_context.set_last_step_output( output_context.set_last_step_output(
name="replica_loss_agg", name="replica_loss_reduced",
output=loss, output=loss,
aggregation=variables_lib.VariableAggregation.MEAN) reduce_op=reduce_util.ReduceOp.MEAN)
output_context.set_non_tensor_output(key1, value1) output_context.set_non_tensor_output(key1, value1)
return (train_op, loss) return (train_op, loss)
@ -413,11 +413,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
(train_op, loss) = distribution.call_for_each_replica( (train_op, loss) = distribution.call_for_each_replica(
model_fn, args=(output_context,) + inputs) model_fn, args=(output_context,) + inputs)
output_context.set_last_step_output( output_context.set_last_step_output(
name="cross_replica_loss_agg", name="cross_replica_loss_reduced",
output=loss, output=loss,
aggregation=variables_lib.VariableAggregation.MEAN) reduce_op=reduce_util.ReduceOp.MEAN)
output_context.set_last_step_output( output_context.set_last_step_output(
name="cross_replica_loss_noagg", name="cross_replica_loss_not_reduced",
output=loss) output=loss)
return distribution.group(train_op) return distribution.group(train_op)
@ -425,16 +425,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def run_step(): def run_step():
initial_loss = lambda: constant_op.constant(1e7) initial_loss = lambda: constant_op.constant(1e7)
# Initial values corresponding to aggregated losses are just single # Initial values corresponding to reduced losses are just single
# tensors. But for non aggregated losses, we need to have initial # tensors. But for non reduced losses, we need to have initial
# values that are of the same structure as non reduced losses. In # values that are of the same structure as non reduced losses. In
# MirroredStrategy, this will be a list of losses, in TPUStrategy # MirroredStrategy, this will be a list of losses, in TPUStrategy
# it will be single tensor. Using `broadcast` followed by `unwrap` # it will be single tensor. Using `broadcast` followed by `unwrap`
# gives us the desired initial value structure. # gives us the desired initial value structure.
initial_loop_values = { initial_loop_values = {
"replica_loss_agg": initial_loss(), "replica_loss_reduced": initial_loss(),
"cross_replica_loss_agg": initial_loss(), "cross_replica_loss_reduced": initial_loss(),
"cross_replica_loss_noagg": "cross_replica_loss_not_reduced":
distribution.unwrap(distribution.broadcast(initial_loss())) distribution.unwrap(distribution.broadcast(initial_loss()))
} }
ctx = distribution.run_steps_on_dataset( ctx = distribution.run_steps_on_dataset(
@ -444,17 +444,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs)
self._verify_loss_output( self._verify_loss_output(
initial_loss(), initial_loss(),
loss_output=ctx.last_step_outputs["replica_loss_agg"], loss_output=ctx.last_step_outputs["replica_loss_reduced"],
aggregated=True, distribution=distribution) reduced=True, distribution=distribution)
self._verify_loss_output( self._verify_loss_output(
initial_loss(), initial_loss(),
loss_output=ctx.last_step_outputs["cross_replica_loss_agg"], loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"],
aggregated=True, distribution=distribution) reduced=True, distribution=distribution)
self._verify_loss_output( self._verify_loss_output(
initial_loss(), initial_loss(),
loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"], loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"],
aggregated=False, distribution=distribution) reduced=False, distribution=distribution)
return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"]) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"])
self.evaluate(distribution.initialize()) self.evaluate(distribution.initialize())
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -479,17 +479,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
self.assertTrue(error_is_not_increasing) self.assertTrue(error_is_not_increasing)
def _verify_loss_output(self, initial_loss, loss_output, aggregated, def _verify_loss_output(self, initial_loss, loss_output, reduced,
distribution): distribution):
if not aggregated: if not reduced:
self.assertEqual(distribution.num_replicas_in_sync, self.assertLen(distribution.unwrap(loss_output),
len(distribution.unwrap(loss_output))) distribution.num_replicas_in_sync)
loss_output = distribution.reduce( loss_output = distribution.reduce(
aggregation=reduce_util.ReduceOp.MEAN, reduce_util.ReduceOp.MEAN, loss_output, destinations="/device:CPU:0")
value=loss_output, destinations="/device:CPU:0")
unwrapped_output = distribution.unwrap(loss_output) unwrapped_output = distribution.unwrap(loss_output)
self.assertEqual(1, len(unwrapped_output)) self.assertLen(unwrapped_output, 1)
loss_tensor = unwrapped_output[0] loss_tensor = unwrapped_output[0]
self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
self.assertEqual(initial_loss.shape, loss_tensor.shape) self.assertEqual(initial_loss.shape, loss_tensor.shape)

View File

@ -36,7 +36,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import coordinator from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
@ -532,11 +531,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
last_step_tensor_outputs_dict = nest.pack_sequence_as( last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs) ctx.last_step_outputs, last_step_tensor_outputs)
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name] output = last_step_tensor_outputs_dict[name]
# For outputs that have already been aggregated, wrap them in a Mirrored # For outputs that have already been reduced, wrap them in a Mirrored
# container, else in a PerReplica container. # container, else in a PerReplica container.
if aggregation is variables_lib.VariableAggregation.NONE: if reduce_op is None:
last_step_tensor_outputs_dict[name] = values.regroup( last_step_tensor_outputs_dict[name] = values.regroup(
{d: t for d, t in zip(self._devices, output)}, values.PerReplica) {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
else: else:

View File

@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -360,14 +359,14 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
last_step_tensor_outputs_dict = nest.pack_sequence_as( last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs) ctx.last_step_outputs, last_step_tensor_outputs)
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name] output = last_step_tensor_outputs_dict[name]
# For outputs that have already been aggregated, take the first value # For outputs that have already been reduced, take the first value
# from the list as each value should be the same. Else return the full # from the list as each value should be the same. Else return the full
# list of values. # list of values.
# TODO(josh11b): If aggregation is NONE, we should return a PerReplica # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
# value. # value.
if aggregation is not variables_lib.VariableAggregation.NONE: if reduce_op is not None:
# TODO(priyag): Should this return the element or a list with 1 element # TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0] last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access

View File

@ -41,7 +41,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import distribution_strategy_context
@ -1508,7 +1507,7 @@ class MultiStepContext(object):
A context object. A context object.
""" """
self._last_step_outputs = {} self._last_step_outputs = {}
self._last_step_outputs_aggregations = {} self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {} self._non_tensor_outputs = {}
@property @property
@ -1518,8 +1517,8 @@ class MultiStepContext(object):
Keys in the dictionary are names of tensors to be captured, as specified Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called. when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with an `aggregation` for this output, `set_last_step_output` was called with a `reduce_op` for this output,
then the value is the aggregated value. then the value is the reduced value.
Returns: Returns:
A dictionary with last step outputs. A dictionary with last step outputs.
@ -1532,8 +1531,7 @@ class MultiStepContext(object):
raise ValueError("Need a dictionary to set last_step_outputs.") raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs self._last_step_outputs = outputs
def set_last_step_output(self, name, output, def set_last_step_output(self, name, output, reduce_op=None):
aggregation=variables_lib.VariableAggregation.NONE):
"""Set `output` with `name` to be outputted from the last step. """Set `output` with `name` to be outputted from the last step.
Args: Args:
@ -1541,37 +1539,35 @@ class MultiStepContext(object):
name. name.
output: The tensors that should be outputted with `name`. See below for output: The tensors that should be outputted with `name`. See below for
actual types supported. actual types supported.
aggregation: Aggregation method to use to aggregate outputs from multiple reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context. context. Optional in cross_replica_context.
When present, the outputs from all the replicas are aggregated using the When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method. `output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and aggregation is set, output For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value. must be a `PerReplica` value.
The aggregation method is also recorded in a dictionary The reduce method is also recorded in a dictionary
`_last_step_outputs_aggregations` for later interpreting of the `_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not. outputs as already reduced or not.
# TODO(priyag): Change aggregation type used here.
""" """
if distribution_strategy_context.get_cross_replica_context(): if distribution_strategy_context.get_cross_replica_context():
self._last_step_outputs_aggregations[name] = aggregation self._last_step_outputs_reduce_ops[name] = reduce_op
if aggregation is variables_lib.VariableAggregation.NONE: if reduce_op is None:
self._last_step_outputs[name] = output self._last_step_outputs[name] = output
else: else:
distribution = distribution_strategy_context.get_distribution_strategy() distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce( self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0") reduce_op, output, destinations="/device:CPU:0")
else: else:
assert aggregation is not variables_lib.VariableAggregation.NONE assert reduce_op is not None
def merge_fn(distribution, value): def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce( self._last_step_outputs[name] = distribution.reduce(
aggregation, value, destinations="/device:CPU:0") reduce_op, value, destinations="/device:CPU:0")
# Setting this inside the `merge_fn` because all replicas share the same # Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all # context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value). # the replicas are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call( distribution_strategy_context.get_replica_context().merge_call(
merge_fn, output) merge_fn, output)
@ -1588,7 +1584,7 @@ class MultiStepContext(object):
else: else:
def merge_fn(distribution, value): def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values # NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors. # in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value) self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_replica_context().merge_call( distribution_strategy_context.get_replica_context().merge_call(
merge_fn, output) merge_fn, output)

View File

@ -56,6 +56,7 @@ py_library(
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/distribute:reduce_util",
], ],
) )

View File

@ -24,6 +24,7 @@ import abc
import six import six
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -848,8 +849,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
"""Scale loss for the number of replicas.""" """Scale loss for the number of replicas."""
if scale_loss_by_num_replicas is None: if scale_loss_by_num_replicas is None:
scale_loss_by_num_replicas = ( scale_loss_by_num_replicas = (
distribute_lib.get_loss_reduction() == variable_scope distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN)
.VariableAggregation.MEAN)
if scale_loss_by_num_replicas: if scale_loss_by_num_replicas:
num_replicas = \ num_replicas = \
distribute_ctx.get_distribution_strategy().num_replicas_in_sync distribute_ctx.get_distribution_strategy().num_replicas_in_sync
@ -928,7 +928,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _distributed_apply(self, distribution, grads_and_vars, global_step, name): def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`.""" """`apply_gradients` for use with a `DistributionStrategy`."""
reduced_grads = distribution.batch_reduce( reduced_grads = distribution.batch_reduce(
variable_scope.VariableAggregation.SUM, grads_and_vars) ds_reduce_util.ReduceOp.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars] var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list) grads_and_vars = zip(reduced_grads, var_list)

View File

@ -3476,6 +3476,7 @@ py_library(
"@six_archive//:six", "@six_archive//:six",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",

View File

@ -146,6 +146,7 @@ py_library(
deps = [ deps = [
":backend", ":backend",
"//tensorflow/python/data", "//tensorflow/python/data",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/training/checkpointable:data_structures", "//tensorflow/python/training/checkpointable:data_structures",
"//tensorflow/tools/docs:doc_controls", "//tensorflow/tools/docs:doc_controls",

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import enum import enum
import numpy as np import numpy as np
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -34,7 +35,6 @@ from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -288,12 +288,12 @@ def _experimental_fit_loop(
for label, output in zip(out_labels, combined_fn.outputs): for label, output in zip(out_labels, combined_fn.outputs):
if label == 'loss': if label == 'loss':
aggregation = distribute_lib.get_loss_reduction() reduce_op = distribute_lib.get_loss_reduction()
else: else:
# We aggregate all other metrics using mean for now. This is temporary # We reduce all other metrics using mean for now. This is temporary
# workaround until new metrics are in place. # workaround until new metrics are in place.
aggregation = variable_scope.VariableAggregation.MEAN reduce_op = ds_reduce_util.ReduceOp.MEAN
ctx.set_last_step_output(label, output, aggregation) ctx.set_last_step_output(label, output, reduce_op)
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
# feed_dict, session kwargs, run options, run_metadata for now. These should # feed_dict, session kwargs, run options, run_metadata for now. These should
@ -571,12 +571,12 @@ def _experimental_test_loop(model, iterator, verbose=0, steps=None,
for label, output in zip(model.metrics_names, combined_fn.outputs): for label, output in zip(model.metrics_names, combined_fn.outputs):
if label == 'loss': if label == 'loss':
aggregation = distribute_lib.get_loss_reduction() reduce_op = distribute_lib.get_loss_reduction()
else: else:
# We aggregate all other metrics using mean for now. This is temporary # We reduce all other metrics using mean for now. This is temporary
# workaround until new metrics are in place. # workaround until new metrics are in place.
aggregation = variable_scope.VariableAggregation.MEAN reduce_op = ds_reduce_util.ReduceOp.MEAN
ctx.set_last_step_output(label, output, aggregation) ctx.set_last_step_output(label, output, reduce_op)
return combined_fn.updates_op return combined_fn.updates_op

View File

@ -27,6 +27,7 @@ py_library(
"//tensorflow/python:state_ops", "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/distribute:reduce_util",
], ],
) )

View File

@ -24,6 +24,7 @@ import abc
import six import six
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -32,7 +33,6 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers from tensorflow.python.keras import initializers
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import gradients from tensorflow.python.ops import gradients
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import distribution_strategy_context
@ -580,7 +580,7 @@ def merge_grads(grads_and_vars):
def merge_grad_fn(strategy, grads_and_vars): def merge_grad_fn(strategy, grads_and_vars):
reduced_grads = strategy.batch_reduce( reduced_grads = strategy.batch_reduce(
variable_scope.VariableAggregation.MEAN, grads_and_vars) ds_reduce_util.ReduceOp.MEAN, grads_and_vars)
return reduced_grads return reduced_grads
return distribution_strategy_context.get_replica_context().merge_call( return distribution_strategy_context.get_replica_context().merge_call(

View File

@ -71,11 +71,11 @@ class UpdateContext(object):
def get_loss_reduction(): def get_loss_reduction():
"""Reduce `aggregation` corresponding to the last loss reduction.""" """Reduce op corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if loss_reduction == losses_impl.Reduction.SUM: if loss_reduction == losses_impl.Reduction.SUM:
return variable_scope.VariableAggregation.SUM return reduce_util.ReduceOp.SUM
return variable_scope.VariableAggregation.MEAN return reduce_util.ReduceOp.MEAN
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -95,8 +96,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
# In a replica context, we update variable using the mean of value across # In a replica context, we update variable using the mean of value across
# replicas. # replicas.
def merge_fn(strategy, v, value): def merge_fn(strategy, v, value):
value = strategy.reduce( value = strategy.reduce(ds_reduce_util.ReduceOp.MEAN, value, v)
variable_scope.VariableAggregation.MEAN, value, v)
return strategy.update(v, update_fn, value) return strategy.update(v, update_fn, value)
return replica_context.merge_call(merge_fn, variable, value) return replica_context.merge_call(merge_fn, variable, value)

View File

@ -24,6 +24,7 @@ import abc
import six import six
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -520,8 +521,7 @@ class Optimizer(
@staticmethod @staticmethod
def _scale_loss(loss_value): def _scale_loss(loss_value):
if (distribute_lib.get_loss_reduction() == if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
variable_scope.VariableAggregation.MEAN):
num_replicas = \ num_replicas = \
distribute_ctx.get_distribution_strategy().num_replicas_in_sync distribute_ctx.get_distribution_strategy().num_replicas_in_sync
if num_replicas > 1: if num_replicas > 1:
@ -658,10 +658,10 @@ class Optimizer(
Returns: Returns:
An `Operation` that applies the specified gradients across all An `Operation` that applies the specified gradients across all
replicas. If `global_step` was not None, that operation also replicas. If `global_step` was not None, that operation also
increments `global_step`. increments `global_step`
""" """
reduced_grads = distribution.batch_reduce( reduced_grads = distribution.batch_reduce(
variable_scope.VariableAggregation.SUM, grads_and_vars) ds_reduce_util.ReduceOp.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars] var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list) grads_and_vars = zip(reduced_grads, var_list)
# Note that this is called in a cross-replica context. # Note that this is called in a cross-replica context.