Fix: model.add_loss(symbolic_tensor) should work in ambient eager.
PiperOrigin-RevId: 234325049
This commit is contained in:
parent
bcbe9c76fe
commit
597eeaa754
@ -696,7 +696,12 @@ class Layer(trackable.Trackable):
|
|||||||
A list of tensors.
|
A list of tensors.
|
||||||
"""
|
"""
|
||||||
collected_losses = []
|
collected_losses = []
|
||||||
if context.executing_eagerly():
|
|
||||||
|
# If any eager losses are present, we assume the model to be part of an
|
||||||
|
# eager training loop (either a custom one or the one used when
|
||||||
|
# `run_eagerly=True`), and so we always return just the eager losses in that
|
||||||
|
# case.
|
||||||
|
if self._eager_losses:
|
||||||
collected_losses.extend(self._eager_losses)
|
collected_losses.extend(self._eager_losses)
|
||||||
else:
|
else:
|
||||||
collected_losses.extend(self._losses)
|
collected_losses.extend(self._losses)
|
||||||
@ -727,6 +732,7 @@ class Layer(trackable.Trackable):
|
|||||||
Arguments:
|
Arguments:
|
||||||
losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
|
losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
|
||||||
may also be zero-argument callables which create a loss tensor.
|
may also be zero-argument callables which create a loss tensor.
|
||||||
|
Other types of input are ignored.
|
||||||
inputs: Ignored when executing eagerly. If anything other than None is
|
inputs: Ignored when executing eagerly. If anything other than None is
|
||||||
passed, it signals the losses are conditional on some of the layer's
|
passed, it signals the losses are conditional on some of the layer's
|
||||||
inputs, and thus they should only be run where these inputs are
|
inputs, and thus they should only be run where these inputs are
|
||||||
@ -752,10 +758,13 @@ class Layer(trackable.Trackable):
|
|||||||
self._callable_losses.append(
|
self._callable_losses.append(
|
||||||
functools.partial(_tag_unconditional, loss))
|
functools.partial(_tag_unconditional, loss))
|
||||||
else:
|
else:
|
||||||
if context.executing_eagerly():
|
if not tensor_util.is_tensor(loss):
|
||||||
self._eager_losses.append(_tag_unconditional(loss))
|
# Ignoring constant values as this does not affect the gradients.
|
||||||
else:
|
return
|
||||||
|
if tf_utils.is_symbolic_tensor(loss):
|
||||||
self._losses.append(_tag_unconditional(loss))
|
self._losses.append(_tag_unconditional(loss))
|
||||||
|
else:
|
||||||
|
self._eager_losses.append(_tag_unconditional(loss))
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
def add_metric(self, value, aggregation=None, name=None):
|
def add_metric(self, value, aggregation=None, name=None):
|
||||||
|
@ -531,7 +531,12 @@ class Network(base_layer.Layer):
|
|||||||
@property
|
@property
|
||||||
def _unfiltered_losses(self):
|
def _unfiltered_losses(self):
|
||||||
losses = []
|
losses = []
|
||||||
if context.executing_eagerly():
|
|
||||||
|
# If any eager losses are present, we assume the model to be part of an
|
||||||
|
# eager training loop (either a custom one or the one used when
|
||||||
|
# `run_eagerly=True`), and so we always return just the eager losses in that
|
||||||
|
# case.
|
||||||
|
if self._eager_losses:
|
||||||
losses.extend(self._eager_losses)
|
losses.extend(self._eager_losses)
|
||||||
else:
|
else:
|
||||||
losses.extend(self._losses)
|
losses.extend(self._losses)
|
||||||
|
@ -956,6 +956,44 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
callbacks=[val_counter])
|
callbacks=[val_counter])
|
||||||
self.assertEqual(val_counter.val_runs, expected_runs)
|
self.assertEqual(val_counter.val_runs, expected_runs)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_add_loss_correctness(self):
|
||||||
|
if testing_utils.should_run_eagerly():
|
||||||
|
self.skipTest('b/124303407')
|
||||||
|
|
||||||
|
class Bias(keras.layers.Layer):
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.bias = self.add_variable('bias', (1,), initializer='zeros')
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return inputs + self.bias
|
||||||
|
|
||||||
|
inputs = keras.Input(shape=(1,))
|
||||||
|
outputs = Bias()(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
targets = keras.Input(shape=(1,))
|
||||||
|
|
||||||
|
model.add_loss(
|
||||||
|
math_ops.reduce_mean(
|
||||||
|
keras.losses.mean_absolute_error(targets, outputs)))
|
||||||
|
|
||||||
|
# If we want to use the loss class instance as shown below, we will need to
|
||||||
|
# add graph scope as the reduction logic involves some eager mode checks.
|
||||||
|
with keras.backend.get_graph().as_default():
|
||||||
|
model.add_loss(keras.losses.MeanAbsoluteError()(targets, outputs))
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
keras.optimizer_v2.gradient_descent.SGD(0.033333),
|
||||||
|
loss=keras.losses.MeanAbsoluteError(),
|
||||||
|
target_tensors=[targets],
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
|
x = np.array([[0.], [1.], [2.]])
|
||||||
|
y = np.array([[0.5], [2.], [3.5]])
|
||||||
|
history = model.fit(x, y, batch_size=3, epochs=5)
|
||||||
|
self.assertAllClose(history.history['loss'], [3., 2.7, 2.4, 2.1, 1.8], 1e-3)
|
||||||
|
|
||||||
|
|
||||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -558,10 +558,15 @@ class BidirectionalTest(test.TestCase):
|
|||||||
assert len(layer.losses) == 4
|
assert len(layer.losses) == 4
|
||||||
assert len(layer.get_losses_for(None)) == 4
|
assert len(layer.get_losses_for(None)) == 4
|
||||||
assert not layer.get_losses_for(x)
|
assert not layer.get_losses_for(x)
|
||||||
|
|
||||||
|
# Create a random tensor that is not conditional on the inputs.
|
||||||
|
with keras.backend.get_graph().as_default():
|
||||||
|
const_tensor = constant_op.constant(1)
|
||||||
|
|
||||||
layer.forward_layer.add_loss(x_reachable_loss, inputs=x)
|
layer.forward_layer.add_loss(x_reachable_loss, inputs=x)
|
||||||
layer.forward_layer.add_loss(1, inputs=None)
|
layer.forward_layer.add_loss(const_tensor, inputs=None)
|
||||||
layer.backward_layer.add_loss(x_reachable_loss, inputs=x)
|
layer.backward_layer.add_loss(x_reachable_loss, inputs=x)
|
||||||
layer.backward_layer.add_loss(1, inputs=None)
|
layer.backward_layer.add_loss(const_tensor, inputs=None)
|
||||||
assert len(layer.losses) == 8
|
assert len(layer.losses) == 8
|
||||||
assert len(layer.get_losses_for(None)) == 6
|
assert len(layer.get_losses_for(None)) == 6
|
||||||
assert len(layer.get_losses_for(x)) == 2
|
assert len(layer.get_losses_for(x)) == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user