Follow up to 0e1f3de50a. This change enables keras tensor tests related to updates/losses.

- `updates` are not relevant in V2. The original `KerasTensor` change returns an empty list for updates. This change modifies the tests checking for updates to run only in v1 mode or updates the test logic as required.
- Some of the losses/add_loss tests were failing with KerasTensor because we were trying to convert KerasTensor to Tensor. This code changes/moves the conversions as required.

PiperOrigin-RevId: 315438845
Change-Id: Ic2a5341cc5f2684649e2efc006e34a33e7da31ee
This commit is contained in:
Pavithra Vijay 2020-06-09 01:09:01 -07:00 committed by TensorFlower Gardener
parent 8398b862ff
commit 8dffa4de1b
9 changed files with 53 additions and 38 deletions

View File

@ -107,7 +107,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
network.add_update(state_ops.assign_add(layer.b, x4), inputs=True) network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
self.assertEqual(len(network.updates), 7) self.assertEqual(len(network.updates), 7)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph']))
def test_get_updates_bn(self): def test_get_updates_bn(self):
x1 = input_layer_lib.Input(shape=(1,)) x1 = input_layer_lib.Input(shape=(1,))
layer = layers.BatchNormalization() layer = layers.BatchNormalization()
@ -1593,9 +1593,9 @@ class GraphUtilsTest(test.TestCase):
tf_utils.get_reachable_from_inputs([x_3]), {x_3, x_5, x_5.op}) tf_utils.get_reachable_from_inputs([x_3]), {x_3, x_5, x_5.op})
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class NestedNetworkTest(keras_parameterized.TestCase): class NestedNetworkTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nested_inputs_network(self): def test_nested_inputs_network(self):
inputs = { inputs = {
'x1': input_layer_lib.Input(shape=(1,)), 'x1': input_layer_lib.Input(shape=(1,)),
@ -1620,6 +1620,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
}) })
self.assertListEqual(output_shape.as_list(), [None, 1]) self.assertListEqual(output_shape.as_list(), [None, 1])
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nested_outputs_network(self): def test_nested_outputs_network(self):
inputs = input_layer_lib.Input(shape=(1,)) inputs = input_layer_lib.Input(shape=(1,))
outputs = { outputs = {
@ -1640,6 +1641,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
self.assertListEqual(output_shape['x+x'].as_list(), [None, 1]) self.assertListEqual(output_shape['x+x'].as_list(), [None, 1])
self.assertListEqual(output_shape['x*x'].as_list(), [None, 1]) self.assertListEqual(output_shape['x*x'].as_list(), [None, 1])
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nested_network_inside_network(self): def test_nested_network_inside_network(self):
inner_inputs = { inner_inputs = {
'x1': input_layer_lib.Input(shape=(1,)), 'x1': input_layer_lib.Input(shape=(1,)),
@ -1672,6 +1674,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
output_shape = network.compute_output_shape([(None, 1), (None, 1)]) output_shape = network.compute_output_shape([(None, 1), (None, 1)])
self.assertListEqual(output_shape.as_list(), [None, 1]) self.assertListEqual(output_shape.as_list(), [None, 1])
@combinations.generate(combinations.combine(mode=['graph']))
def test_updates_with_direct_call(self): def test_updates_with_direct_call(self):
inputs = input_layer_lib.Input(shape=(10,)) inputs = input_layer_lib.Input(shape=(10,))
x = layers.BatchNormalization()(inputs) x = layers.BatchNormalization()(inputs)
@ -1683,6 +1686,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
self.assertLen(model.updates, 4) self.assertLen(model.updates, 4)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_dict_mapping_input(self): def test_dict_mapping_input(self):
class ReturnFirst(layers.Layer): class ReturnFirst(layers.Layer):
@ -1708,6 +1712,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
res = reversed_model({'a': a_val, 'b': b_val}) res = reversed_model({'a': a_val, 'b': b_val})
self.assertAllClose(self.evaluate(res), self.evaluate(b_val)) self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_dict_mapping_single_input(self): def test_dict_mapping_single_input(self):
b = input_layer_lib.Input(shape=(1,), name='b') b = input_layer_lib.Input(shape=(1,), name='b')
outputs = b * 2 outputs = b * 2

View File

@ -690,7 +690,7 @@ class TrainingTest(keras_parameterized.TestCase):
metrics=['accuracy'], metrics=['accuracy'],
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_that_trainable_disables_updates(self): def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4)) val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4)) val_out = np.random.random((10, 4))
@ -701,13 +701,15 @@ class TrainingTest(keras_parameterized.TestCase):
model = training_module.Model(a, b) model = training_module.Model(a, b)
model.trainable = False model.trainable = False
assert not model.updates if not ops.executing_eagerly_outside_functions():
self.assertEmpty(model.updates)
model.compile( model.compile(
'sgd', 'sgd',
'mse', 'mse',
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates if not ops.executing_eagerly_outside_functions():
self.assertEmpty(model.updates)
x1 = model.predict(val_a) x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)
@ -719,7 +721,8 @@ class TrainingTest(keras_parameterized.TestCase):
'sgd', 'sgd',
'mse', 'mse',
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
assert model.updates if not ops.executing_eagerly_outside_functions():
self.assertAllGreater(len(model.updates), 0)
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a) x2 = model.predict(val_a)
@ -730,7 +733,8 @@ class TrainingTest(keras_parameterized.TestCase):
'sgd', 'sgd',
'mse', 'mse',
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates if not ops.executing_eagerly_outside_functions():
self.assertEmpty(model.updates)
x1 = model.predict(val_a) x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)

