Follow up to 0e1f3de50a
. This change enables keras tensor tests related to updates/losses.
- `updates` are not relevant in V2. The original `KerasTensor` change returns an empty list for updates. This change modifies the tests checking for updates to run only in v1 mode or updates the test logic as required. - Some of the losses/add_loss tests were failing with KerasTensor because we were trying to convert KerasTensor to Tensor. This code changes/moves the conversions as required. PiperOrigin-RevId: 315438845 Change-Id: Ic2a5341cc5f2684649e2efc006e34a33e7da31ee
This commit is contained in:
parent
8398b862ff
commit
8dffa4de1b
@ -107,7 +107,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
|
network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
|
||||||
self.assertEqual(len(network.updates), 7)
|
self.assertEqual(len(network.updates), 7)
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph']))
|
||||||
def test_get_updates_bn(self):
|
def test_get_updates_bn(self):
|
||||||
x1 = input_layer_lib.Input(shape=(1,))
|
x1 = input_layer_lib.Input(shape=(1,))
|
||||||
layer = layers.BatchNormalization()
|
layer = layers.BatchNormalization()
|
||||||
@ -1593,9 +1593,9 @@ class GraphUtilsTest(test.TestCase):
|
|||||||
tf_utils.get_reachable_from_inputs([x_3]), {x_3, x_5, x_5.op})
|
tf_utils.get_reachable_from_inputs([x_3]), {x_3, x_5, x_5.op})
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
|
||||||
class NestedNetworkTest(keras_parameterized.TestCase):
|
class NestedNetworkTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_nested_inputs_network(self):
|
def test_nested_inputs_network(self):
|
||||||
inputs = {
|
inputs = {
|
||||||
'x1': input_layer_lib.Input(shape=(1,)),
|
'x1': input_layer_lib.Input(shape=(1,)),
|
||||||
@ -1620,6 +1620,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
})
|
})
|
||||||
self.assertListEqual(output_shape.as_list(), [None, 1])
|
self.assertListEqual(output_shape.as_list(), [None, 1])
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_nested_outputs_network(self):
|
def test_nested_outputs_network(self):
|
||||||
inputs = input_layer_lib.Input(shape=(1,))
|
inputs = input_layer_lib.Input(shape=(1,))
|
||||||
outputs = {
|
outputs = {
|
||||||
@ -1640,6 +1641,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
self.assertListEqual(output_shape['x+x'].as_list(), [None, 1])
|
self.assertListEqual(output_shape['x+x'].as_list(), [None, 1])
|
||||||
self.assertListEqual(output_shape['x*x'].as_list(), [None, 1])
|
self.assertListEqual(output_shape['x*x'].as_list(), [None, 1])
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_nested_network_inside_network(self):
|
def test_nested_network_inside_network(self):
|
||||||
inner_inputs = {
|
inner_inputs = {
|
||||||
'x1': input_layer_lib.Input(shape=(1,)),
|
'x1': input_layer_lib.Input(shape=(1,)),
|
||||||
@ -1672,6 +1674,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
output_shape = network.compute_output_shape([(None, 1), (None, 1)])
|
output_shape = network.compute_output_shape([(None, 1), (None, 1)])
|
||||||
self.assertListEqual(output_shape.as_list(), [None, 1])
|
self.assertListEqual(output_shape.as_list(), [None, 1])
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph']))
|
||||||
def test_updates_with_direct_call(self):
|
def test_updates_with_direct_call(self):
|
||||||
inputs = input_layer_lib.Input(shape=(10,))
|
inputs = input_layer_lib.Input(shape=(10,))
|
||||||
x = layers.BatchNormalization()(inputs)
|
x = layers.BatchNormalization()(inputs)
|
||||||
@ -1683,6 +1686,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertLen(model.updates, 4)
|
self.assertLen(model.updates, 4)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_dict_mapping_input(self):
|
def test_dict_mapping_input(self):
|
||||||
|
|
||||||
class ReturnFirst(layers.Layer):
|
class ReturnFirst(layers.Layer):
|
||||||
@ -1708,6 +1712,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
|||||||
res = reversed_model({'a': a_val, 'b': b_val})
|
res = reversed_model({'a': a_val, 'b': b_val})
|
||||||
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
|
self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_dict_mapping_single_input(self):
|
def test_dict_mapping_single_input(self):
|
||||||
b = input_layer_lib.Input(shape=(1,), name='b')
|
b = input_layer_lib.Input(shape=(1,), name='b')
|
||||||
outputs = b * 2
|
outputs = b * 2
|
||||||
|
@ -690,7 +690,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
metrics=['accuracy'],
|
metrics=['accuracy'],
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_that_trainable_disables_updates(self):
|
def test_that_trainable_disables_updates(self):
|
||||||
val_a = np.random.random((10, 4))
|
val_a = np.random.random((10, 4))
|
||||||
val_out = np.random.random((10, 4))
|
val_out = np.random.random((10, 4))
|
||||||
@ -701,13 +701,15 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
model = training_module.Model(a, b)
|
model = training_module.Model(a, b)
|
||||||
|
|
||||||
model.trainable = False
|
model.trainable = False
|
||||||
assert not model.updates
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
self.assertEmpty(model.updates)
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
'sgd',
|
'sgd',
|
||||||
'mse',
|
'mse',
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
assert not model.updates
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
self.assertEmpty(model.updates)
|
||||||
|
|
||||||
x1 = model.predict(val_a)
|
x1 = model.predict(val_a)
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
@ -719,7 +721,8 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
'sgd',
|
'sgd',
|
||||||
'mse',
|
'mse',
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
assert model.updates
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
self.assertAllGreater(len(model.updates), 0)
|
||||||
|
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
x2 = model.predict(val_a)
|
x2 = model.predict(val_a)
|
||||||
@ -730,7 +733,8 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
'sgd',
|
'sgd',
|
||||||
'mse',
|
'mse',
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
assert not model.updates
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
self.assertEmpty(model.updates)
|
||||||
|
|
||||||
x1 = model.predict(val_a)
|
x1 = model.predict(val_a)
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
|
@ -311,8 +311,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
|||||||
norm(inp)
|
norm(inp)
|
||||||
|
|
||||||
def test_updates_in_wrap_function(self):
|
def test_updates_in_wrap_function(self):
|
||||||
with context.eager_mode():
|
layer = normalization.BatchNormalization()
|
||||||
layer = keras.layers.BatchNormalization()
|
|
||||||
|
|
||||||
def my_func():
|
def my_func():
|
||||||
x = array_ops.ones((10, 1))
|
x = array_ops.ones((10, 1))
|
||||||
|
@ -79,7 +79,7 @@ def _get_model(input_shape=(4,)):
|
|||||||
|
|
||||||
class TestModelCloning(keras_parameterized.TestCase):
|
class TestModelCloning(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
{'testcase_name': 'has_input_layer',
|
{'testcase_name': 'has_input_layer',
|
||||||
'input_shape': (4,),
|
'input_shape': (4,),
|
||||||
@ -122,7 +122,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
isinstance(new_model._layers[0], keras.layers.InputLayer),
|
isinstance(new_model._layers[0], keras.layers.InputLayer),
|
||||||
add_input_layer)
|
add_input_layer)
|
||||||
self.assertEqual(new_model._is_graph_network, model._is_graph_network)
|
self.assertEqual(new_model._is_graph_network, model._is_graph_network)
|
||||||
if input_shape:
|
if input_shape and not ops.executing_eagerly_outside_functions():
|
||||||
# update ops from batch norm needs to be included
|
# update ops from batch norm needs to be included
|
||||||
self.assertGreaterEqual(len(new_model.updates), 2)
|
self.assertGreaterEqual(len(new_model.updates), 2)
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
|
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
|
||||||
self.assertTrue(new_model._is_graph_network)
|
self.assertTrue(new_model._is_graph_network)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
{'testcase_name': 'clone_weights', 'share_weights': False},
|
{'testcase_name': 'clone_weights', 'share_weights': False},
|
||||||
{'testcase_name': 'share_weights', 'share_weights': True},
|
{'testcase_name': 'share_weights', 'share_weights': True},
|
||||||
@ -173,6 +173,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
# With placeholder creation
|
# With placeholder creation
|
||||||
new_model = clone_fn(model)
|
new_model = clone_fn(model)
|
||||||
|
if not ops.executing_eagerly_outside_functions():
|
||||||
self.assertGreaterEqual(len(new_model.updates), 2)
|
self.assertGreaterEqual(len(new_model.updates), 2)
|
||||||
new_model.compile(
|
new_model.compile(
|
||||||
testing_utils.get_v2_optimizer('rmsprop'),
|
testing_utils.get_v2_optimizer('rmsprop'),
|
||||||
@ -185,6 +186,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
input_b = keras.Input(shape=(4,), name='b')
|
input_b = keras.Input(shape=(4,), name='b')
|
||||||
new_model = keras.models.clone_model(
|
new_model = keras.models.clone_model(
|
||||||
model, input_tensors=[input_a, input_b])
|
model, input_tensors=[input_a, input_b])
|
||||||
|
if not ops.executing_eagerly_outside_functions():
|
||||||
self.assertLen(new_model.updates, 2)
|
self.assertLen(new_model.updates, 2)
|
||||||
new_model.compile(
|
new_model.compile(
|
||||||
testing_utils.get_v2_optimizer('rmsprop'),
|
testing_utils.get_v2_optimizer('rmsprop'),
|
||||||
|
@ -69,7 +69,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
|
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
|
||||||
self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
|
self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_loss_on_model_fit(self):
|
def test_loss_on_model_fit(self):
|
||||||
inputs = Input(shape=(1,))
|
inputs = Input(shape=(1,))
|
||||||
targets = Input(shape=(1,))
|
targets = Input(shape=(1,))
|
||||||
@ -85,8 +85,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
|
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True,
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
always_skip_v1=True)
|
|
||||||
def test_loss_callable_on_model_fit(self):
|
def test_loss_callable_on_model_fit(self):
|
||||||
model = testing_utils.get_model_from_layers([testing_utils.Bias()],
|
model = testing_utils.get_model_from_layers([testing_utils.Bias()],
|
||||||
input_shape=(1,))
|
input_shape=(1,))
|
||||||
@ -145,7 +144,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss = [train_step(self.x, self.y) for _ in range(5)]
|
loss = [train_step(self.x, self.y) for _ in range(5)]
|
||||||
self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
|
self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_loss_with_sample_weight_on_model_fit(self):
|
def test_loss_with_sample_weight_on_model_fit(self):
|
||||||
inputs = Input(shape=(1,))
|
inputs = Input(shape=(1,))
|
||||||
targets = Input(shape=(1,))
|
targets = Input(shape=(1,))
|
||||||
@ -182,7 +181,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
|
loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
|
||||||
self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_loss_with_sample_weight_in_model_call(self):
|
def test_loss_with_sample_weight_in_model_call(self):
|
||||||
|
|
||||||
class MyModel(Model):
|
class MyModel(Model):
|
||||||
@ -210,7 +209,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
eval_out = model.evaluate([self.x, self.y, self.w])
|
eval_out = model.evaluate([self.x, self.y, self.w])
|
||||||
self.assertAlmostEqual(eval_out, 1.0, 3)
|
self.assertAlmostEqual(eval_out, 1.0, 3)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_loss_with_sample_weight_in_layer_call(self):
|
def test_loss_with_sample_weight_in_layer_call(self):
|
||||||
|
|
||||||
class MyLayer(layers.Layer):
|
class MyLayer(layers.Layer):
|
||||||
@ -245,7 +244,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
output = model.test_on_batch([self.x, self.y, self.w])
|
output = model.test_on_batch([self.x, self.y, self.w])
|
||||||
self.assertAlmostEqual(output, 1.0, 3)
|
self.assertAlmostEqual(output, 1.0, 3)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_loss_on_layer(self):
|
def test_loss_on_layer(self):
|
||||||
|
|
||||||
class MyLayer(layers.Layer):
|
class MyLayer(layers.Layer):
|
||||||
@ -266,7 +265,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(loss, 2 * 3)
|
self.assertEqual(loss, 2 * 3)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_activity_regularizer(self):
|
def test_activity_regularizer(self):
|
||||||
loss = {}
|
loss = {}
|
||||||
@ -300,7 +299,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss[reg] = model.evaluate(x, y)
|
loss[reg] = model.evaluate(x, y)
|
||||||
self.assertLess(loss[None], loss['l2'])
|
self.assertLess(loss[None], loss['l2'])
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_activity_regularizer_loss_value(self):
|
def test_activity_regularizer_loss_value(self):
|
||||||
layer = layers.Dense(
|
layer = layers.Dense(
|
||||||
@ -319,7 +318,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss = model.test_on_batch(x)
|
loss = model.test_on_batch(x)
|
||||||
self.assertAlmostEqual(0.01, loss, places=4)
|
self.assertAlmostEqual(0.01, loss, places=4)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_activity_regularizer_batch_independent(self):
|
def test_activity_regularizer_batch_independent(self):
|
||||||
inputs = layers.Input(shape=(10,))
|
inputs = layers.Input(shape=(10,))
|
||||||
x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
|
x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
|
||||||
@ -335,7 +334,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
|
loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
|
||||||
self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
|
self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_with_shared_layer(self):
|
def test_with_shared_layer(self):
|
||||||
|
|
||||||
class LayerWithLoss(layers.Layer):
|
class LayerWithLoss(layers.Layer):
|
||||||
@ -352,7 +351,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(len(m2.losses), 2)
|
self.assertEqual(len(m2.losses), 2)
|
||||||
self.assertAllClose(m2.losses, [6, 12])
|
self.assertAllClose(m2.losses, [6, 12])
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_with_shared_nested_layer(self):
|
def test_with_shared_nested_layer(self):
|
||||||
|
|
||||||
class LayerWithLoss(layers.Layer):
|
class LayerWithLoss(layers.Layer):
|
||||||
@ -378,7 +377,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(len(m2.losses), 2)
|
self.assertEqual(len(m2.losses), 2)
|
||||||
self.assertAllClose(m2.losses, [6, 12])
|
self.assertAllClose(m2.losses, [6, 12])
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_clear_losses(self):
|
def test_clear_losses(self):
|
||||||
|
|
||||||
class LayerWithSharedNestedLossLayer(layers.Layer):
|
class LayerWithSharedNestedLossLayer(layers.Layer):
|
||||||
@ -429,7 +428,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(len(model.get_losses_for(x4)), 2)
|
self.assertEqual(len(model.get_losses_for(x4)), 2)
|
||||||
self.assertEqual(len(model.get_losses_for(None)), 1)
|
self.assertEqual(len(model.get_losses_for(None)), 1)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_invalid_constant_input(self):
|
def test_invalid_constant_input(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
inputs = Input(shape=(1,))
|
inputs = Input(shape=(1,))
|
||||||
@ -440,7 +439,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
'Expected a symbolic Tensors or a callable for the loss value'):
|
'Expected a symbolic Tensors or a callable for the loss value'):
|
||||||
model.add_loss(1.)
|
model.add_loss(1.)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_invalid_variable_input(self):
|
def test_invalid_variable_input(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
inputs = Input(shape=(1,))
|
inputs = Input(shape=(1,))
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.losses import loss_reduction
|
from tensorflow.python.ops.losses import loss_reduction
|
||||||
@ -101,8 +102,12 @@ def compute_weighted_loss(losses,
|
|||||||
# to multiple replicas. Used only for estimator + v1 optimizer flow.
|
# to multiple replicas. Used only for estimator + v1 optimizer flow.
|
||||||
ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
|
ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
|
||||||
|
|
||||||
|
if not isinstance(losses, keras_tensor.KerasTensor):
|
||||||
losses = ops.convert_to_tensor_v2(losses)
|
losses = ops.convert_to_tensor_v2(losses)
|
||||||
input_dtype = losses.dtype
|
input_dtype = losses.dtype
|
||||||
|
|
||||||
|
if not isinstance(sample_weight, keras_tensor.KerasTensor):
|
||||||
|
sample_weight = ops.convert_to_tensor_v2(sample_weight)
|
||||||
weighted_losses = tf_losses_utils.scale_losses_by_sample_weight(
|
weighted_losses = tf_losses_utils.scale_losses_by_sample_weight(
|
||||||
losses, sample_weight)
|
losses, sample_weight)
|
||||||
# Apply reduction function to the individual weighted losses.
|
# Apply reduction function to the individual weighted losses.
|
||||||
|
@ -346,6 +346,7 @@ def update_confusion_matrix_variables(variables_to_update,
|
|||||||
y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
|
y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
|
||||||
y_pred, y_true)
|
y_pred, y_true)
|
||||||
else:
|
else:
|
||||||
|
sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype)
|
||||||
y_pred, y_true, sample_weight = (
|
y_pred, y_true, sample_weight = (
|
||||||
tf_losses_utils.squeeze_or_expand_dimensions(
|
tf_losses_utils.squeeze_or_expand_dimensions(
|
||||||
y_pred, y_true, sample_weight=sample_weight))
|
y_pred, y_true, sample_weight=sample_weight))
|
||||||
|
@ -84,7 +84,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
|
|||||||
if sample_weight is None:
|
if sample_weight is None:
|
||||||
return y_pred, y_true
|
return y_pred, y_true
|
||||||
|
|
||||||
sample_weight = ops.convert_to_tensor(sample_weight)
|
|
||||||
weights_shape = sample_weight.shape
|
weights_shape = sample_weight.shape
|
||||||
weights_rank = weights_shape.ndims
|
weights_rank = weights_shape.ndims
|
||||||
if weights_rank == 0: # If weights is scalar, do nothing.
|
if weights_rank == 0: # If weights is scalar, do nothing.
|
||||||
|
@ -427,6 +427,7 @@ def compute_average_loss(per_example_loss,
|
|||||||
|
|
||||||
with losses_util.check_per_example_loss_rank(per_example_loss):
|
with losses_util.check_per_example_loss_rank(per_example_loss):
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
|
sample_weight = ops.convert_to_tensor(sample_weight)
|
||||||
per_example_loss = losses_util.scale_losses_by_sample_weight(
|
per_example_loss = losses_util.scale_losses_by_sample_weight(
|
||||||
per_example_loss, sample_weight)
|
per_example_loss, sample_weight)
|
||||||
per_example_loss = math_ops.cast(per_example_loss, input_dtype)
|
per_example_loss = math_ops.cast(per_example_loss, input_dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user