Throw error if there are unexpected keys in sample_weight / class_weight dictionaries.
PiperOrigin-RevId: 241421418
This commit is contained in:
parent
92c63053f4
commit
3cade69a60
@ -1406,6 +1406,39 @@ class LossWeightingTest(keras_parameterized.TestCase):
|
|||||||
temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0)
|
temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0)
|
||||||
self.assertLess(score[0], ref_score[0])
|
self.assertLess(score[0], ref_score[0])
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
|
||||||
|
def test_fit_with_incorrect_weights(self):
|
||||||
|
input_a = keras.layers.Input(shape=(3,), name='input_a')
|
||||||
|
input_b = keras.layers.Input(shape=(3,), name='input_b')
|
||||||
|
|
||||||
|
dense = keras.layers.Dense(2, name='output_1')
|
||||||
|
dropout = keras.layers.Dropout(0.5, name='output_2')
|
||||||
|
branch_a = [input_a, dense]
|
||||||
|
branch_b = [input_b, dense, dropout]
|
||||||
|
|
||||||
|
model = testing_utils.get_multi_io_model(branch_a, branch_b)
|
||||||
|
model.compile(
|
||||||
|
optimizer='adam',
|
||||||
|
loss='mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
x = np.random.random((10, 3))
|
||||||
|
y = np.random.random((10, 2))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
r'Unknown entries in sample_weight dictionary: \[\'unknown\'\]. '
|
||||||
|
r'Only expected following keys: \[\'output_1\', \'output_2\'\]'):
|
||||||
|
model.fit([x, x], [y, y],
|
||||||
|
epochs=1,
|
||||||
|
sample_weight={'unknown': 'something'})
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
r'Unknown entries in class_weight dictionary: \[\'unknown\'\]. '
|
||||||
|
r'Only expected following keys: \[\'output_1\', \'output_2\'\]'):
|
||||||
|
model.fit([x, x], [y, y], epochs=1, class_weight={'unknown': 'something'})
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_class_weight_invalid_use_case(self):
|
def test_class_weight_invalid_use_case(self):
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
@ -2288,9 +2321,10 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase):
|
|||||||
output_b_np = np.random.random((10, 3))
|
output_b_np = np.random.random((10, 3))
|
||||||
|
|
||||||
_ = model.train_on_batch([input_a_np, input_b_np],
|
_ = model.train_on_batch([input_a_np, input_b_np],
|
||||||
[output_a_np, output_b_np],
|
[output_a_np, output_b_np], {
|
||||||
{y: np.random.random((10, 4)),
|
'dense_1': np.random.random((10,)),
|
||||||
y1: np.random.random((10, 3))})
|
'dropout': np.random.random((10,))
|
||||||
|
})
|
||||||
# test dictionary of target_tensors
|
# test dictionary of target_tensors
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.compile(optimizer, loss,
|
model.compile(optimizer, loss,
|
||||||
@ -2305,9 +2339,10 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase):
|
|||||||
sample_weight_mode=None,
|
sample_weight_mode=None,
|
||||||
target_tensors={'dense_1': y, 'dropout': y1})
|
target_tensors={'dense_1': y, 'dropout': y1})
|
||||||
_ = model.train_on_batch([input_a_np, input_b_np],
|
_ = model.train_on_batch([input_a_np, input_b_np],
|
||||||
[output_a_np, output_b_np],
|
[output_a_np, output_b_np], {
|
||||||
{y: np.random.random((10, 4)),
|
'dense_1': np.random.random((10,)),
|
||||||
y1: np.random.random((10, 3))})
|
'dropout': np.random.random((10,))
|
||||||
|
})
|
||||||
|
|
||||||
# test with custom TF placeholder as target
|
# test with custom TF placeholder as target
|
||||||
pl_target_a = keras.backend.array_ops.placeholder('float32',
|
pl_target_a = keras.backend.array_ops.placeholder('float32',
|
||||||
|
@ -380,7 +380,8 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
|||||||
'You should provide one `' + weight_type + '`'
|
'You should provide one `' + weight_type + '`'
|
||||||
'array per model output.')
|
'array per model output.')
|
||||||
return x_weight
|
return x_weight
|
||||||
if isinstance(x_weight, dict):
|
if isinstance(x_weight, collections.Mapping):
|
||||||
|
generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
|
||||||
x_weights = []
|
x_weights = []
|
||||||
for name in output_names:
|
for name in output_names:
|
||||||
x_weights.append(x_weight.get(name))
|
x_weights.append(x_weight.get(name))
|
||||||
|
Loading…
Reference in New Issue
Block a user