Fix issue with return value of evaluate() in models that add custom metrics via overriding train_step.
PiperOrigin-RevId: 330008361 Change-Id: I8942e176972a3c080c97e25dfb8ff42641b54372
This commit is contained in:
parent
1657a92beb
commit
332f2338ce
@ -1366,7 +1366,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
if return_dict:
|
||||
return logs
|
||||
else:
|
||||
results = [logs.get(name, None) for name in self.metrics_names]
|
||||
results = []
|
||||
for name in self.metrics_names:
|
||||
if name in logs:
|
||||
results.append(logs[name])
|
||||
for key in sorted(logs.keys()):
|
||||
if key not in self.metrics_names:
|
||||
results.append(logs[key])
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
return results
|
||||
|
@ -1618,6 +1618,64 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
model.evaluate(x, batch_size=batch_size)
|
||||
model.predict(x, batch_size=batch_size)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(
|
||||
always_skip_v1=True)
|
||||
@parameterized.named_parameters(
|
||||
('custom_metrics', False, True),
|
||||
('compiled_metrics', True, False),
|
||||
('both_compiled_and_custom_metrics', True, True))
|
||||
def test_evaluate_with_custom_test_step(
|
||||
self, use_compiled_metrics, use_custom_metrics):
|
||||
|
||||
class MyModel(training_module.Model):
|
||||
|
||||
def test_step(self, data):
|
||||
x, y = data
|
||||
pred = self(x)
|
||||
metrics = {}
|
||||
if use_compiled_metrics:
|
||||
self.compiled_metrics.update_state(y, pred)
|
||||
self.compiled_loss(y, pred)
|
||||
for metric in self.metrics:
|
||||
metrics[metric.name] = metric.result()
|
||||
if use_custom_metrics:
|
||||
custom_metrics = {
|
||||
'mean': math_ops.reduce_mean(pred),
|
||||
'sum': math_ops.reduce_sum(pred)
|
||||
}
|
||||
metrics.update(custom_metrics)
|
||||
return metrics
|
||||
|
||||
inputs = layers_module.Input((2,))
|
||||
outputs = layers_module.Dense(3)(inputs)
|
||||
model = MyModel(inputs, outputs)
|
||||
if use_compiled_metrics:
|
||||
model.compile('adam', 'mse', metrics=['mae', 'mape'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
else:
|
||||
model.compile('adam', 'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
x = np.random.random((4, 2))
|
||||
y = np.random.random((4, 3))
|
||||
results_list = model.evaluate(x, y)
|
||||
results_dict = model.evaluate(x, y, return_dict=True)
|
||||
self.assertLen(results_list, len(results_dict))
|
||||
if use_compiled_metrics and use_custom_metrics:
|
||||
self.assertLen(results_list, 5)
|
||||
self.assertEqual(results_list,
|
||||
[results_dict['loss'],
|
||||
results_dict['mae'], results_dict['mape'],
|
||||
results_dict['mean'], results_dict['sum']])
|
||||
if use_compiled_metrics and not use_custom_metrics:
|
||||
self.assertLen(results_list, 3)
|
||||
self.assertEqual(results_list,
|
||||
[results_dict['loss'],
|
||||
results_dict['mae'], results_dict['mape']])
|
||||
if not use_compiled_metrics and use_custom_metrics:
|
||||
self.assertLen(results_list, 2)
|
||||
self.assertEqual(results_list,
|
||||
[results_dict['mean'], results_dict['sum']])
|
||||
|
||||
|
||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user