Use new enum ReduceOp
for all callers of strategy.reduce. Also use it for loss reduction as well as OutputContext which were previously using VariableAggregation.
PiperOrigin-RevId: 221558489
This commit is contained in:
parent
ddfa238d6e
commit
35228fbf0c
@ -403,9 +403,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
train_op = optimizer.minimize(loss_fn)
|
train_op = optimizer.minimize(loss_fn)
|
||||||
loss = loss_fn()
|
loss = loss_fn()
|
||||||
output_context.set_last_step_output(
|
output_context.set_last_step_output(
|
||||||
name="replica_loss_agg",
|
name="replica_loss_reduced",
|
||||||
output=loss,
|
output=loss,
|
||||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
reduce_op=reduce_util.ReduceOp.MEAN)
|
||||||
output_context.set_non_tensor_output(key1, value1)
|
output_context.set_non_tensor_output(key1, value1)
|
||||||
return (train_op, loss)
|
return (train_op, loss)
|
||||||
|
|
||||||
@ -413,11 +413,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
(train_op, loss) = distribution.call_for_each_replica(
|
(train_op, loss) = distribution.call_for_each_replica(
|
||||||
model_fn, args=(output_context,) + inputs)
|
model_fn, args=(output_context,) + inputs)
|
||||||
output_context.set_last_step_output(
|
output_context.set_last_step_output(
|
||||||
name="cross_replica_loss_agg",
|
name="cross_replica_loss_reduced",
|
||||||
output=loss,
|
output=loss,
|
||||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
reduce_op=reduce_util.ReduceOp.MEAN)
|
||||||
output_context.set_last_step_output(
|
output_context.set_last_step_output(
|
||||||
name="cross_replica_loss_noagg",
|
name="cross_replica_loss_not_reduced",
|
||||||
output=loss)
|
output=loss)
|
||||||
return distribution.group(train_op)
|
return distribution.group(train_op)
|
||||||
|
|
||||||
@ -425,16 +425,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
initial_loss = lambda: constant_op.constant(1e7)
|
initial_loss = lambda: constant_op.constant(1e7)
|
||||||
# Initial values corresponding to aggregated losses are just single
|
# Initial values corresponding to reduced losses are just single
|
||||||
# tensors. But for non aggregated losses, we need to have initial
|
# tensors. But for non reduced losses, we need to have initial
|
||||||
# values that are of the same structure as non reduced losses. In
|
# values that are of the same structure as non reduced losses. In
|
||||||
# MirroredStrategy, this will be a list of losses, in TPUStrategy
|
# MirroredStrategy, this will be a list of losses, in TPUStrategy
|
||||||
# it will be single tensor. Using `broadcast` followed by `unwrap`
|
# it will be single tensor. Using `broadcast` followed by `unwrap`
|
||||||
# gives us the desired initial value structure.
|
# gives us the desired initial value structure.
|
||||||
initial_loop_values = {
|
initial_loop_values = {
|
||||||
"replica_loss_agg": initial_loss(),
|
"replica_loss_reduced": initial_loss(),
|
||||||
"cross_replica_loss_agg": initial_loss(),
|
"cross_replica_loss_reduced": initial_loss(),
|
||||||
"cross_replica_loss_noagg":
|
"cross_replica_loss_not_reduced":
|
||||||
distribution.unwrap(distribution.broadcast(initial_loss()))
|
distribution.unwrap(distribution.broadcast(initial_loss()))
|
||||||
}
|
}
|
||||||
ctx = distribution.run_steps_on_dataset(
|
ctx = distribution.run_steps_on_dataset(
|
||||||
@ -444,17 +444,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs)
|
self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs)
|
||||||
self._verify_loss_output(
|
self._verify_loss_output(
|
||||||
initial_loss(),
|
initial_loss(),
|
||||||
loss_output=ctx.last_step_outputs["replica_loss_agg"],
|
loss_output=ctx.last_step_outputs["replica_loss_reduced"],
|
||||||
aggregated=True, distribution=distribution)
|
reduced=True, distribution=distribution)
|
||||||
self._verify_loss_output(
|
self._verify_loss_output(
|
||||||
initial_loss(),
|
initial_loss(),
|
||||||
loss_output=ctx.last_step_outputs["cross_replica_loss_agg"],
|
loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"],
|
||||||
aggregated=True, distribution=distribution)
|
reduced=True, distribution=distribution)
|
||||||
self._verify_loss_output(
|
self._verify_loss_output(
|
||||||
initial_loss(),
|
initial_loss(),
|
||||||
loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"],
|
loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"],
|
||||||
aggregated=False, distribution=distribution)
|
reduced=False, distribution=distribution)
|
||||||
return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"])
|
return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"])
|
||||||
|
|
||||||
self.evaluate(distribution.initialize())
|
self.evaluate(distribution.initialize())
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -479,17 +479,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
|
error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
|
||||||
self.assertTrue(error_is_not_increasing)
|
self.assertTrue(error_is_not_increasing)
|
||||||
|
|
||||||
def _verify_loss_output(self, initial_loss, loss_output, aggregated,
|
def _verify_loss_output(self, initial_loss, loss_output, reduced,
|
||||||
distribution):
|
distribution):
|
||||||
if not aggregated:
|
if not reduced:
|
||||||
self.assertEqual(distribution.num_replicas_in_sync,
|
self.assertLen(distribution.unwrap(loss_output),
|
||||||
len(distribution.unwrap(loss_output)))
|
distribution.num_replicas_in_sync)
|
||||||
loss_output = distribution.reduce(
|
loss_output = distribution.reduce(
|
||||||
aggregation=reduce_util.ReduceOp.MEAN,
|
reduce_util.ReduceOp.MEAN, loss_output, destinations="/device:CPU:0")
|
||||||
value=loss_output, destinations="/device:CPU:0")
|
|
||||||
|
|
||||||
unwrapped_output = distribution.unwrap(loss_output)
|
unwrapped_output = distribution.unwrap(loss_output)
|
||||||
self.assertEqual(1, len(unwrapped_output))
|
self.assertLen(unwrapped_output, 1)
|
||||||
loss_tensor = unwrapped_output[0]
|
loss_tensor = unwrapped_output[0]
|
||||||
self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
|
self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
|
||||||
self.assertEqual(initial_loss.shape, loss_tensor.shape)
|
self.assertEqual(initial_loss.shape, loss_tensor.shape)
|
||||||
|
@ -36,7 +36,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
|
||||||
from tensorflow.python.training import coordinator
|
from tensorflow.python.training import coordinator
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
from tensorflow.python.training import distribute as distribute_lib
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
@ -532,11 +531,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
|||||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||||
|
|
||||||
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
|
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||||
output = last_step_tensor_outputs_dict[name]
|
output = last_step_tensor_outputs_dict[name]
|
||||||
# For outputs that have already been aggregated, wrap them in a Mirrored
|
# For outputs that have already been reduced, wrap them in a Mirrored
|
||||||
# container, else in a PerReplica container.
|
# container, else in a PerReplica container.
|
||||||
if aggregation is variables_lib.VariableAggregation.NONE:
|
if reduce_op is None:
|
||||||
last_step_tensor_outputs_dict[name] = values.regroup(
|
last_step_tensor_outputs_dict[name] = values.regroup(
|
||||||
{d: t for d, t in zip(self._devices, output)}, values.PerReplica)
|
{d: t for d, t in zip(self._devices, output)}, values.PerReplica)
|
||||||
else:
|
else:
|
||||||
|
@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
from tensorflow.python.training import distribute as distribute_lib
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -360,14 +359,14 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||||
|
|
||||||
for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
|
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||||
output = last_step_tensor_outputs_dict[name]
|
output = last_step_tensor_outputs_dict[name]
|
||||||
# For outputs that have already been aggregated, take the first value
|
# For outputs that have already been reduced, take the first value
|
||||||
# from the list as each value should be the same. Else return the full
|
# from the list as each value should be the same. Else return the full
|
||||||
# list of values.
|
# list of values.
|
||||||
# TODO(josh11b): If aggregation is NONE, we should return a PerReplica
|
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
||||||
# value.
|
# value.
|
||||||
if aggregation is not variables_lib.VariableAggregation.NONE:
|
if reduce_op is not None:
|
||||||
# TODO(priyag): Should this return the element or a list with 1 element
|
# TODO(priyag): Should this return the element or a list with 1 element
|
||||||
last_step_tensor_outputs_dict[name] = output[0]
|
last_step_tensor_outputs_dict[name] = output[0]
|
||||||
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
||||||
|
@ -41,7 +41,6 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import gen_resource_variable_ops
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
from tensorflow.python.training import distribute as distribute_lib
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
from tensorflow.python.training import distribution_strategy_context
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
@ -1508,7 +1507,7 @@ class MultiStepContext(object):
|
|||||||
A context object.
|
A context object.
|
||||||
"""
|
"""
|
||||||
self._last_step_outputs = {}
|
self._last_step_outputs = {}
|
||||||
self._last_step_outputs_aggregations = {}
|
self._last_step_outputs_reduce_ops = {}
|
||||||
self._non_tensor_outputs = {}
|
self._non_tensor_outputs = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1518,8 +1517,8 @@ class MultiStepContext(object):
|
|||||||
Keys in the dictionary are names of tensors to be captured, as specified
|
Keys in the dictionary are names of tensors to be captured, as specified
|
||||||
when `set_last_step_output` is called.
|
when `set_last_step_output` is called.
|
||||||
Values in the dictionary are the tensors themselves. If
|
Values in the dictionary are the tensors themselves. If
|
||||||
`set_last_step_output` was called with an `aggregation` for this output,
|
`set_last_step_output` was called with a `reduce_op` for this output,
|
||||||
then the value is the aggregated value.
|
then the value is the reduced value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary with last step outputs.
|
A dictionary with last step outputs.
|
||||||
@ -1532,8 +1531,7 @@ class MultiStepContext(object):
|
|||||||
raise ValueError("Need a dictionary to set last_step_outputs.")
|
raise ValueError("Need a dictionary to set last_step_outputs.")
|
||||||
self._last_step_outputs = outputs
|
self._last_step_outputs = outputs
|
||||||
|
|
||||||
def set_last_step_output(self, name, output,
|
def set_last_step_output(self, name, output, reduce_op=None):
|
||||||
aggregation=variables_lib.VariableAggregation.NONE):
|
|
||||||
"""Set `output` with `name` to be outputted from the last step.
|
"""Set `output` with `name` to be outputted from the last step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1541,37 +1539,35 @@ class MultiStepContext(object):
|
|||||||
name.
|
name.
|
||||||
output: The tensors that should be outputted with `name`. See below for
|
output: The tensors that should be outputted with `name`. See below for
|
||||||
actual types supported.
|
actual types supported.
|
||||||
aggregation: Aggregation method to use to aggregate outputs from multiple
|
reduce_op: Reduction method to use to reduce outputs from multiple
|
||||||
replicas. Required if `set_last_step_output` is called in a replica
|
replicas. Required if `set_last_step_output` is called in a replica
|
||||||
context. Optional in cross_replica_context.
|
context. Optional in cross_replica_context.
|
||||||
When present, the outputs from all the replicas are aggregated using the
|
When present, the outputs from all the replicas are reduced using the
|
||||||
current distribution strategy's `reduce` method. Hence, the type of
|
current distribution strategy's `reduce` method. Hence, the type of
|
||||||
`output` must be what's supported by the corresponding `reduce` method.
|
`output` must be what's supported by the corresponding `reduce` method.
|
||||||
For e.g. if using MirroredStrategy and aggregation is set, output
|
For e.g. if using MirroredStrategy and reduction is set, output
|
||||||
must be a `PerReplica` value.
|
must be a `PerReplica` value.
|
||||||
The aggregation method is also recorded in a dictionary
|
The reduce method is also recorded in a dictionary
|
||||||
`_last_step_outputs_aggregations` for later interpreting of the
|
`_last_step_outputs_reduce_ops` for later interpreting of the
|
||||||
outputs as already reduced or not.
|
outputs as already reduced or not.
|
||||||
# TODO(priyag): Change aggregation type used here.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if distribution_strategy_context.get_cross_replica_context():
|
if distribution_strategy_context.get_cross_replica_context():
|
||||||
self._last_step_outputs_aggregations[name] = aggregation
|
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||||
if aggregation is variables_lib.VariableAggregation.NONE:
|
if reduce_op is None:
|
||||||
self._last_step_outputs[name] = output
|
self._last_step_outputs[name] = output
|
||||||
else:
|
else:
|
||||||
distribution = distribution_strategy_context.get_distribution_strategy()
|
distribution = distribution_strategy_context.get_distribution_strategy()
|
||||||
self._last_step_outputs[name] = distribution.reduce(
|
self._last_step_outputs[name] = distribution.reduce(
|
||||||
aggregation, output, destinations="/device:CPU:0")
|
reduce_op, output, destinations="/device:CPU:0")
|
||||||
else:
|
else:
|
||||||
assert aggregation is not variables_lib.VariableAggregation.NONE
|
assert reduce_op is not None
|
||||||
def merge_fn(distribution, value):
|
def merge_fn(distribution, value):
|
||||||
self._last_step_outputs[name] = distribution.reduce(
|
self._last_step_outputs[name] = distribution.reduce(
|
||||||
aggregation, value, destinations="/device:CPU:0")
|
reduce_op, value, destinations="/device:CPU:0")
|
||||||
# Setting this inside the `merge_fn` because all replicas share the same
|
# Setting this inside the `merge_fn` because all replicas share the same
|
||||||
# context object, so it's more robust to set it only once (even if all
|
# context object, so it's more robust to set it only once (even if all
|
||||||
# the replicas are trying to set the same value).
|
# the replicas are trying to set the same value).
|
||||||
self._last_step_outputs_aggregations[name] = aggregation
|
self._last_step_outputs_reduce_ops[name] = reduce_op
|
||||||
|
|
||||||
distribution_strategy_context.get_replica_context().merge_call(
|
distribution_strategy_context.get_replica_context().merge_call(
|
||||||
merge_fn, output)
|
merge_fn, output)
|
||||||
@ -1588,7 +1584,7 @@ class MultiStepContext(object):
|
|||||||
else:
|
else:
|
||||||
def merge_fn(distribution, value):
|
def merge_fn(distribution, value):
|
||||||
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
||||||
# in a list as aggregation doesn't make sense on non tensors.
|
# in a list as reduction doesn't make sense on non tensors.
|
||||||
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
||||||
distribution_strategy_context.get_replica_context().merge_call(
|
distribution_strategy_context.get_replica_context().merge_call(
|
||||||
merge_fn, output)
|
merge_fn, output)
|
||||||
|
@ -56,6 +56,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -848,8 +849,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
"""Scale loss for the number of replicas."""
|
"""Scale loss for the number of replicas."""
|
||||||
if scale_loss_by_num_replicas is None:
|
if scale_loss_by_num_replicas is None:
|
||||||
scale_loss_by_num_replicas = (
|
scale_loss_by_num_replicas = (
|
||||||
distribute_lib.get_loss_reduction() == variable_scope
|
distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN)
|
||||||
.VariableAggregation.MEAN)
|
|
||||||
if scale_loss_by_num_replicas:
|
if scale_loss_by_num_replicas:
|
||||||
num_replicas = \
|
num_replicas = \
|
||||||
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
||||||
@ -928,7 +928,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
|
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
|
||||||
"""`apply_gradients` for use with a `DistributionStrategy`."""
|
"""`apply_gradients` for use with a `DistributionStrategy`."""
|
||||||
reduced_grads = distribution.batch_reduce(
|
reduced_grads = distribution.batch_reduce(
|
||||||
variable_scope.VariableAggregation.SUM, grads_and_vars)
|
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||||
var_list = [v for _, v in grads_and_vars]
|
var_list = [v for _, v in grads_and_vars]
|
||||||
grads_and_vars = zip(reduced_grads, var_list)
|
grads_and_vars = zip(reduced_grads, var_list)
|
||||||
|
|
||||||
|
@ -3476,6 +3476,7 @@ py_library(
|
|||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
@ -146,6 +146,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":backend",
|
":backend",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/keras/optimizer_v2",
|
"//tensorflow/python/keras/optimizer_v2",
|
||||||
"//tensorflow/python/training/checkpointable:data_structures",
|
"//tensorflow/python/training/checkpointable:data_structures",
|
||||||
"//tensorflow/tools/docs:doc_controls",
|
"//tensorflow/tools/docs:doc_controls",
|
||||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||||||
import enum
|
import enum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -34,7 +35,6 @@ from tensorflow.python.keras.engine import distributed_training_utils
|
|||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import distribute as distribute_lib
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -288,12 +288,12 @@ def _experimental_fit_loop(
|
|||||||
|
|
||||||
for label, output in zip(out_labels, combined_fn.outputs):
|
for label, output in zip(out_labels, combined_fn.outputs):
|
||||||
if label == 'loss':
|
if label == 'loss':
|
||||||
aggregation = distribute_lib.get_loss_reduction()
|
reduce_op = distribute_lib.get_loss_reduction()
|
||||||
else:
|
else:
|
||||||
# We aggregate all other metrics using mean for now. This is temporary
|
# We reduce all other metrics using mean for now. This is temporary
|
||||||
# workaround until new metrics are in place.
|
# workaround until new metrics are in place.
|
||||||
aggregation = variable_scope.VariableAggregation.MEAN
|
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
||||||
ctx.set_last_step_output(label, output, aggregation)
|
ctx.set_last_step_output(label, output, reduce_op)
|
||||||
|
|
||||||
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
|
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
|
||||||
# feed_dict, session kwargs, run options, run_metadata for now. These should
|
# feed_dict, session kwargs, run options, run_metadata for now. These should
|
||||||
@ -571,12 +571,12 @@ def _experimental_test_loop(model, iterator, verbose=0, steps=None,
|
|||||||
|
|
||||||
for label, output in zip(model.metrics_names, combined_fn.outputs):
|
for label, output in zip(model.metrics_names, combined_fn.outputs):
|
||||||
if label == 'loss':
|
if label == 'loss':
|
||||||
aggregation = distribute_lib.get_loss_reduction()
|
reduce_op = distribute_lib.get_loss_reduction()
|
||||||
else:
|
else:
|
||||||
# We aggregate all other metrics using mean for now. This is temporary
|
# We reduce all other metrics using mean for now. This is temporary
|
||||||
# workaround until new metrics are in place.
|
# workaround until new metrics are in place.
|
||||||
aggregation = variable_scope.VariableAggregation.MEAN
|
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
||||||
ctx.set_last_step_output(label, output, aggregation)
|
ctx.set_last_step_output(label, output, reduce_op)
|
||||||
|
|
||||||
return combined_fn.updates_op
|
return combined_fn.updates_op
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ py_library(
|
|||||||
"//tensorflow/python:state_ops",
|
"//tensorflow/python:state_ops",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -32,7 +33,6 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras import initializers
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import variable_scope
|
|
||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import distribution_strategy_context
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
@ -580,7 +580,7 @@ def merge_grads(grads_and_vars):
|
|||||||
|
|
||||||
def merge_grad_fn(strategy, grads_and_vars):
|
def merge_grad_fn(strategy, grads_and_vars):
|
||||||
reduced_grads = strategy.batch_reduce(
|
reduced_grads = strategy.batch_reduce(
|
||||||
variable_scope.VariableAggregation.MEAN, grads_and_vars)
|
ds_reduce_util.ReduceOp.MEAN, grads_and_vars)
|
||||||
return reduced_grads
|
return reduced_grads
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
return distribution_strategy_context.get_replica_context().merge_call(
|
||||||
|
@ -71,11 +71,11 @@ class UpdateContext(object):
|
|||||||
|
|
||||||
|
|
||||||
def get_loss_reduction():
|
def get_loss_reduction():
|
||||||
"""Reduce `aggregation` corresponding to the last loss reduction."""
|
"""Reduce op corresponding to the last loss reduction."""
|
||||||
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
||||||
if loss_reduction == losses_impl.Reduction.SUM:
|
if loss_reduction == losses_impl.Reduction.SUM:
|
||||||
return variable_scope.VariableAggregation.SUM
|
return reduce_util.ReduceOp.SUM
|
||||||
return variable_scope.VariableAggregation.MEAN
|
return reduce_util.ReduceOp.MEAN
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -95,8 +96,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||||||
# In a replica context, we update variable using the mean of value across
|
# In a replica context, we update variable using the mean of value across
|
||||||
# replicas.
|
# replicas.
|
||||||
def merge_fn(strategy, v, value):
|
def merge_fn(strategy, v, value):
|
||||||
value = strategy.reduce(
|
value = strategy.reduce(ds_reduce_util.ReduceOp.MEAN, value, v)
|
||||||
variable_scope.VariableAggregation.MEAN, value, v)
|
|
||||||
return strategy.update(v, update_fn, value)
|
return strategy.update(v, update_fn, value)
|
||||||
|
|
||||||
return replica_context.merge_call(merge_fn, variable, value)
|
return replica_context.merge_call(merge_fn, variable, value)
|
||||||
|
@ -24,6 +24,7 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -520,8 +521,7 @@ class Optimizer(
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _scale_loss(loss_value):
|
def _scale_loss(loss_value):
|
||||||
if (distribute_lib.get_loss_reduction() ==
|
if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
|
||||||
variable_scope.VariableAggregation.MEAN):
|
|
||||||
num_replicas = \
|
num_replicas = \
|
||||||
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
distribute_ctx.get_distribution_strategy().num_replicas_in_sync
|
||||||
if num_replicas > 1:
|
if num_replicas > 1:
|
||||||
@ -658,10 +658,10 @@ class Optimizer(
|
|||||||
Returns:
|
Returns:
|
||||||
An `Operation` that applies the specified gradients across all
|
An `Operation` that applies the specified gradients across all
|
||||||
replicas. If `global_step` was not None, that operation also
|
replicas. If `global_step` was not None, that operation also
|
||||||
increments `global_step`.
|
increments `global_step`
|
||||||
"""
|
"""
|
||||||
reduced_grads = distribution.batch_reduce(
|
reduced_grads = distribution.batch_reduce(
|
||||||
variable_scope.VariableAggregation.SUM, grads_and_vars)
|
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||||
var_list = [v for _, v in grads_and_vars]
|
var_list = [v for _, v in grads_and_vars]
|
||||||
grads_and_vars = zip(reduced_grads, var_list)
|
grads_and_vars = zip(reduced_grads, var_list)
|
||||||
# Note that this is called in a cross-replica context.
|
# Note that this is called in a cross-replica context.
|
||||||
|
Loading…
Reference in New Issue
Block a user