Merge pull request #37919 from reedwm/none_grad_fix
2.2-rc2 cherry-pick request: Fix crash in Model.fit() if a gradient is None
This commit is contained in:
commit
e6e5d6df2a
@ -804,6 +804,33 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
atol=1e-4,
|
||||
rtol=1e-4)
|
||||
|
||||
@combinations.generate(all_strategy_combinations_plus_run_distributed())
|
||||
def test_gradients_are_none(self, distribution):
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('None gradients are not supported in graph mode')
|
||||
|
||||
class DenseWithExtraWeight(keras.layers.Dense):
|
||||
|
||||
def build(self, input_shape):
|
||||
# Gradients w.r.t. extra_weights are None
|
||||
self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(),
|
||||
initializer='ones')
|
||||
super(DenseWithExtraWeight, self).build(input_shape)
|
||||
self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(),
|
||||
initializer='ones')
|
||||
|
||||
with distribution.scope():
|
||||
model = keras.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
|
||||
model.compile('adam', 'mse')
|
||||
|
||||
inputs = np.random.normal(size=(64, 4))
|
||||
targets = np.random.normal(size=(64, 4))
|
||||
old_kernel = model.get_weights()[1]
|
||||
model.fit(inputs, targets)
|
||||
new_kernel = model.get_weights()[1]
|
||||
self.assertNotAllEqual(old_kernel, new_kernel)
|
||||
|
||||
|
||||
class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
@ -1364,6 +1364,30 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
model.fit(x, y)
|
||||
self.assertEqual(model.optimizer.aggregate_gradients_called, True)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_gradients_are_none(self):
|
||||
|
||||
class DenseWithExtraWeight(keras.layers.Dense):
|
||||
|
||||
def build(self, input_shape):
|
||||
# Gradients w.r.t. extra_weights are None
|
||||
self.extra_weight_1 = self.add_weight('extra_weight_1', shape=(),
|
||||
initializer='ones')
|
||||
super(DenseWithExtraWeight, self).build(input_shape)
|
||||
self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(),
|
||||
initializer='ones')
|
||||
|
||||
model = keras.models.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
|
||||
# Test clipping can handle None gradients
|
||||
opt = keras.optimizer_v2.adam.Adam(clipnorm=1.0, clipvalue=1.0)
|
||||
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
inputs = np.random.normal(size=(64, 4))
|
||||
targets = np.random.normal(size=(64, 4))
|
||||
old_kernel = model.get_weights()[1]
|
||||
model.fit(inputs, targets)
|
||||
new_kernel = model.get_weights()[1]
|
||||
self.assertNotAllEqual(old_kernel, new_kernel)
|
||||
|
||||
|
||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -344,15 +344,16 @@ class OptimizerV2(trackable.Trackable):
|
||||
raise ValueError("Gradient clipping in the optimizer "
|
||||
"(by setting clipnorm or clipvalue) is currently "
|
||||
"unsupported when using a distribution strategy.")
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
grads = [None if g is None else clip_ops.clip_by_norm(g, self.clipnorm)
|
||||
for g in grads]
|
||||
if self.clipvalue is not None:
|
||||
if distribute_ctx.has_strategy():
|
||||
raise ValueError("Gradient clipping in the optimizer "
|
||||
"(by setting clipnorm or clipvalue) is currently "
|
||||
"unsupported when using a distribution strategy.")
|
||||
v = self.clipvalue
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
None if g is None else clip_ops.clip_by_value(g, -v, v) for g in grads
|
||||
]
|
||||
return grads
|
||||
|
||||
@ -521,6 +522,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
A list of all-reduced gradients.
|
||||
"""
|
||||
grads_and_vars = list(grads_and_vars)
|
||||
filtered_grads_and_vars = _filter_grads(grads_and_vars)
|
||||
def all_reduce_fn(distribution, grads_and_vars):
|
||||
return distribution.extended.batch_reduce_to(
|
||||
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||
@ -529,9 +531,22 @@ class OptimizerV2(trackable.Trackable):
|
||||
# replica context.
|
||||
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
|
||||
# is fixed.
|
||||
if grads_and_vars:
|
||||
return distribute_ctx.get_replica_context().merge_call(
|
||||
all_reduce_fn, args=(grads_and_vars,))
|
||||
if filtered_grads_and_vars:
|
||||
reduced = distribute_ctx.get_replica_context().merge_call(
|
||||
all_reduce_fn, args=(filtered_grads_and_vars,))
|
||||
else:
|
||||
reduced = []
|
||||
# Copy 'reduced' but add None gradients back in
|
||||
reduced_with_nones = []
|
||||
reduced_pos = 0
|
||||
for g, _ in grads_and_vars:
|
||||
if g is None:
|
||||
reduced_with_nones.append(None)
|
||||
else:
|
||||
reduced_with_nones.append(reduced[reduced_pos])
|
||||
reduced_pos += 1
|
||||
assert reduced_pos == len(reduced), "Failed to add all gradients"
|
||||
return reduced_with_nones
|
||||
|
||||
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
|
||||
"""`apply_gradients` using a `DistributionStrategy`."""
|
||||
|
Loading…
Reference in New Issue
Block a user