STT-tensorflow/tensorflow/python/keras/metrics_functional_test.py
Scott Zhu ab3601d340 Remove the usage of eager context in test code, and replace them with test combinations.
PiperOrigin-RevId: 329947449
Change-Id: Ie513d6547c6ba458515be31e4df84976893cfd3d
2020-09-03 11:05:50 -07:00

146 lines
6.4 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras metrics functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import combinations
from tensorflow.python.keras import metrics
from tensorflow.python.platform import test
class KerasFunctionalMetricsTest(test.TestCase, parameterized.TestCase):
def test_metrics(self):
with self.cached_session():
y_a = K.variable(np.random.random((6, 7)))
y_b = K.variable(np.random.random((6, 7)))
for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
output = metric(y_a, y_b)
self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy_int(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
y_true = K.variable(np.random.randint(0, 7, (6,)))
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
# Test correctness if the shape of y_true is (num_samples,)
y_true = K.variable([1., 0., 0., 0.])
y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
# Test correctness if the shape of y_true is (num_samples, 1)
y_true = K.variable([[1.], [0.], [0.], [0.]])
y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
# Test correctness if the shape of y_true is (batch_size, seq_length) and
# y_pred is (batch_size, seq_length, num_classes)
y_pred = K.variable(
np.array([[[0.2, 0.3, 0.1], [0.1, 0.2, 0.7]],
[[0.3, 0.2, 0.1], [0.7, 0.2, 0.1]]]))
y_true = K.variable(np.array([[1, 0], [1, 0]]))
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [[1., 0.], [0., 1.]])
def test_sparse_categorical_accuracy_float(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
y_true = K.variable(np.random.random((6,)))
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
@combinations.generate(combinations.combine(mode=['eager']))
def test_sparse_categorical_accuracy_eager(self):
"""Tests that ints passed in via Eager return results. See b/113504761."""
metric = metrics.sparse_categorical_accuracy
y_true = np.arange(6).reshape([6, 1])
y_pred = np.arange(36).reshape([6, 6])
self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
@combinations.generate(combinations.combine(mode=['eager']))
def test_sparse_categorical_accuracy_float_eager(self):
"""Tests that floats passed in via Eager return results. See b/113504761."""
metric = metrics.sparse_categorical_accuracy
y_true = np.arange(6, dtype=np.float32).reshape([6, 1])
y_pred = np.arange(36).reshape([6, 6])
self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
# Test correctness if the shape of y_true is (num_samples, 1)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(np.mean(result), 1)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(np.mean(result), 0.5)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(np.mean(result), 0.)
# Test correctness if the shape of y_true is (num_samples,)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([1, 0]))
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(np.mean(result), 1)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(np.mean(result), 0.5)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(np.mean(result), 0.)
# Test correctness if the shape of y_true is (batch_size, seq_length) and
# y_pred is (batch_size, seq_length, num_classes)
y_pred = K.variable(
np.array([[[0.3, 0.2, 0.1], [0.1, 0.2, 0.7], [0.1, 0.2, 0.7]],
[[0.3, 0.2, 0.1], [0.1, 0.2, 0.7], [0.3, 0.2, 0.1]]]))
y_true = K.variable(np.array([[1, 0, 0], [1, 0, 1]]))
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(np.mean(result), 1)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(np.mean(result), 0.5)
result = K.eval(
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(np.mean(result), 0.)
def test_top_k_categorical_accuracy(self):
with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(np.mean(result), 1)
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(np.mean(result), 0.5)
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(np.mean(result), 0.)
if __name__ == '__main__':
test.main()