Fix issue with return value of evaluate() in models that add custom metrics via overriding train_step.

PiperOrigin-RevId: 330008361
Change-Id: I8942e176972a3c080c97e25dfb8ff42641b54372
This commit is contained in:
Francois Chollet 2020-09-03 15:45:53 -07:00 committed by TensorFlower Gardener
parent 1657a92beb
commit 332f2338ce
2 changed files with 65 additions and 1 deletions

View File

@ -1366,7 +1366,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
if return_dict:
return logs
else:
results = [logs.get(name, None) for name in self.metrics_names]
results = []
for name in self.metrics_names:
if name in logs:
results.append(logs[name])
for key in sorted(logs.keys()):
if key not in self.metrics_names:
results.append(logs[key])
if len(results) == 1:
return results[0]
return results

View File

@ -1618,6 +1618,64 @@ class TrainingTest(keras_parameterized.TestCase):
model.evaluate(x, batch_size=batch_size)
model.predict(x, batch_size=batch_size)
@keras_parameterized.run_all_keras_modes(
always_skip_v1=True)
@parameterized.named_parameters(
('custom_metrics', False, True),
('compiled_metrics', True, False),
('both_compiled_and_custom_metrics', True, True))
def test_evaluate_with_custom_test_step(
self, use_compiled_metrics, use_custom_metrics):
class MyModel(training_module.Model):
def test_step(self, data):
x, y = data
pred = self(x)
metrics = {}
if use_compiled_metrics:
self.compiled_metrics.update_state(y, pred)
self.compiled_loss(y, pred)
for metric in self.metrics:
metrics[metric.name] = metric.result()
if use_custom_metrics:
custom_metrics = {
'mean': math_ops.reduce_mean(pred),
'sum': math_ops.reduce_sum(pred)
}
metrics.update(custom_metrics)
return metrics
inputs = layers_module.Input((2,))
outputs = layers_module.Dense(3)(inputs)
model = MyModel(inputs, outputs)
if use_compiled_metrics:
model.compile('adam', 'mse', metrics=['mae', 'mape'],
run_eagerly=testing_utils.should_run_eagerly())
else:
model.compile('adam', 'mse',
run_eagerly=testing_utils.should_run_eagerly())
x = np.random.random((4, 2))
y = np.random.random((4, 3))
results_list = model.evaluate(x, y)
results_dict = model.evaluate(x, y, return_dict=True)
self.assertLen(results_list, len(results_dict))
if use_compiled_metrics and use_custom_metrics:
self.assertLen(results_list, 5)
self.assertEqual(results_list,
[results_dict['loss'],
results_dict['mae'], results_dict['mape'],
results_dict['mean'], results_dict['sum']])
if use_compiled_metrics and not use_custom_metrics:
self.assertLen(results_list, 3)
self.assertEqual(results_list,
[results_dict['loss'],
results_dict['mae'], results_dict['mape']])
if not use_compiled_metrics and use_custom_metrics:
self.assertLen(results_list, 2)
self.assertEqual(results_list,
[results_dict['mean'], results_dict['sum']])
class TestExceptionsAndWarnings(keras_parameterized.TestCase):