Do not cast inputs to AddLoss layers
This means tensors passed to Model.add_loss will no longer be cast to floatx. PiperOrigin-RevId: 264287945
This commit is contained in:
parent
75a9d99941
commit
a3a9fd34e9
@ -2591,6 +2591,9 @@ class AddLoss(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, unconditional, **kwargs):
|
||||
# Pass autocast=False, as there is no reason to cast loss to a different
|
||||
# dtype.
|
||||
kwargs['autocast'] = False
|
||||
super(AddLoss, self).__init__(**kwargs)
|
||||
self.unconditional = unconditional
|
||||
|
||||
|
||||
@ -764,6 +764,17 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
'optimizer" must be an instance of '):
|
||||
model.compile(optimizers.SGD(1.), 'mse')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_functional_model_loss_dtype(self):
|
||||
with policy.policy_scope('float16'):
|
||||
x = layers.Input(shape=(1,))
|
||||
y = AddLayer()(x)
|
||||
model = models.Model(x, y)
|
||||
model.add_loss(math_ops.cast(y, 'float32'))
|
||||
# The loss should not be casted to the policy's dtype.
|
||||
self.assertEqual(model.losses[0].dtype, 'float32')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
'testcase_name': 'base',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user