Add single tensor gradient support for DynamicLossScale
PiperOrigin-RevId: 257209230
This commit is contained in:
parent
67a48c79a3
commit
45925ad3d0
@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -93,10 +94,10 @@ class LossScale(trackable.Trackable):
|
|||||||
cross-replica context.
|
cross-replica context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grads: A list of unscaled gradients, each which is the gradient of the
|
grads: A nested structure of unscaled gradients, each which is the
|
||||||
loss with respect to a weight. The gradients should have already been
|
gradient of the loss with respect to a weight. The gradients should have
|
||||||
divided by the loss scale being before passed to this function. 'None'
|
already been divided by the loss scale being before passed to this
|
||||||
gradients are accepted, and are ignored.
|
function. 'None' gradients are accepted, and are ignored.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
update_op: In eager mode, None. In graph mode, an op to update the loss
|
update_op: In eager mode, None. In graph mode, an op to update the loss
|
||||||
@ -328,6 +329,7 @@ class DynamicLossScale(LossScale):
|
|||||||
|
|
||||||
def update(self, grads):
|
def update(self, grads):
|
||||||
"""Updates loss scale based on if gradients are finite in current step."""
|
"""Updates loss scale based on if gradients are finite in current step."""
|
||||||
|
grads = nest.flatten(grads)
|
||||||
if distribution_strategy_context.has_strategy():
|
if distribution_strategy_context.has_strategy():
|
||||||
distribution = distribution_strategy_context.get_cross_replica_context()
|
distribution = distribution_strategy_context.get_cross_replica_context()
|
||||||
|
|
||||||
|
@ -264,6 +264,15 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
|||||||
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
|
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
|
||||||
self._test_helper(inputs, expected_outputs, init_loss_scale)
|
self._test_helper(inputs, expected_outputs, init_loss_scale)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_single_tensor_gradient(self, strategy_fn):
|
||||||
|
with strategy_fn().scope():
|
||||||
|
loss_scale = loss_scale_module.DynamicLossScale()
|
||||||
|
grad = constant_op.constant(4.0)
|
||||||
|
_, should_apply = loss_scale.update(grad)
|
||||||
|
self.assertTrue(self.evaluate(should_apply))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
loss_scale = loss_scale_module.DynamicLossScale(
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
|
Loading…
Reference in New Issue
Block a user