Throw error if there are unexpected keys in sample_weight / class_weight dictionaries.
PiperOrigin-RevId: 241421418
This commit is contained in:
parent
92c63053f4
commit
3cade69a60
@ -1406,6 +1406,39 @@ class LossWeightingTest(keras_parameterized.TestCase):
|
||||
temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0)
|
||||
self.assertLess(score[0], ref_score[0])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
|
||||
def test_fit_with_incorrect_weights(self):
|
||||
input_a = keras.layers.Input(shape=(3,), name='input_a')
|
||||
input_b = keras.layers.Input(shape=(3,), name='input_b')
|
||||
|
||||
dense = keras.layers.Dense(2, name='output_1')
|
||||
dropout = keras.layers.Dropout(0.5, name='output_2')
|
||||
branch_a = [input_a, dense]
|
||||
branch_b = [input_b, dense, dropout]
|
||||
|
||||
model = testing_utils.get_multi_io_model(branch_a, branch_b)
|
||||
model.compile(
|
||||
optimizer='adam',
|
||||
loss='mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
x = np.random.random((10, 3))
|
||||
y = np.random.random((10, 2))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'Unknown entries in sample_weight dictionary: \[\'unknown\'\]. '
|
||||
r'Only expected following keys: \[\'output_1\', \'output_2\'\]'):
|
||||
model.fit([x, x], [y, y],
|
||||
epochs=1,
|
||||
sample_weight={'unknown': 'something'})
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'Unknown entries in class_weight dictionary: \[\'unknown\'\]. '
|
||||
r'Only expected following keys: \[\'output_1\', \'output_2\'\]'):
|
||||
model.fit([x, x], [y, y], epochs=1, class_weight={'unknown': 'something'})
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_class_weight_invalid_use_case(self):
|
||||
num_classes = 5
|
||||
@ -2288,9 +2321,10 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase):
|
||||
output_b_np = np.random.random((10, 3))
|
||||
|
||||
_ = model.train_on_batch([input_a_np, input_b_np],
|
||||
[output_a_np, output_b_np],
|
||||
{y: np.random.random((10, 4)),
|
||||
y1: np.random.random((10, 3))})
|
||||
[output_a_np, output_b_np], {
|
||||
'dense_1': np.random.random((10,)),
|
||||
'dropout': np.random.random((10,))
|
||||
})
|
||||
# test dictionary of target_tensors
|
||||
with self.assertRaises(ValueError):
|
||||
model.compile(optimizer, loss,
|
||||
@ -2305,9 +2339,10 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase):
|
||||
sample_weight_mode=None,
|
||||
target_tensors={'dense_1': y, 'dropout': y1})
|
||||
_ = model.train_on_batch([input_a_np, input_b_np],
|
||||
[output_a_np, output_b_np],
|
||||
{y: np.random.random((10, 4)),
|
||||
y1: np.random.random((10, 3))})
|
||||
[output_a_np, output_b_np], {
|
||||
'dense_1': np.random.random((10,)),
|
||||
'dropout': np.random.random((10,))
|
||||
})
|
||||
|
||||
# test with custom TF placeholder as target
|
||||
pl_target_a = keras.backend.array_ops.placeholder('float32',
|
||||
|
@ -380,7 +380,8 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
||||
'You should provide one `' + weight_type + '`'
|
||||
'array per model output.')
|
||||
return x_weight
|
||||
if isinstance(x_weight, dict):
|
||||
if isinstance(x_weight, collections.Mapping):
|
||||
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
|
||||
x_weights = []
|
||||
for name in output_names:
|
||||
x_weights.append(x_weight.get(name))
|
||||
|
Loading…
Reference in New Issue
Block a user