View File

@ -311,18 +311,17 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
norm(inp) norm(inp)
def test_updates_in_wrap_function(self): def test_updates_in_wrap_function(self):
with context.eager_mode(): layer = normalization.BatchNormalization()
layer = keras.layers.BatchNormalization()
def my_func(): def my_func():
x = array_ops.ones((10, 1)) x = array_ops.ones((10, 1))
return layer(x, training=True) return layer(x, training=True)
wrapped_fn = wrap_function.wrap_function(my_func, []) wrapped_fn = wrap_function.wrap_function(my_func, [])
wrapped_fn() wrapped_fn()
# Updates should be tracked in a `wrap_function`. # Updates should be tracked in a `wrap_function`.
self.assertLen(layer.updates, 2) self.assertLen(layer.updates, 2)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self): def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):

View File

@ -79,7 +79,7 @@ def _get_model(input_shape=(4,)):
class TestModelCloning(keras_parameterized.TestCase): class TestModelCloning(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'has_input_layer', {'testcase_name': 'has_input_layer',
'input_shape': (4,), 'input_shape': (4,),
@ -122,7 +122,7 @@ class TestModelCloning(keras_parameterized.TestCase):
isinstance(new_model._layers[0], keras.layers.InputLayer), isinstance(new_model._layers[0], keras.layers.InputLayer),
add_input_layer) add_input_layer)
self.assertEqual(new_model._is_graph_network, model._is_graph_network) self.assertEqual(new_model._is_graph_network, model._is_graph_network)
if input_shape: if input_shape and not ops.executing_eagerly_outside_functions():
# update ops from batch norm needs to be included # update ops from batch norm needs to be included
self.assertGreaterEqual(len(new_model.updates), 2) self.assertGreaterEqual(len(new_model.updates), 2)
@ -142,7 +142,7 @@ class TestModelCloning(keras_parameterized.TestCase):
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer) self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
self.assertTrue(new_model._is_graph_network) self.assertTrue(new_model._is_graph_network)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'clone_weights', 'share_weights': False}, {'testcase_name': 'clone_weights', 'share_weights': False},
{'testcase_name': 'share_weights', 'share_weights': True}, {'testcase_name': 'share_weights', 'share_weights': True},
@ -173,7 +173,8 @@ class TestModelCloning(keras_parameterized.TestCase):
# With placeholder creation # With placeholder creation
new_model = clone_fn(model) new_model = clone_fn(model)
self.assertGreaterEqual(len(new_model.updates), 2) if not ops.executing_eagerly_outside_functions():
self.assertGreaterEqual(len(new_model.updates), 2)
new_model.compile( new_model.compile(
testing_utils.get_v2_optimizer('rmsprop'), testing_utils.get_v2_optimizer('rmsprop'),
'mse', 'mse',
@ -185,7 +186,8 @@ class TestModelCloning(keras_parameterized.TestCase):
input_b = keras.Input(shape=(4,), name='b') input_b = keras.Input(shape=(4,), name='b')
new_model = keras.models.clone_model( new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b]) model, input_tensors=[input_a, input_b])
self.assertLen(new_model.updates, 2) if not ops.executing_eagerly_outside_functions():
self.assertLen(new_model.updates, 2)
new_model.compile( new_model.compile(
testing_utils.get_v2_optimizer('rmsprop'), testing_utils.get_v2_optimizer('rmsprop'),
'mse', 'mse',

View File

@ -69,7 +69,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32') self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32') self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_loss_on_model_fit(self): def test_loss_on_model_fit(self):
inputs = Input(shape=(1,)) inputs = Input(shape=(1,))
targets = Input(shape=(1,)) targets = Input(shape=(1,))
@ -85,8 +85,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3) self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential']) @keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True, @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
always_skip_v1=True)
def test_loss_callable_on_model_fit(self): def test_loss_callable_on_model_fit(self):
model = testing_utils.get_model_from_layers([testing_utils.Bias()], model = testing_utils.get_model_from_layers([testing_utils.Bias()],
input_shape=(1,)) input_shape=(1,))
@ -145,7 +144,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = [train_step(self.x, self.y) for _ in range(5)] loss = [train_step(self.x, self.y) for _ in range(5)]
self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3) self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_loss_with_sample_weight_on_model_fit(self): def test_loss_with_sample_weight_on_model_fit(self):
inputs = Input(shape=(1,)) inputs = Input(shape=(1,))
targets = Input(shape=(1,)) targets = Input(shape=(1,))
@ -182,7 +181,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = [train_step(self.x, self.y, self.w) for _ in range(5)] loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3) self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_loss_with_sample_weight_in_model_call(self): def test_loss_with_sample_weight_in_model_call(self):
class MyModel(Model): class MyModel(Model):
@ -210,7 +209,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
eval_out = model.evaluate([self.x, self.y, self.w]) eval_out = model.evaluate([self.x, self.y, self.w])
self.assertAlmostEqual(eval_out, 1.0, 3) self.assertAlmostEqual(eval_out, 1.0, 3)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_loss_with_sample_weight_in_layer_call(self): def test_loss_with_sample_weight_in_layer_call(self):
class MyLayer(layers.Layer): class MyLayer(layers.Layer):
@ -245,7 +244,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
output = model.test_on_batch([self.x, self.y, self.w]) output = model.test_on_batch([self.x, self.y, self.w])
self.assertAlmostEqual(output, 1.0, 3) self.assertAlmostEqual(output, 1.0, 3)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_loss_on_layer(self): def test_loss_on_layer(self):
class MyLayer(layers.Layer): class MyLayer(layers.Layer):
@ -266,7 +265,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(loss, 2 * 3) self.assertEqual(loss, 2 * 3)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
def test_activity_regularizer(self): def test_activity_regularizer(self):
loss = {} loss = {}
@ -300,7 +299,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss[reg] = model.evaluate(x, y) loss[reg] = model.evaluate(x, y)
self.assertLess(loss[None], loss['l2']) self.assertLess(loss[None], loss['l2'])
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
def test_activity_regularizer_loss_value(self): def test_activity_regularizer_loss_value(self):
layer = layers.Dense( layer = layers.Dense(
@ -319,7 +318,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = model.test_on_batch(x) loss = model.test_on_batch(x)
self.assertAlmostEqual(0.01, loss, places=4) self.assertAlmostEqual(0.01, loss, places=4)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_activity_regularizer_batch_independent(self): def test_activity_regularizer_batch_independent(self):
inputs = layers.Input(shape=(10,)) inputs = layers.Input(shape=(10,))
x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs) x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
@ -335,7 +334,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32')) loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4) self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_with_shared_layer(self): def test_with_shared_layer(self):
class LayerWithLoss(layers.Layer): class LayerWithLoss(layers.Layer):
@ -352,7 +351,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(m2.losses), 2) self.assertEqual(len(m2.losses), 2)
self.assertAllClose(m2.losses, [6, 12]) self.assertAllClose(m2.losses, [6, 12])
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_with_shared_nested_layer(self): def test_with_shared_nested_layer(self):
class LayerWithLoss(layers.Layer): class LayerWithLoss(layers.Layer):
@ -378,7 +377,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(m2.losses), 2) self.assertEqual(len(m2.losses), 2)
self.assertAllClose(m2.losses, [6, 12]) self.assertAllClose(m2.losses, [6, 12])
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_clear_losses(self): def test_clear_losses(self):
class LayerWithSharedNestedLossLayer(layers.Layer): class LayerWithSharedNestedLossLayer(layers.Layer):
@ -429,7 +428,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(model.get_losses_for(x4)), 2) self.assertEqual(len(model.get_losses_for(x4)), 2)
self.assertEqual(len(model.get_losses_for(None)), 1) self.assertEqual(len(model.get_losses_for(None)), 1)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_invalid_constant_input(self): def test_invalid_constant_input(self):
with context.eager_mode(): with context.eager_mode():
inputs = Input(shape=(1,)) inputs = Input(shape=(1,))
@ -440,7 +439,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
'Expected a symbolic Tensors or a callable for the loss value'): 'Expected a symbolic Tensors or a callable for the loss value'):
model.add_loss(1.) model.add_loss(1.)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True) @keras_parameterized.run_all_keras_modes
def test_invalid_variable_input(self): def test_invalid_variable_input(self):
with context.eager_mode(): with context.eager_mode():
inputs = Input(shape=(1,)) inputs = Input(shape=(1,))

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import loss_reduction from tensorflow.python.ops.losses import loss_reduction
@ -101,8 +102,12 @@ def compute_weighted_loss(losses,
# to multiple replicas. Used only for estimator + v1 optimizer flow. # to multiple replicas. Used only for estimator + v1 optimizer flow.
ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
losses = ops.convert_to_tensor_v2(losses) if not isinstance(losses, keras_tensor.KerasTensor):
losses = ops.convert_to_tensor_v2(losses)
input_dtype = losses.dtype input_dtype = losses.dtype
if not isinstance(sample_weight, keras_tensor.KerasTensor):
sample_weight = ops.convert_to_tensor_v2(sample_weight)
weighted_losses = tf_losses_utils.scale_losses_by_sample_weight( weighted_losses = tf_losses_utils.scale_losses_by_sample_weight(
losses, sample_weight) losses, sample_weight)
# Apply reduction function to the individual weighted losses. # Apply reduction function to the individual weighted losses.

View File

@ -346,6 +346,7 @@ def update_confusion_matrix_variables(variables_to_update,
y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions( y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
y_pred, y_true) y_pred, y_true)
else: else:
sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype)
y_pred, y_true, sample_weight = ( y_pred, y_true, sample_weight = (
tf_losses_utils.squeeze_or_expand_dimensions( tf_losses_utils.squeeze_or_expand_dimensions(
y_pred, y_true, sample_weight=sample_weight)) y_pred, y_true, sample_weight=sample_weight))

View File

@ -84,7 +84,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
if sample_weight is None: if sample_weight is None:
return y_pred, y_true return y_pred, y_true
sample_weight = ops.convert_to_tensor(sample_weight)
weights_shape = sample_weight.shape weights_shape = sample_weight.shape
weights_rank = weights_shape.ndims weights_rank = weights_shape.ndims
if weights_rank == 0: # If weights is scalar, do nothing. if weights_rank == 0: # If weights is scalar, do nothing.

View File

@ -427,6 +427,7 @@ def compute_average_loss(per_example_loss,
with losses_util.check_per_example_loss_rank(per_example_loss): with losses_util.check_per_example_loss_rank(per_example_loss):
if sample_weight is not None: if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight)
per_example_loss = losses_util.scale_losses_by_sample_weight( per_example_loss = losses_util.scale_losses_by_sample_weight(
per_example_loss, sample_weight) per_example_loss, sample_weight)
per_example_loss = math_ops.cast(per_example_loss, input_dtype) per_example_loss = math_ops.cast(per_example_loss, input_dtype)