diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 2fc669be2c8..46c913256a5 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2591,6 +2591,9 @@ class AddLoss(Layer): """ def __init__(self, unconditional, **kwargs): + # Pass autocast=False, as there is no reason to cast loss to a different + # dtype. + kwargs['autocast'] = False super(AddLoss, self).__init__(**kwargs) self.unconditional = unconditional diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index a64e0f68149..784d7b304dd 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -764,6 +764,17 @@ class KerasModelTest(keras_parameterized.TestCase): 'optimizer" must be an instance of '): model.compile(optimizers.SGD(1.), 'mse') + @test_util.run_in_graph_and_eager_modes + @testing_utils.enable_v2_dtype_behavior + def test_functional_model_loss_dtype(self): + with policy.policy_scope('float16'): + x = layers.Input(shape=(1,)) + y = AddLayer()(x) + model = models.Model(x, y) + model.add_loss(math_ops.cast(y, 'float32')) + # The loss should not be casted to the policy's dtype. + self.assertEqual(model.losses[0].dtype, 'float32') + @parameterized.named_parameters( { 'testcase_name': 'base',