Use new enum ReduceOp
for all callers of strategy.reduce. Also use it for loss reduction as well as OutputContext which were previously using VariableAggregation.
PiperOrigin-RevId: 221558489
This commit is contained in:
parent
ddfa238d6e
commit
35228fbf0c
@ -403,9 +403,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
train_op = optimizer.minimize(loss_fn)
|
||||
loss = loss_fn()
|
||||
output_context.set_last_step_output(
|
||||
name="replica_loss_agg",
|
||||
name="replica_loss_reduced",
|
||||
output=loss,
|
||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
reduce_op=reduce_util.ReduceOp.MEAN)
|
||||
output_context.set_non_tensor_output(key1, value1)
|
||||
return (train_op, loss)
|
||||
|
||||
@ -413,11 +413,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
(train_op, loss) = distribution.call_for_each_replica(
|
||||
model_fn, args=(output_context,) + inputs)
|
||||
output_context.set_last_step_output(
|
||||
name="cross_replica_loss_agg",
|
||||
name="cross_replica_loss_reduced",
|
||||
output=loss,
|
||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
reduce_op=reduce_util.ReduceOp.MEAN)
|
||||
output_context.set_last_step_output(
|
||||
name="cross_replica_loss_noagg",
|
||||
name="cross_replica_loss_not_reduced",
|
||||
output=loss)
|
||||
return distribution.group(train_op)
|
||||
|
||||
@ -425,16 +425,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def run_step():
|
||||
initial_loss = lambda: constant_op.constant(1e7)
|
||||
# Initial values corresponding to aggregated losses are just single
|
||||
# tensors. But for non aggregated losses, we need to have initial
|
||||
# Initial values corresponding to reduced losses are just single
|
||||
# tensors. But for non reduced losses, we need to have initial
|
||||
# values that are of the same structure as non reduced losses. In
|
||||
# MirroredStrategy, this will be a list of losses, in TPUStrategy
|
||||
# it will be single tensor. Using `broadcast` followed by `unwrap`
|
||||
# gives us the desired initial value structure.
|
||||
initial_loop_values = {
|
||||
"replica_loss_agg": initial_loss(),
|
||||
"cross_replica_loss_agg": initial_loss(),
|
||||
"cross_replica_loss_noagg":
|
||||
"replica_loss_reduced": initial_loss(),
|
||||
"cross_replica_loss_reduced": initial_loss(),
|
||||
"cross_replica_loss_not_reduced":
|
||||
distribution.unwrap(distribution.broadcast(initial_loss()))
|
||||
}
|
||||
ctx = distribution.run_steps_on_dataset(
|
||||
@ -444,17 +444,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs)
|
||||
self._verify_loss_output(
|
||||
initial_loss(),
|
||||
loss_output=ctx.last_step_outputs["replica_loss_agg"],
|
||||
aggregated=True, distribution=distribution)
|
||||
loss_output=ctx.last_step_outputs["replica_loss_reduced"],
|
||||
reduced=True, distribution=distribution)
|
||||
self._verify_loss_output(
|
||||
initial_loss(),
|
||||
loss_output=ctx.last_step_outputs["cross_replica_loss_agg"],
|
||||
aggregated=True, distribution=distribution)
|
||||
loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"],
|
||||
reduced=True, distribution=distribution)
|
||||
self._verify_loss_output(
|
||||
initial_loss(),
|
||||
loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"],
|
||||
aggregated=False, distribution=distribution)
|
||||
return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"])
|
||||
loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"],
|
||||
reduced=False, distribution=distribution)
|
||||
return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"])
|
||||
|
||||
self.evaluate(distribution.initialize())
|
||||
if not context.executing_eagerly():
|
||||
@ -479,17 +479,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
|
||||
self.assertTrue(error_is_not_increasing)
|
||||
|
||||
def _verify_loss_output(self, initial_loss, loss_output, aggregated,
|
||||
def _verify_loss_output(self, initial_loss, loss_output, reduced,
|
||||
distribution):
|
||||
if not aggregated:
|
||||
self.assertEqual(distribution.num_replicas_in_sync,
|
||||
len(distribution.unwrap(loss_output)))
|
||||
if not reduced:
|
||||
self.assertLen(distribution.unwrap(loss_output),
|
||||
distribution.num_replicas_in_sync)
|
||||
loss_output = distribution.reduce(
|
||||
aggregation=reduce_util.ReduceOp.MEAN,
|
||||
value=loss_output, destinations="/device:CPU:0")
|
||||
reduce_util.ReduceOp.MEAN, loss_output, destinations="/device:CPU:0")
|
||||
|
||||
unwrapped_output = distribution.unwrap(loss_output)
|
||||
self.assertEqual(1, len(unwrapped_output))
|
||||
self.assertLen(unwrapped_output, 1)
|
||||
loss_tensor = unwrapped_output[0]
|
||||
self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
|
||||
self.assertEqual(initial_loss.shape, loss_tensor.shape)
|
||||
|
@ -36,7 +36,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
@ -532,11 +531,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||
|
||||
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
|
||||
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||
output = last_step_tensor_outputs_dict[name]
|
||||
# For outputs that have already been aggregated, wrap them in a Mirrored
|
||||
# For outputs that have already been reduced, wrap them in a Mirrored
|
||||
# container, else in a PerReplica container.
|
||||
if aggregation is variables_lib.VariableAggregation.NONE:
|
||||
if reduce_op is None:
|
||||
last_step_tensor_outputs_dict[name] = values.regroup(
|
||||
{d: t for d, t in zip(self._devices, output)}, values.PerReplica)
|
||||
else:
|
||||
|
@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.util import nest
|
||||
@ -360,14 +359,14 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||
|
||||
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
|
||||
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||
output = last_step_tensor_outputs_dict[name]
|
||||
# For outputs that have already been aggregated, take the first value
|
||||
# For outputs that have already been reduced, take the first value
|
||||
# from the list as each value should be the same. Else return the full
|
||||
# list of values.
|
||||
# TODO(josh11b): If aggregation is NONE, we should return a PerReplica
|
||||
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
||||
# value.
|
||||
if aggregation is not variables_lib.VariableAggregation.NONE:
|
||||
if reduce_op is not None:
|
||||
# TODO(priyag): Should this return the element or a list with 1 element
|
||||
last_step_tensor_outputs_dict[name] = output[0]
|
||||
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
||||
|
@ -41,7 +41,6 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
@ -1508,7 +1507,7 @@ class MultiStepContext(object):
|
||||
A context object.
|
||||
"""
|
||||
self._last_step_outputs = {}
|
||||
self._last_step_outputs_aggregations = {}
|
||||
self._last_step_outputs_reduce_ops = {}
|
||||
self._non_tensor_outputs = {}
|
||||
|
||||
@property
|
||||
@ -1518,8 +1517,8 @@ class MultiStepContext(object):
|
||||
Keys in the dictionary are names of tensors to be captured, as specified
|
||||
when `set_last_step_output` is called.
|
||||
Values in the dictionary are the tensors themselves. If
|
||||
`set_last_step_output` was called with an `aggregation` for this output,
|
||||
then the value is the aggregated value.
|
||||
`set_last_step_output` was called with a `reduce_op` for this output,
|
||||
then the value is the reduced value.
|
||||
|
||||
Returns:
|
||||
A dictionary with last step outputs.
|
||||
@ -1532,8 +1531,7 @@ class MultiStepContext(object):
|
||||
raise ValueError("Need a dictionary to set last_step_outputs.")
|
||||
self._last_step_outputs = outputs
|
||||
|
||||
def set_last_step_output(self, name, output,
|
||||
aggregation=variables_lib.VariableAggregation.NONE):
|
||||
def set_last_step_output(self, name, output, reduce_op=None):
|
||||
"""Set `output` with `name` to be outputted from the last step.
|
||||
|
||||
Args:
|
||||
@ -1541,37 +1539,35 @@ class MultiStepContext(object):
|
||||
name.
|
||||
output: The tensors that should be outputted with `name`. See below for
|
||||
actual types supported.
|
||||
aggregation: Aggregation method to use to aggregate outputs from multiple
|
||||
reduce_op: Reduction method to use to reduce outputs from multiple
|
||||
replicas. Required if `set_last_step_output` is called in a replica
|
||||
context. Optional in cross_replica_context.
|
||||
When present, the outputs from all the replicas are aggregated using the
|
||||
When present, the outputs from all the replicas are reduced using the
|
||||
current distribution strategy's `reduce` method. Hence, the type of
|
||||
`output` must be what's supported by the corresponding `reduce` method.
|
||||
For e.g. if using MirroredStrategy and aggregation is set, output
|
||||
For e.g. if using MirroredStrategy and reduction is set, output
|
||||
must be a `PerReplica` value.
|
||||
The aggregation method is also recorded in a dictionary
|
||||
`_last_step_outputs_aggregations` for later interpreting of the
|
||||
The reduce method is also recorded in a dictionary
|
||||
`_last_step_outputs_reduce_ops` for later interpreting of the
|
||||
outputs as already reduced or not.
|
||||
# TODO(priyag): Change aggregation type used here.
|
||||
|
||||
"""
|
||||
if distribution_strategy_context.get_cross_replica_context():
|
||||
self._last_step_outputs_aggregations[name] = aggregation
|
||||
if aggregation is variables_lib.VariableAggregation.NONE:
|
||||
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||
if reduce_op is None:
|
||||
self._last_step_outputs[name] = output
|
||||
else:
|
||||
distribution = distribution_strategy_context.get_distribution_strategy()
|
||||
self._last_step_outputs[name] = distribution.reduce(
|
||||
aggregation, output, destinations="/device:CPU:0")
|
||||
reduce_op, output, destinations="/device:CPU:0")
|
||||
else:
|
||||
assert aggregation is not variables_lib.VariableAggregation.NONE
|
||||
assert reduce_op is not None
|
||||
def merge_fn(distribution, value):
|
||||
self._last_step_outputs[name] = distribution.reduce(
|
||||
aggregation, value, destinations="/device:CPU:0")
|
||||
reduce_op, value, destinations="/device:CPU:0")
|
||||
# Setting this inside the `merge_fn` because all replicas share the same
|
||||
# context object, so it's more robust to set it only once (even if all
|
||||
# the replicas are trying to set the same value).
|
||||
self._last_step_outputs_aggregations[name] = aggregation
|
||||
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
merge_fn, output)
|
||||
@ -1588,7 +1584,7 @@ class MultiStepContext(object):
|
||||
else:
|
||||
def merge_fn(distribution, value):
|
||||
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
||||
# in a list as aggregation doesn't make sense on non tensors.
|
||||
# in a list as reduction doesn't make sense on non tensors.
|
||||
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
merge_fn, output)
|
||||
|
@ -56,6 +56,7 @@ py_library(
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -848,8 +849,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
"""Scale loss for the number of replicas."""
|
||||
if scale_loss_by_num_replicas is None:
|
||||
scale_loss_by_num_replicas = (
|
||||
distribute_lib.get_loss_reduction() == variable_scope
|
||||
.VariableAggregation.MEAN)
|
||||
distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN)
|
||||
if scale_loss_by_num_replicas:
|
||||
num_replicas = \
|
||||
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
||||
@ -928,7 +928,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
|
||||
"""`apply_gradients` for use with a `DistributionStrategy`."""
|
||||
reduced_grads = distribution.batch_reduce(
|
||||
variable_scope.VariableAggregation.SUM, grads_and_vars)
|
||||
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||
var_list = [v for _, v in grads_and_vars]
|
||||
grads_and_vars = zip(reduced_grads, var_list)
|
||||
|
||||
|
@ -3476,6 +3476,7 @@ py_library(
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -146,6 +146,7 @@ py_library(
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"//tensorflow/python/training/checkpointable:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
import enum
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -34,7 +35,6 @@ from tensorflow.python.keras.engine import distributed_training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.util import nest
|
||||
@ -288,12 +288,12 @@ def _experimental_fit_loop(
|
||||
|
||||
for label, output in zip(out_labels, combined_fn.outputs):
|
||||
if label == 'loss':
|
||||
aggregation = distribute_lib.get_loss_reduction()
|
||||
reduce_op = distribute_lib.get_loss_reduction()
|
||||
else:
|
||||
# We aggregate all other metrics using mean for now. This is temporary
|
||||
# We reduce all other metrics using mean for now. This is temporary
|
||||
# workaround until new metrics are in place.
|
||||
aggregation = variable_scope.VariableAggregation.MEAN
|
||||
ctx.set_last_step_output(label, output, aggregation)
|
||||
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
||||
ctx.set_last_step_output(label, output, reduce_op)
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
|
||||
# feed_dict, session kwargs, run options, run_metadata for now. These should
|
||||
@ -571,12 +571,12 @@ def _experimental_test_loop(model, iterator, verbose=0, steps=None,
|
||||
|
||||
for label, output in zip(model.metrics_names, combined_fn.outputs):
|
||||
if label == 'loss':
|
||||
aggregation = distribute_lib.get_loss_reduction()
|
||||
reduce_op = distribute_lib.get_loss_reduction()
|
||||
else:
|
||||
# We aggregate all other metrics using mean for now. This is temporary
|
||||
# We reduce all other metrics using mean for now. This is temporary
|
||||
# workaround until new metrics are in place.
|
||||
aggregation = variable_scope.VariableAggregation.MEAN
|
||||
ctx.set_last_step_output(label, output, aggregation)
|
||||
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
||||
ctx.set_last_step_output(label, output, reduce_op)
|
||||
|
||||
return combined_fn.updates_op
|
||||
|
||||
|
@ -27,6 +27,7 @@ py_library(
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -32,7 +33,6 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
@ -580,7 +580,7 @@ def merge_grads(grads_and_vars):
|
||||
|
||||
def merge_grad_fn(strategy, grads_and_vars):
|
||||
reduced_grads = strategy.batch_reduce(
|
||||
variable_scope.VariableAggregation.MEAN, grads_and_vars)
|
||||
ds_reduce_util.ReduceOp.MEAN, grads_and_vars)
|
||||
return reduced_grads
|
||||
|
||||
return distribution_strategy_context.get_replica_context().merge_call(
|
||||
|
@ -71,11 +71,11 @@ class UpdateContext(object):
|
||||
|
||||
|
||||
def get_loss_reduction():
|
||||
"""Reduce `aggregation` corresponding to the last loss reduction."""
|
||||
"""Reduce op corresponding to the last loss reduction."""
|
||||
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
||||
if loss_reduction == losses_impl.Reduction.SUM:
|
||||
return variable_scope.VariableAggregation.SUM
|
||||
return variable_scope.VariableAggregation.MEAN
|
||||
return reduce_util.ReduceOp.SUM
|
||||
return reduce_util.ReduceOp.MEAN
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -95,8 +96,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||
# In a replica context, we update variable using the mean of value across
|
||||
# replicas.
|
||||
def merge_fn(strategy, v, value):
|
||||
value = strategy.reduce(
|
||||
variable_scope.VariableAggregation.MEAN, value, v)
|
||||
value = strategy.reduce(ds_reduce_util.ReduceOp.MEAN, value, v)
|
||||
return strategy.update(v, update_fn, value)
|
||||
|
||||
return replica_context.merge_call(merge_fn, variable, value)
|
||||
|
@ -24,6 +24,7 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -520,8 +521,7 @@ class Optimizer(
|
||||
|
||||
@staticmethod
|
||||
def _scale_loss(loss_value):
|
||||
if (distribute_lib.get_loss_reduction() ==
|
||||
variable_scope.VariableAggregation.MEAN):
|
||||
if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
|
||||
num_replicas = \
|
||||
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
||||
if num_replicas > 1:
|
||||
@ -658,10 +658,10 @@ class Optimizer(
|
||||
Returns:
|
||||
An `Operation` that applies the specified gradients across all
|
||||
replicas. If `global_step` was not None, that operation also
|
||||
increments `global_step`.
|
||||
increments `global_step`
|
||||
"""
|
||||
reduced_grads = distribution.batch_reduce(
|
||||
variable_scope.VariableAggregation.SUM, grads_and_vars)
|
||||
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||
var_list = [v for _, v in grads_and_vars]
|
||||
grads_and_vars = zip(reduced_grads, var_list)
|
||||
# Note that this is called in a cross-replica context.
|
||||
|
Loading…
Reference in New Issue
Block a user