Update tests under Keras to use combination.
1. Change all test_util.run_all_in_graph_and_eager_modes to combination. 2. Replace import tensorflow.python.keras with explicit module import. 3. Update BUILD file to not rely on the overall Keras target. PiperOrigin-RevId: 299403937 Change-Id: Ic798fd0cc2602b6447aaf28922a4a418514d7131
This commit is contained in:
parent
c4d9a3a647
commit
c8823160e0
@ -94,6 +94,7 @@ py_library(
|
||||
name = "backend_config",
|
||||
srcs = ["backend_config.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/python:util"],
|
||||
)
|
||||
|
||||
# TODO(scottzhu): Cleanup this target and point all the user to keras/engine.
|
||||
@ -322,9 +323,14 @@ tf_py_test(
|
||||
srcs = ["activations_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":keras",
|
||||
":activations",
|
||||
":backend",
|
||||
":combinations",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/keras/layers:advanced_activations",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -352,10 +358,11 @@ tf_py_test(
|
||||
srcs = ["constraints_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":keras",
|
||||
":backend",
|
||||
":combinations",
|
||||
":constraints",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -365,11 +372,17 @@ tf_py_test(
|
||||
srcs = ["initializers_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":keras",
|
||||
":backend",
|
||||
":combinations",
|
||||
":initializers",
|
||||
":models",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -407,10 +420,17 @@ tf_py_test(
|
||||
srcs = ["losses_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":keras",
|
||||
":backend",
|
||||
":combinations",
|
||||
":losses",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -433,10 +453,27 @@ tf_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
":combinations",
|
||||
":keras",
|
||||
":metrics",
|
||||
":testing_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python:weights_broadcast_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -447,8 +484,17 @@ tf_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
":keras",
|
||||
":combinations",
|
||||
":metrics",
|
||||
":models",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/keras/utils:metrics_utils",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -527,9 +573,22 @@ tf_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
":keras",
|
||||
":backend",
|
||||
":combinations",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -541,10 +600,10 @@ tf_py_test(
|
||||
srcs = ["backend_config_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":keras",
|
||||
":backend",
|
||||
":backend_config",
|
||||
":combinations",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,10 +18,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import serialization
|
||||
from tensorflow.python.ops import nn_ops as nn
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -32,34 +37,34 @@ def _ref_softmax(values):
|
||||
return e / np.sum(e)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class KerasActivationsTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KerasActivationsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_serialization(self):
|
||||
all_activations = ['softmax', 'relu', 'elu', 'tanh',
|
||||
'sigmoid', 'hard_sigmoid', 'linear',
|
||||
'softplus', 'softsign', 'selu']
|
||||
for name in all_activations:
|
||||
fn = keras.activations.get(name)
|
||||
ref_fn = getattr(keras.activations, name)
|
||||
fn = activations.get(name)
|
||||
ref_fn = getattr(activations, name)
|
||||
assert fn == ref_fn
|
||||
config = keras.activations.serialize(fn)
|
||||
fn = keras.activations.deserialize(config)
|
||||
config = activations.serialize(fn)
|
||||
fn = activations.deserialize(config)
|
||||
assert fn == ref_fn
|
||||
|
||||
def test_serialization_v2(self):
|
||||
activation_map = {nn.softmax_v2: 'softmax'}
|
||||
for fn_v2_key in activation_map:
|
||||
fn_v2 = keras.activations.get(fn_v2_key)
|
||||
config = keras.activations.serialize(fn_v2)
|
||||
fn = keras.activations.deserialize(config)
|
||||
fn_v2 = activations.get(fn_v2_key)
|
||||
config = activations.serialize(fn_v2)
|
||||
fn = activations.deserialize(config)
|
||||
assert fn.__name__ == activation_map[fn_v2_key]
|
||||
|
||||
def test_serialization_with_layers(self):
|
||||
activation = keras.layers.LeakyReLU(alpha=0.1)
|
||||
layer = keras.layers.Dense(3, activation=activation)
|
||||
config = keras.layers.serialize(layer)
|
||||
deserialized_layer = keras.layers.deserialize(
|
||||
activation = advanced_activations.LeakyReLU(alpha=0.1)
|
||||
layer = core.Dense(3, activation=activation)
|
||||
config = serialization.serialize(layer)
|
||||
deserialized_layer = serialization.deserialize(
|
||||
config, custom_objects={'LeakyReLU': activation})
|
||||
self.assertEqual(deserialized_layer.__class__.__name__,
|
||||
layer.__class__.__name__)
|
||||
@ -67,8 +72,8 @@ class KerasActivationsTest(test.TestCase):
|
||||
activation.__class__.__name__)
|
||||
|
||||
def test_softmax(self):
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.softmax(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.softmax(x)])
|
||||
test_values = np.random.random((2, 5))
|
||||
|
||||
result = f([test_values])[0]
|
||||
@ -76,28 +81,28 @@ class KerasActivationsTest(test.TestCase):
|
||||
self.assertAllClose(result[0], expected, rtol=1e-05)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
x = keras.backend.placeholder(ndim=1)
|
||||
keras.activations.softmax(x)
|
||||
x = backend.placeholder(ndim=1)
|
||||
activations.softmax(x)
|
||||
|
||||
def test_temporal_softmax(self):
|
||||
x = keras.backend.placeholder(shape=(2, 2, 3))
|
||||
f = keras.backend.function([x], [keras.activations.softmax(x)])
|
||||
x = backend.placeholder(shape=(2, 2, 3))
|
||||
f = backend.function([x], [activations.softmax(x)])
|
||||
test_values = np.random.random((2, 2, 3)) * 10
|
||||
result = f([test_values])[0]
|
||||
expected = _ref_softmax(test_values[0, 0])
|
||||
self.assertAllClose(result[0, 0], expected, rtol=1e-05)
|
||||
|
||||
def test_selu(self):
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.selu(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.selu(x)])
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
scale = 1.0507009873554804934193349852946
|
||||
|
||||
positive_values = np.array([[1, 2]], dtype=keras.backend.floatx())
|
||||
positive_values = np.array([[1, 2]], dtype=backend.floatx())
|
||||
result = f([positive_values])[0]
|
||||
self.assertAllClose(result, positive_values * scale, rtol=1e-05)
|
||||
|
||||
negative_values = np.array([[-1, -2]], dtype=keras.backend.floatx())
|
||||
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
|
||||
result = f([negative_values])[0]
|
||||
true_result = (np.exp(negative_values) - 1) * scale * alpha
|
||||
self.assertAllClose(result, true_result)
|
||||
@ -106,8 +111,8 @@ class KerasActivationsTest(test.TestCase):
|
||||
def softplus(x):
|
||||
return np.log(np.ones_like(x) + np.exp(x))
|
||||
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.softplus(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.softplus(x)])
|
||||
test_values = np.random.random((2, 5))
|
||||
result = f([test_values])[0]
|
||||
expected = softplus(test_values)
|
||||
@ -117,8 +122,8 @@ class KerasActivationsTest(test.TestCase):
|
||||
def softsign(x):
|
||||
return np.divide(x, np.ones_like(x) + np.absolute(x))
|
||||
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.softsign(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.softsign(x)])
|
||||
test_values = np.random.random((2, 5))
|
||||
result = f([test_values])[0]
|
||||
expected = softsign(test_values)
|
||||
@ -133,8 +138,8 @@ class KerasActivationsTest(test.TestCase):
|
||||
return z / (1 + z)
|
||||
sigmoid = np.vectorize(ref_sigmoid)
|
||||
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.sigmoid(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.sigmoid(x)])
|
||||
test_values = np.random.random((2, 5))
|
||||
result = f([test_values])[0]
|
||||
expected = sigmoid(test_values)
|
||||
@ -146,16 +151,16 @@ class KerasActivationsTest(test.TestCase):
|
||||
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
|
||||
return z
|
||||
hard_sigmoid = np.vectorize(ref_hard_sigmoid)
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.hard_sigmoid(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.hard_sigmoid(x)])
|
||||
test_values = np.random.random((2, 5))
|
||||
result = f([test_values])[0]
|
||||
expected = hard_sigmoid(test_values)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_relu(self):
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.relu(x)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.relu(x)])
|
||||
positive_values = np.random.random((2, 5))
|
||||
result = f([positive_values])[0]
|
||||
self.assertAllClose(result, positive_values, rtol=1e-05)
|
||||
@ -166,44 +171,45 @@ class KerasActivationsTest(test.TestCase):
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_elu(self):
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
f = keras.backend.function([x], [keras.activations.elu(x, 0.5)])
|
||||
x = backend.placeholder(ndim=2)
|
||||
f = backend.function([x], [activations.elu(x, 0.5)])
|
||||
test_values = np.random.random((2, 5))
|
||||
result = f([test_values])[0]
|
||||
self.assertAllClose(result, test_values, rtol=1e-05)
|
||||
negative_values = np.array([[-1, -2]], dtype=keras.backend.floatx())
|
||||
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
|
||||
result = f([negative_values])[0]
|
||||
true_result = (np.exp(negative_values) - 1) / 2
|
||||
self.assertAllClose(result, true_result)
|
||||
|
||||
def test_tanh(self):
|
||||
test_values = np.random.random((2, 5))
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
exp = keras.activations.tanh(x)
|
||||
f = keras.backend.function([x], [exp])
|
||||
x = backend.placeholder(ndim=2)
|
||||
exp = activations.tanh(x)
|
||||
f = backend.function([x], [exp])
|
||||
result = f([test_values])[0]
|
||||
expected = np.tanh(test_values)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_exponential(self):
|
||||
test_values = np.random.random((2, 5))
|
||||
x = keras.backend.placeholder(ndim=2)
|
||||
exp = keras.activations.exponential(x)
|
||||
f = keras.backend.function([x], [exp])
|
||||
x = backend.placeholder(ndim=2)
|
||||
exp = activations.exponential(x)
|
||||
f = backend.function([x], [exp])
|
||||
result = f([test_values])[0]
|
||||
expected = np.exp(test_values)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_linear(self):
|
||||
x = np.random.random((10, 5))
|
||||
self.assertAllClose(x, keras.activations.linear(x))
|
||||
self.assertAllClose(x, activations.linear(x))
|
||||
|
||||
def test_invalid_usage(self):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.activations.get('unknown')
|
||||
activations.get('unknown')
|
||||
|
||||
# The following should be possible but should raise a warning:
|
||||
keras.activations.get(keras.layers.LeakyReLU())
|
||||
activations.get(advanced_activations.LeakyReLU())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -17,38 +17,38 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class BackendConfigTest(test.TestCase):
|
||||
|
||||
def test_backend(self):
|
||||
self.assertEqual(keras.backend.backend(), 'tensorflow')
|
||||
self.assertEqual(backend.backend(), 'tensorflow')
|
||||
|
||||
def test_epsilon(self):
|
||||
epsilon = 1e-2
|
||||
keras.backend_config.set_epsilon(epsilon)
|
||||
self.assertEqual(keras.backend_config.epsilon(), epsilon)
|
||||
keras.backend_config.set_epsilon(1e-7)
|
||||
self.assertEqual(keras.backend_config.epsilon(), 1e-7)
|
||||
backend_config.set_epsilon(epsilon)
|
||||
self.assertEqual(backend_config.epsilon(), epsilon)
|
||||
backend_config.set_epsilon(1e-7)
|
||||
self.assertEqual(backend_config.epsilon(), 1e-7)
|
||||
|
||||
def test_floatx(self):
|
||||
floatx = 'float64'
|
||||
keras.backend_config.set_floatx(floatx)
|
||||
self.assertEqual(keras.backend_config.floatx(), floatx)
|
||||
keras.backend_config.set_floatx('float32')
|
||||
self.assertEqual(keras.backend_config.floatx(), 'float32')
|
||||
backend_config.set_floatx(floatx)
|
||||
self.assertEqual(backend_config.floatx(), floatx)
|
||||
backend_config.set_floatx('float32')
|
||||
self.assertEqual(backend_config.floatx(), 'float32')
|
||||
|
||||
def test_image_data_format(self):
|
||||
image_data_format = 'channels_first'
|
||||
keras.backend_config.set_image_data_format(image_data_format)
|
||||
self.assertEqual(keras.backend_config.image_data_format(),
|
||||
image_data_format)
|
||||
keras.backend_config.set_image_data_format('channels_last')
|
||||
self.assertEqual(keras.backend_config.image_data_format(), 'channels_last')
|
||||
backend_config.set_image_data_format(image_data_format)
|
||||
self.assertEqual(backend_config.image_data_format(), image_data_format)
|
||||
backend_config.set_image_data_format('channels_last')
|
||||
self.assertEqual(backend_config.image_data_format(), 'channels_last')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,8 +22,9 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -44,47 +45,45 @@ def get_example_kernel(width):
|
||||
return example_array
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KerasConstraintsTest(test.TestCase):
|
||||
|
||||
def test_serialization(self):
|
||||
all_activations = ['max_norm', 'non_neg',
|
||||
'unit_norm', 'min_max_norm']
|
||||
for name in all_activations:
|
||||
fn = keras.constraints.get(name)
|
||||
ref_fn = getattr(keras.constraints, name)()
|
||||
fn = constraints.get(name)
|
||||
ref_fn = getattr(constraints, name)()
|
||||
assert fn.__class__ == ref_fn.__class__
|
||||
config = keras.constraints.serialize(fn)
|
||||
fn = keras.constraints.deserialize(config)
|
||||
config = constraints.serialize(fn)
|
||||
fn = constraints.deserialize(config)
|
||||
assert fn.__class__ == ref_fn.__class__
|
||||
|
||||
def test_max_norm(self):
|
||||
array = get_example_array()
|
||||
for m in get_test_values():
|
||||
norm_instance = keras.constraints.max_norm(m)
|
||||
normed = norm_instance(keras.backend.variable(array))
|
||||
assert np.all(keras.backend.eval(normed) < m)
|
||||
norm_instance = constraints.max_norm(m)
|
||||
normed = norm_instance(backend.variable(array))
|
||||
assert np.all(backend.eval(normed) < m)
|
||||
|
||||
# a more explicit example
|
||||
norm_instance = keras.constraints.max_norm(2.0)
|
||||
norm_instance = constraints.max_norm(2.0)
|
||||
x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T
|
||||
x_normed_target = np.array(
|
||||
[[0, 0, 0], [1.0, 0, 0], [2.0, 0, 0],
|
||||
[2. / np.sqrt(3), 2. / np.sqrt(3), 2. / np.sqrt(3)]]).T
|
||||
x_normed_actual = keras.backend.eval(
|
||||
norm_instance(keras.backend.variable(x)))
|
||||
x_normed_actual = backend.eval(norm_instance(backend.variable(x)))
|
||||
self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05)
|
||||
|
||||
def test_non_neg(self):
|
||||
non_neg_instance = keras.constraints.non_neg()
|
||||
normed = non_neg_instance(keras.backend.variable(get_example_array()))
|
||||
assert np.all(np.min(keras.backend.eval(normed), axis=1) == 0.)
|
||||
non_neg_instance = constraints.non_neg()
|
||||
normed = non_neg_instance(backend.variable(get_example_array()))
|
||||
assert np.all(np.min(backend.eval(normed), axis=1) == 0.)
|
||||
|
||||
def test_unit_norm(self):
|
||||
unit_norm_instance = keras.constraints.unit_norm()
|
||||
normalized = unit_norm_instance(keras.backend.variable(get_example_array()))
|
||||
norm_of_normalized = np.sqrt(
|
||||
np.sum(keras.backend.eval(normalized)**2, axis=0))
|
||||
unit_norm_instance = constraints.unit_norm()
|
||||
normalized = unit_norm_instance(backend.variable(get_example_array()))
|
||||
norm_of_normalized = np.sqrt(np.sum(backend.eval(normalized)**2, axis=0))
|
||||
# In the unit norm constraint, it should be equal to 1.
|
||||
difference = norm_of_normalized - 1.
|
||||
largest_difference = np.max(np.abs(difference))
|
||||
@ -93,10 +92,9 @@ class KerasConstraintsTest(test.TestCase):
|
||||
def test_min_max_norm(self):
|
||||
array = get_example_array()
|
||||
for m in get_test_values():
|
||||
norm_instance = keras.constraints.min_max_norm(
|
||||
min_value=m, max_value=m * 2)
|
||||
normed = norm_instance(keras.backend.variable(array))
|
||||
value = keras.backend.eval(normed)
|
||||
norm_instance = constraints.min_max_norm(min_value=m, max_value=m * 2)
|
||||
normed = norm_instance(backend.variable(array))
|
||||
value = backend.eval(normed)
|
||||
l2 = np.sqrt(np.sum(np.square(value), axis=0))
|
||||
assert not l2[l2 < m]
|
||||
assert not l2[l2 > m * 2 + 1e-5]
|
||||
@ -104,9 +102,9 @@ class KerasConstraintsTest(test.TestCase):
|
||||
def test_conv2d_radial_constraint(self):
|
||||
for width in (3, 4, 5, 6):
|
||||
array = get_example_kernel(width)
|
||||
norm_instance = keras.constraints.radial_constraint()
|
||||
normed = norm_instance(keras.backend.variable(array))
|
||||
value = keras.backend.eval(normed)
|
||||
norm_instance = constraints.radial_constraint()
|
||||
normed = norm_instance(backend.variable(array))
|
||||
value = backend.eval(normed)
|
||||
assert np.all(value.shape == array.shape)
|
||||
assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0])
|
||||
assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2)
|
||||
|
@ -20,33 +20,38 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KerasInitializersTest(test.TestCase):
|
||||
|
||||
def _runner(self, init, shape, target_mean=None, target_std=None,
|
||||
target_max=None, target_min=None):
|
||||
variable = keras.backend.variable(init(shape))
|
||||
output = keras.backend.get_value(variable)
|
||||
variable = backend.variable(init(shape))
|
||||
output = backend.get_value(variable)
|
||||
# Test serialization (assumes deterministic behavior).
|
||||
config = init.get_config()
|
||||
reconstructed_init = init.__class__.from_config(config)
|
||||
variable = keras.backend.variable(reconstructed_init(shape))
|
||||
output_2 = keras.backend.get_value(variable)
|
||||
variable = backend.variable(reconstructed_init(shape))
|
||||
output_2 = backend.get_value(variable)
|
||||
self.assertAllClose(output, output_2, atol=1e-4)
|
||||
|
||||
def test_uniform(self):
|
||||
tensor_shape = (9, 6, 7)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
|
||||
initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_max=1,
|
||||
@ -56,7 +61,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
tensor_shape = (8, 12, 99)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
|
||||
initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=1)
|
||||
@ -65,7 +70,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
tensor_shape = (12, 99, 7)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
|
||||
initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_max=2,
|
||||
@ -75,7 +80,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
tensor_shape = (5, 6, 4)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.ConstantV2(2.),
|
||||
initializers.ConstantV2(2.),
|
||||
tensor_shape,
|
||||
target_mean=2,
|
||||
target_max=2,
|
||||
@ -87,7 +92,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(1. / fan_in)
|
||||
self._runner(
|
||||
keras.initializers.lecun_uniformV2(seed=123),
|
||||
initializers.lecun_uniformV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -98,7 +103,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / (fan_in + fan_out))
|
||||
self._runner(
|
||||
keras.initializers.GlorotUniformV2(seed=123),
|
||||
initializers.GlorotUniformV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -109,7 +114,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / fan_in)
|
||||
self._runner(
|
||||
keras.initializers.he_uniformV2(seed=123),
|
||||
initializers.he_uniformV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -120,7 +125,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(1. / fan_in)
|
||||
self._runner(
|
||||
keras.initializers.lecun_normalV2(seed=123),
|
||||
initializers.lecun_normalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -131,7 +136,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / (fan_in + fan_out))
|
||||
self._runner(
|
||||
keras.initializers.GlorotNormalV2(seed=123),
|
||||
initializers.GlorotNormalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -142,7 +147,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / fan_in)
|
||||
self._runner(
|
||||
keras.initializers.he_normalV2(seed=123),
|
||||
initializers.he_normalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -151,23 +156,21 @@ class KerasInitializersTest(test.TestCase):
|
||||
tensor_shape = (20, 20)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.OrthogonalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.)
|
||||
initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.)
|
||||
|
||||
def test_identity(self):
|
||||
with self.cached_session():
|
||||
tensor_shape = (3, 4, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
self._runner(
|
||||
keras.initializers.IdentityV2(),
|
||||
initializers.IdentityV2(),
|
||||
tensor_shape,
|
||||
target_mean=1. / tensor_shape[0],
|
||||
target_max=1.)
|
||||
|
||||
tensor_shape = (3, 3)
|
||||
self._runner(
|
||||
keras.initializers.IdentityV2(),
|
||||
initializers.IdentityV2(),
|
||||
tensor_shape,
|
||||
target_mean=1. / tensor_shape[0],
|
||||
target_max=1.)
|
||||
@ -176,32 +179,26 @@ class KerasInitializersTest(test.TestCase):
|
||||
tensor_shape = (4, 5)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.ZerosV2(),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_max=0.)
|
||||
initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.)
|
||||
|
||||
def test_one(self):
|
||||
tensor_shape = (4, 5)
|
||||
with self.cached_session():
|
||||
self._runner(
|
||||
keras.initializers.OnesV2(),
|
||||
tensor_shape,
|
||||
target_mean=1.,
|
||||
target_max=1.)
|
||||
initializers.OnesV2(), tensor_shape, target_mean=1., target_max=1.)
|
||||
|
||||
def test_default_random_uniform(self):
|
||||
ru = keras.initializers.get('uniform')
|
||||
ru = initializers.get('uniform')
|
||||
self.assertEqual(ru.minval, -0.05)
|
||||
self.assertEqual(ru.maxval, 0.05)
|
||||
|
||||
def test_default_random_normal(self):
|
||||
rn = keras.initializers.get('normal')
|
||||
rn = initializers.get('normal')
|
||||
self.assertEqual(rn.mean, 0.0)
|
||||
self.assertEqual(rn.stddev, 0.05)
|
||||
|
||||
def test_default_truncated_normal(self):
|
||||
tn = keras.initializers.get('truncated_normal')
|
||||
tn = initializers.get('truncated_normal')
|
||||
self.assertEqual(tn.mean, 0.0)
|
||||
self.assertEqual(tn.stddev, 0.05)
|
||||
|
||||
@ -209,7 +206,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access
|
||||
try:
|
||||
tf2.enable()
|
||||
rn = keras.initializers.get('random_normal')
|
||||
rn = initializers.get('random_normal')
|
||||
self.assertIn('init_ops_v2', rn.__class__.__module__)
|
||||
finally:
|
||||
tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access
|
||||
@ -219,9 +216,9 @@ class KerasInitializersTest(test.TestCase):
|
||||
def my_initializer(shape, dtype=None):
|
||||
return array_ops.ones(shape, dtype=dtype)
|
||||
|
||||
inputs = keras.Input((10,))
|
||||
outputs = keras.layers.Dense(1, kernel_initializer=my_initializer)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
inputs = input_layer.Input((10,))
|
||||
outputs = core.Dense(1, kernel_initializer=my_initializer)(inputs)
|
||||
model = models.Model(inputs, outputs)
|
||||
model2 = model.from_config(
|
||||
model.get_config(), custom_objects={'my_initializer': my_initializer})
|
||||
self.assertEqual(model2.layers[1].kernel_initializer, my_initializer)
|
||||
@ -237,7 +234,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
'seed': None
|
||||
}
|
||||
}
|
||||
initializer = keras.initializers.deserialize(external_serialized_json)
|
||||
initializer = initializers.deserialize(external_serialized_json)
|
||||
self.assertEqual(initializer.distribution, 'truncated_normal')
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -26,7 +26,7 @@ from scipy.special import expit
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import metrics
|
||||
from tensorflow.python.keras import models
|
||||
@ -37,19 +37,19 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class FalsePositivesTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class FalsePositivesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
fp_obj = metrics.FalsePositives(name='my_fp', thresholds=[0.4, 0.9])
|
||||
self.assertEqual(fp_obj.name, 'my_fp')
|
||||
self.assertEqual(len(fp_obj.variables), 1)
|
||||
self.assertLen(fp_obj.variables, 1)
|
||||
self.assertEqual(fp_obj.thresholds, [0.4, 0.9])
|
||||
|
||||
# Check save and restore config
|
||||
fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config())
|
||||
self.assertEqual(fp_obj2.name, 'my_fp')
|
||||
self.assertEqual(len(fp_obj2.variables), 1)
|
||||
self.assertLen(fp_obj2.variables, 1)
|
||||
self.assertEqual(fp_obj2.thresholds, [0.4, 0.9])
|
||||
|
||||
def test_unweighted(self):
|
||||
@ -117,19 +117,19 @@ class FalsePositivesTest(test.TestCase):
|
||||
metrics.FalsePositives(thresholds=[None])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class FalseNegativesTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class FalseNegativesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
fn_obj = metrics.FalseNegatives(name='my_fn', thresholds=[0.4, 0.9])
|
||||
self.assertEqual(fn_obj.name, 'my_fn')
|
||||
self.assertEqual(len(fn_obj.variables), 1)
|
||||
self.assertLen(fn_obj.variables, 1)
|
||||
self.assertEqual(fn_obj.thresholds, [0.4, 0.9])
|
||||
|
||||
# Check save and restore config
|
||||
fn_obj2 = metrics.FalseNegatives.from_config(fn_obj.get_config())
|
||||
self.assertEqual(fn_obj2.name, 'my_fn')
|
||||
self.assertEqual(len(fn_obj2.variables), 1)
|
||||
self.assertLen(fn_obj2.variables, 1)
|
||||
self.assertEqual(fn_obj2.thresholds, [0.4, 0.9])
|
||||
|
||||
def test_unweighted(self):
|
||||
@ -185,19 +185,19 @@ class FalseNegativesTest(test.TestCase):
|
||||
self.assertAllClose([4., 16., 23.], self.evaluate(result))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TrueNegativesTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TrueNegativesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
tn_obj = metrics.TrueNegatives(name='my_tn', thresholds=[0.4, 0.9])
|
||||
self.assertEqual(tn_obj.name, 'my_tn')
|
||||
self.assertEqual(len(tn_obj.variables), 1)
|
||||
self.assertLen(tn_obj.variables, 1)
|
||||
self.assertEqual(tn_obj.thresholds, [0.4, 0.9])
|
||||
|
||||
# Check save and restore config
|
||||
tn_obj2 = metrics.TrueNegatives.from_config(tn_obj.get_config())
|
||||
self.assertEqual(tn_obj2.name, 'my_tn')
|
||||
self.assertEqual(len(tn_obj2.variables), 1)
|
||||
self.assertLen(tn_obj2.variables, 1)
|
||||
self.assertEqual(tn_obj2.thresholds, [0.4, 0.9])
|
||||
|
||||
def test_unweighted(self):
|
||||
@ -253,19 +253,19 @@ class TrueNegativesTest(test.TestCase):
|
||||
self.assertAllClose([5., 15., 23.], self.evaluate(result))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TruePositivesTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TruePositivesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
tp_obj = metrics.TruePositives(name='my_tp', thresholds=[0.4, 0.9])
|
||||
self.assertEqual(tp_obj.name, 'my_tp')
|
||||
self.assertEqual(len(tp_obj.variables), 1)
|
||||
self.assertLen(tp_obj.variables, 1)
|
||||
self.assertEqual(tp_obj.thresholds, [0.4, 0.9])
|
||||
|
||||
# Check save and restore config
|
||||
tp_obj2 = metrics.TruePositives.from_config(tp_obj.get_config())
|
||||
self.assertEqual(tp_obj2.name, 'my_tp')
|
||||
self.assertEqual(len(tp_obj2.variables), 1)
|
||||
self.assertLen(tp_obj2.variables, 1)
|
||||
self.assertEqual(tp_obj2.thresholds, [0.4, 0.9])
|
||||
|
||||
def test_unweighted(self):
|
||||
@ -320,14 +320,14 @@ class TruePositivesTest(test.TestCase):
|
||||
self.assertAllClose([222., 111., 37.], self.evaluate(result))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PrecisionTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class PrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
p_obj = metrics.Precision(
|
||||
name='my_precision', thresholds=[0.4, 0.9], top_k=15, class_id=12)
|
||||
self.assertEqual(p_obj.name, 'my_precision')
|
||||
self.assertEqual(len(p_obj.variables), 2)
|
||||
self.assertLen(p_obj.variables, 2)
|
||||
self.assertEqual([v.name for v in p_obj.variables],
|
||||
['true_positives:0', 'false_positives:0'])
|
||||
self.assertEqual(p_obj.thresholds, [0.4, 0.9])
|
||||
@ -337,7 +337,7 @@ class PrecisionTest(test.TestCase):
|
||||
# Check save and restore config
|
||||
p_obj2 = metrics.Precision.from_config(p_obj.get_config())
|
||||
self.assertEqual(p_obj2.name, 'my_precision')
|
||||
self.assertEqual(len(p_obj2.variables), 2)
|
||||
self.assertLen(p_obj2.variables, 2)
|
||||
self.assertEqual(p_obj2.thresholds, [0.4, 0.9])
|
||||
self.assertEqual(p_obj2.top_k, 15)
|
||||
self.assertEqual(p_obj2.class_id, 12)
|
||||
@ -525,14 +525,14 @@ class PrecisionTest(test.TestCase):
|
||||
self.assertAlmostEqual(0, self.evaluate(p_obj.false_positives))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RecallTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class RecallTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
r_obj = metrics.Recall(
|
||||
name='my_recall', thresholds=[0.4, 0.9], top_k=15, class_id=12)
|
||||
self.assertEqual(r_obj.name, 'my_recall')
|
||||
self.assertEqual(len(r_obj.variables), 2)
|
||||
self.assertLen(r_obj.variables, 2)
|
||||
self.assertEqual([v.name for v in r_obj.variables],
|
||||
['true_positives:0', 'false_negatives:0'])
|
||||
self.assertEqual(r_obj.thresholds, [0.4, 0.9])
|
||||
@ -542,7 +542,7 @@ class RecallTest(test.TestCase):
|
||||
# Check save and restore config
|
||||
r_obj2 = metrics.Recall.from_config(r_obj.get_config())
|
||||
self.assertEqual(r_obj2.name, 'my_recall')
|
||||
self.assertEqual(len(r_obj2.variables), 2)
|
||||
self.assertLen(r_obj2.variables, 2)
|
||||
self.assertEqual(r_obj2.thresholds, [0.4, 0.9])
|
||||
self.assertEqual(r_obj2.top_k, 15)
|
||||
self.assertEqual(r_obj2.class_id, 12)
|
||||
@ -729,7 +729,7 @@ class RecallTest(test.TestCase):
|
||||
self.assertAlmostEqual(3, self.evaluate(r_obj.false_negatives))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -771,13 +771,14 @@ class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
|
||||
1e-3)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = constant_op.constant(inputs, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(inputs)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(1, self.evaluate(result))
|
||||
with self.test_session():
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = constant_op.constant(inputs, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(inputs)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(1, self.evaluate(result))
|
||||
|
||||
def test_unweighted_high_specificity(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.8)
|
||||
@ -825,7 +826,7 @@ class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
|
||||
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -921,7 +922,7 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
|
||||
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1018,7 +1019,7 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
|
||||
metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class RecallAtPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1133,8 +1134,8 @@ class RecallAtPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
metrics.RecallAtPrecision(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class AUCTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class AUCTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setup(self):
|
||||
self.num_thresholds = 3
|
||||
@ -1172,7 +1173,7 @@ class AUCTest(test.TestCase):
|
||||
name='auc_1')
|
||||
auc_obj.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj.name, 'auc_1')
|
||||
self.assertEqual(len(auc_obj.variables), 4)
|
||||
self.assertLen(auc_obj.variables, 4)
|
||||
self.assertEqual(auc_obj.num_thresholds, 100)
|
||||
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(auc_obj.summation_method,
|
||||
@ -1184,7 +1185,7 @@ class AUCTest(test.TestCase):
|
||||
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
|
||||
auc_obj2.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj2.name, 'auc_1')
|
||||
self.assertEqual(len(auc_obj2.variables), 4)
|
||||
self.assertLen(auc_obj2.variables, 4)
|
||||
self.assertEqual(auc_obj2.num_thresholds, 100)
|
||||
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(auc_obj2.summation_method,
|
||||
@ -1203,7 +1204,7 @@ class AUCTest(test.TestCase):
|
||||
thresholds=[0.3, 0.5])
|
||||
auc_obj.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj.name, 'auc_1')
|
||||
self.assertEqual(len(auc_obj.variables), 4)
|
||||
self.assertLen(auc_obj.variables, 4)
|
||||
self.assertEqual(auc_obj.num_thresholds, 4)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
|
||||
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
|
||||
@ -1216,7 +1217,7 @@ class AUCTest(test.TestCase):
|
||||
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
|
||||
auc_obj2.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj2.name, 'auc_1')
|
||||
self.assertEqual(len(auc_obj2.variables), 4)
|
||||
self.assertLen(auc_obj2.variables, 4)
|
||||
self.assertEqual(auc_obj2.num_thresholds, 4)
|
||||
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(auc_obj2.summation_method,
|
||||
@ -1407,8 +1408,8 @@ class AUCTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(result), 0.5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MultiAUCTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setup(self):
|
||||
self.num_thresholds = 5
|
||||
@ -1457,26 +1458,28 @@ class MultiAUCTest(test.TestCase):
|
||||
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
|
||||
def test_value_is_idempotent(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=5, multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=5, multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
|
||||
# Run several updates.
|
||||
update_op = auc_obj.update_state(self.y_true_good, self.y_pred)
|
||||
for _ in range(10):
|
||||
self.evaluate(update_op)
|
||||
# Run several updates.
|
||||
update_op = auc_obj.update_state(self.y_true_good, self.y_pred)
|
||||
for _ in range(10):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_auc = self.evaluate(auc_obj.result())
|
||||
for _ in range(10):
|
||||
self.assertAllClose(initial_auc, self.evaluate(auc_obj.result()), 1e-3)
|
||||
# Then verify idempotency.
|
||||
initial_auc = self.evaluate(auc_obj.result())
|
||||
for _ in range(10):
|
||||
self.assertAllClose(initial_auc, self.evaluate(auc_obj.result()), 1e-3)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_true_good)
|
||||
self.assertEqual(self.evaluate(result), 1)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_true_good)
|
||||
self.assertEqual(self.evaluate(result), 1)
|
||||
|
||||
def test_unweighted_all_correct_flat(self):
|
||||
self.setup()
|
||||
@ -1486,15 +1489,17 @@ class MultiAUCTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(result), 1)
|
||||
|
||||
def test_unweighted(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_sample_weight_flat(self):
|
||||
self.setup()
|
||||
@ -1521,18 +1526,19 @@ class MultiAUCTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_label_weights(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=True,
|
||||
label_weights=[0.75, 0.25])
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=True,
|
||||
label_weights=[0.75, 0.25])
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_label_weights_flat(self):
|
||||
self.setup()
|
||||
@ -1565,65 +1571,72 @@ class MultiAUCTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_manual_thresholds(self):
|
||||
self.setup()
|
||||
# Verify that when specified, thresholds are used instead of num_thresholds.
|
||||
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5], multi_label=True)
|
||||
self.assertEqual(auc_obj.num_thresholds, 3)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
# Verify that when specified, thresholds are used instead of
|
||||
# num_thresholds.
|
||||
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5],
|
||||
multi_label=True)
|
||||
self.assertEqual(auc_obj.num_thresholds, 3)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tp = [[2, 1, 0], [2, 0, 0]]
|
||||
# fp = [2, 0, 0], [2, 0, 0]]
|
||||
# fn = [[0, 1, 2], [0, 2, 2]]
|
||||
# tn = [[0, 2, 2], [0, 2, 2]]
|
||||
# tp = [[2, 1, 0], [2, 0, 0]]
|
||||
# fp = [2, 0, 0], [2, 0, 0]]
|
||||
# fn = [[0, 1, 2], [0, 2, 2]]
|
||||
# tn = [[0, 2, 2], [0, 2, 2]]
|
||||
|
||||
# tpr = [[1, 0.5, 0], [1, 0, 0]]
|
||||
# fpr = [[1, 0, 0], [1, 0, 0]]
|
||||
# tpr = [[1, 0.5, 0], [1, 0, 0]]
|
||||
# fpr = [[1, 0, 0], [1, 0, 0]]
|
||||
|
||||
# auc by slice = [0.75, 0.5]
|
||||
expected_result = (0.75 + 0.5) / 2.0
|
||||
# auc by slice = [0.75, 0.5]
|
||||
expected_result = (0.75 + 0.5) / 2.0
|
||||
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_weighted_roc_interpolation(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(
|
||||
self.y_true_good, self.y_pred, sample_weight=self.sample_weight)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(
|
||||
self.y_true_good, self.y_pred, sample_weight=self.sample_weight)
|
||||
|
||||
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = 1.0 - 0.5 * 0.43 * 0.67
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-1)
|
||||
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = 1.0 - 0.5 * 0.43 * 0.67
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-1)
|
||||
|
||||
def test_pr_interpolation_unweighted(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
good_result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.subTest(name='good'):
|
||||
# PR AUCs are 0.917 and 1.0 respectively
|
||||
self.assertAllClose(self.evaluate(good_result), (0.91667 + 1.0) / 2.0,
|
||||
1e-1)
|
||||
bad_result = auc_obj(self.y_true_bad, self.y_pred)
|
||||
with self.subTest(name='bad'):
|
||||
# PR AUCs are 0.917 and 0.5 respectively
|
||||
self.assertAllClose(self.evaluate(bad_result), (0.91667 + 0.5) / 2.0,
|
||||
1e-1)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
good_result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.subTest(name='good'):
|
||||
# PR AUCs are 0.917 and 1.0 respectively
|
||||
self.assertAllClose(self.evaluate(good_result), (0.91667 + 1.0) / 2.0,
|
||||
1e-1)
|
||||
bad_result = auc_obj(self.y_true_bad, self.y_pred)
|
||||
with self.subTest(name='bad'):
|
||||
# PR AUCs are 0.917 and 0.5 respectively
|
||||
self.assertAllClose(self.evaluate(bad_result), (0.91667 + 0.5) / 2.0,
|
||||
1e-1)
|
||||
|
||||
def test_pr_interpolation(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
good_result = auc_obj(self.y_true_good, self.y_pred,
|
||||
sample_weight=self.sample_weight)
|
||||
# PR AUCs are 0.939 and 1.0 respectively
|
||||
self.assertAllClose(self.evaluate(good_result), (0.939 + 1.0) / 2.0,
|
||||
1e-1)
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
good_result = auc_obj(self.y_true_good, self.y_pred,
|
||||
sample_weight=self.sample_weight)
|
||||
# PR AUCs are 0.939 and 1.0 respectively
|
||||
self.assertAllClose(self.evaluate(good_result), (0.939 + 1.0) / 2.0,
|
||||
1e-1)
|
||||
|
||||
def test_keras_model_compiles(self):
|
||||
inputs = layers.Input(shape=(10,))
|
||||
@ -1635,12 +1648,14 @@ class MultiAUCTest(test.TestCase):
|
||||
)
|
||||
|
||||
def test_reset_states(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
auc_obj(self.y_true_good, self.y_pred)
|
||||
auc_obj.reset_states()
|
||||
self.assertAllEqual(auc_obj.true_positives, np.zeros((5, 2)))
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
|
||||
multi_label=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
auc_obj(self.y_true_good, self.y_pred)
|
||||
auc_obj.reset_states()
|
||||
self.assertAllEqual(auc_obj.true_positives, np.zeros((5, 2)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,6 +22,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -32,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import metrics
|
||||
@ -46,35 +48,36 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class KerasSumTest(test.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KerasSumTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_sum(self):
|
||||
m = metrics.Sum(name='my_sum')
|
||||
with self.test_session():
|
||||
m = metrics.Sum(name='my_sum')
|
||||
|
||||
# check config
|
||||
self.assertEqual(m.name, 'my_sum')
|
||||
self.assertTrue(m.stateful)
|
||||
self.assertEqual(m.dtype, dtypes.float32)
|
||||
self.assertEqual(len(m.variables), 1)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
# check config
|
||||
self.assertEqual(m.name, 'my_sum')
|
||||
self.assertTrue(m.stateful)
|
||||
self.assertEqual(m.dtype, dtypes.float32)
|
||||
self.assertLen(m.variables, 1)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
|
||||
# check initial state
|
||||
self.assertEqual(self.evaluate(m.total), 0)
|
||||
# check initial state
|
||||
self.assertEqual(self.evaluate(m.total), 0)
|
||||
|
||||
# check __call__()
|
||||
self.assertEqual(self.evaluate(m(100)), 100)
|
||||
self.assertEqual(self.evaluate(m.total), 100)
|
||||
# check __call__()
|
||||
self.assertEqual(self.evaluate(m(100)), 100)
|
||||
self.assertEqual(self.evaluate(m.total), 100)
|
||||
|
||||
# check update_state() and result() + state accumulation + tensor input
|
||||
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
|
||||
self.evaluate(update_op)
|
||||
self.assertAlmostEqual(self.evaluate(m.result()), 106)
|
||||
self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
|
||||
# check update_state() and result() + state accumulation + tensor input
|
||||
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
|
||||
self.evaluate(update_op)
|
||||
self.assertAlmostEqual(self.evaluate(m.result()), 106)
|
||||
self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
|
||||
|
||||
# check reset_states()
|
||||
m.reset_states()
|
||||
self.assertEqual(self.evaluate(m.total), 0)
|
||||
# check reset_states()
|
||||
m.reset_states()
|
||||
self.assertEqual(self.evaluate(m.total), 0)
|
||||
|
||||
def test_sum_with_sample_weight(self):
|
||||
m = metrics.Sum(dtype=dtypes.float64)
|
||||
@ -133,33 +136,34 @@ class KerasSumTest(test.TestCase):
|
||||
self.assertAlmostEqual(self.evaluate(m.total), 52., 2)
|
||||
|
||||
def test_save_restore(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
m = metrics.Sum()
|
||||
checkpoint = trackable_utils.Checkpoint(sum=m)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
with self.test_session():
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
m = metrics.Sum()
|
||||
checkpoint = trackable_utils.Checkpoint(sum=m)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
|
||||
# update state
|
||||
self.evaluate(m(100.))
|
||||
self.evaluate(m(200.))
|
||||
# update state
|
||||
self.evaluate(m(100.))
|
||||
self.evaluate(m(200.))
|
||||
|
||||
# save checkpoint and then add an update
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.evaluate(m(1000.))
|
||||
# save checkpoint and then add an update
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.evaluate(m(1000.))
|
||||
|
||||
# restore to the same checkpoint sum object (= 300)
|
||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self.evaluate(m(300.))
|
||||
self.assertEqual(600., self.evaluate(m.result()))
|
||||
# restore to the same checkpoint sum object (= 300)
|
||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self.evaluate(m(300.))
|
||||
self.assertEqual(600., self.evaluate(m.result()))
|
||||
|
||||
# restore to a different checkpoint sum object
|
||||
restore_sum = metrics.Sum()
|
||||
restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum)
|
||||
status = restore_checkpoint.restore(save_path)
|
||||
restore_update = restore_sum(300.)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.evaluate(restore_update)
|
||||
self.assertEqual(600., self.evaluate(restore_sum.result()))
|
||||
# restore to a different checkpoint sum object
|
||||
restore_sum = metrics.Sum()
|
||||
restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum)
|
||||
status = restore_checkpoint.restore(save_path)
|
||||
restore_update = restore_sum(300.)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.evaluate(restore_update)
|
||||
self.assertEqual(600., self.evaluate(restore_sum.result()))
|
||||
|
||||
|
||||
class MeanTest(keras_parameterized.TestCase):
|
||||
@ -354,7 +358,7 @@ class MeanTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(m.count), 1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KerasAccuracyTest(test.TestCase):
|
||||
|
||||
def test_accuracy(self):
|
||||
@ -598,7 +602,7 @@ class KerasAccuracyTest(test.TestCase):
|
||||
self.assertEqual(acc_fn, metrics.accuracy)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class CosineSimilarityTest(test.TestCase):
|
||||
|
||||
def l2_norm(self, x, axis):
|
||||
@ -659,7 +663,7 @@ class CosineSimilarityTest(test.TestCase):
|
||||
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanAbsoluteErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -697,7 +701,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
|
||||
self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanAbsolutePercentageErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -737,7 +741,7 @@ class MeanAbsolutePercentageErrorTest(test.TestCase):
|
||||
self.assertAllClose(40e7, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -775,7 +779,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanSquaredLogarithmicErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -815,7 +819,7 @@ class MeanSquaredLogarithmicErrorTest(test.TestCase):
|
||||
self.assertAllClose(0.26082, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class HingeTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -870,7 +874,7 @@ class HingeTest(test.TestCase):
|
||||
self.assertAllClose(0.493, self.evaluate(result), atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SquaredHingeTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -931,7 +935,7 @@ class SquaredHingeTest(test.TestCase):
|
||||
self.assertAllClose(0.347, self.evaluate(result), atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class CategoricalHingeTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -971,7 +975,7 @@ class CategoricalHingeTest(test.TestCase):
|
||||
self.assertAllClose(0.5, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class RootMeanSquaredErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1005,7 +1009,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
|
||||
self.assertAllClose(math.sqrt(13), self.evaluate(result), atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TopKCategoricalAccuracyTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1052,7 +1056,7 @@ class TopKCategoricalAccuracyTest(test.TestCase):
|
||||
self.assertAllClose(1.0, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SparseTopKCategoricalAccuracyTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1099,7 +1103,7 @@ class SparseTopKCategoricalAccuracyTest(test.TestCase):
|
||||
self.assertAllClose(1.0, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class LogCoshErrorTest(test.TestCase):
|
||||
|
||||
def setup(self):
|
||||
@ -1142,7 +1146,7 @@ class LogCoshErrorTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class PoissonTest(test.TestCase):
|
||||
|
||||
def setup(self):
|
||||
@ -1188,7 +1192,7 @@ class PoissonTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class KLDivergenceTest(test.TestCase):
|
||||
|
||||
def setup(self):
|
||||
@ -1235,7 +1239,7 @@ class KLDivergenceTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanRelativeErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1291,7 +1295,7 @@ class MeanRelativeErrorTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(result), 0)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanIoUTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1374,90 +1378,93 @@ class MeanIoUTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
|
||||
|
||||
|
||||
class MeanTensorTest(test.TestCase):
|
||||
class MeanTensorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_config(self):
|
||||
m = metrics.MeanTensor(name='mean_by_element')
|
||||
with self.test_session():
|
||||
m = metrics.MeanTensor(name='mean_by_element')
|
||||
|
||||
# check config
|
||||
self.assertEqual(m.name, 'mean_by_element')
|
||||
self.assertTrue(m.stateful)
|
||||
self.assertEqual(m.dtype, dtypes.float32)
|
||||
self.assertEqual(len(m.variables), 0)
|
||||
# check config
|
||||
self.assertEqual(m.name, 'mean_by_element')
|
||||
self.assertTrue(m.stateful)
|
||||
self.assertEqual(m.dtype, dtypes.float32)
|
||||
self.assertEmpty(m.variables)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'does not have any result yet'):
|
||||
m.result()
|
||||
with self.assertRaisesRegexp(ValueError, 'does not have any result yet'):
|
||||
m.result()
|
||||
|
||||
self.evaluate(m([[3], [5], [3]]))
|
||||
self.assertAllEqual(m._shape, [3, 1])
|
||||
self.evaluate(m([[3], [5], [3]]))
|
||||
self.assertAllEqual(m._shape, [3, 1])
|
||||
|
||||
m2 = metrics.MeanTensor.from_config(m.get_config())
|
||||
self.assertEqual(m2.name, 'mean_by_element')
|
||||
self.assertTrue(m2.stateful)
|
||||
self.assertEqual(m2.dtype, dtypes.float32)
|
||||
self.assertEqual(len(m2.variables), 0)
|
||||
m2 = metrics.MeanTensor.from_config(m.get_config())
|
||||
self.assertEqual(m2.name, 'mean_by_element')
|
||||
self.assertTrue(m2.stateful)
|
||||
self.assertEqual(m2.dtype, dtypes.float32)
|
||||
self.assertEmpty(m2.variables)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_unweighted(self):
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
with self.test_session():
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
|
||||
# check __call__()
|
||||
self.assertAllClose(self.evaluate(m([100, 40])), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.total), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.count), [1, 1])
|
||||
# check __call__()
|
||||
self.assertAllClose(self.evaluate(m([100, 40])), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.total), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.count), [1, 1])
|
||||
|
||||
# check update_state() and result() + state accumulation + tensor input
|
||||
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
|
||||
self.evaluate(update_op)
|
||||
self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5])
|
||||
self.assertAllClose(self.evaluate(m.total), [101, 45])
|
||||
self.assertAllClose(self.evaluate(m.count), [2, 2])
|
||||
# check update_state() and result() + state accumulation + tensor input
|
||||
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
|
||||
self.evaluate(update_op)
|
||||
self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5])
|
||||
self.assertAllClose(self.evaluate(m.total), [101, 45])
|
||||
self.assertAllClose(self.evaluate(m.count), [2, 2])
|
||||
|
||||
# check reset_states()
|
||||
m.reset_states()
|
||||
self.assertAllClose(self.evaluate(m.total), [0, 0])
|
||||
self.assertAllClose(self.evaluate(m.count), [0, 0])
|
||||
# check reset_states()
|
||||
m.reset_states()
|
||||
self.assertAllClose(self.evaluate(m.total), [0, 0])
|
||||
self.assertAllClose(self.evaluate(m.count), [0, 0])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weighted(self):
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
self.assertEqual(m.dtype, dtypes.float64)
|
||||
with self.test_session():
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
self.assertEqual(m.dtype, dtypes.float64)
|
||||
|
||||
# check scalar weight
|
||||
result_t = m([100, 30], sample_weight=0.5)
|
||||
self.assertAllClose(self.evaluate(result_t), [100, 30])
|
||||
self.assertAllClose(self.evaluate(m.total), [50, 15])
|
||||
self.assertAllClose(self.evaluate(m.count), [0.5, 0.5])
|
||||
# check scalar weight
|
||||
result_t = m([100, 30], sample_weight=0.5)
|
||||
self.assertAllClose(self.evaluate(result_t), [100, 30])
|
||||
self.assertAllClose(self.evaluate(m.total), [50, 15])
|
||||
self.assertAllClose(self.evaluate(m.count), [0.5, 0.5])
|
||||
|
||||
# check weights not scalar and weights rank matches values rank
|
||||
result_t = m([1, 5], sample_weight=[1, 0.2])
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2)
|
||||
self.assertAllClose(self.evaluate(m.total), [51, 16])
|
||||
self.assertAllClose(self.evaluate(m.count), [1.5, 0.7])
|
||||
# check weights not scalar and weights rank matches values rank
|
||||
result_t = m([1, 5], sample_weight=[1, 0.2])
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2)
|
||||
self.assertAllClose(self.evaluate(m.total), [51, 16])
|
||||
self.assertAllClose(self.evaluate(m.count), [1.5, 0.7])
|
||||
|
||||
# check weights broadcast
|
||||
result_t = m([1, 2], sample_weight=0.5)
|
||||
self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2])
|
||||
self.assertAllClose(self.evaluate(m.total), [51.5, 17])
|
||||
self.assertAllClose(self.evaluate(m.count), [2, 1.2])
|
||||
# check weights broadcast
|
||||
result_t = m([1, 2], sample_weight=0.5)
|
||||
self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2])
|
||||
self.assertAllClose(self.evaluate(m.total), [51.5, 17])
|
||||
self.assertAllClose(self.evaluate(m.count), [2, 1.2])
|
||||
|
||||
# check weights squeeze
|
||||
result_t = m([1, 5], sample_weight=[[1], [0.2]])
|
||||
self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4])
|
||||
self.assertAllClose(self.evaluate(m.total), [52.5, 18])
|
||||
self.assertAllClose(self.evaluate(m.count), [3, 1.4])
|
||||
# check weights squeeze
|
||||
result_t = m([1, 5], sample_weight=[[1], [0.2]])
|
||||
self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4])
|
||||
self.assertAllClose(self.evaluate(m.total), [52.5, 18])
|
||||
self.assertAllClose(self.evaluate(m.count), [3, 1.4])
|
||||
|
||||
# check weights expand
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
result_t = m([[1], [5]], sample_weight=[1, 0.2])
|
||||
self.assertAllClose(self.evaluate(result_t), [[1], [5]])
|
||||
self.assertAllClose(self.evaluate(m.total), [[1], [1]])
|
||||
self.assertAllClose(self.evaluate(m.count), [[1], [0.2]])
|
||||
# check weights expand
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
result_t = m([[1], [5]], sample_weight=[1, 0.2])
|
||||
self.assertAllClose(self.evaluate(result_t), [[1], [5]])
|
||||
self.assertAllClose(self.evaluate(m.total), [[1], [1]])
|
||||
self.assertAllClose(self.evaluate(m.count), [[1], [0.2]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_invalid_value_shape(self):
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
m([1])
|
||||
@ -1465,7 +1472,7 @@ class MeanTensorTest(test.TestCase):
|
||||
ValueError, 'MeanTensor input values must always have the same shape'):
|
||||
m([1, 5])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_build_in_tf_function(self):
|
||||
"""Ensure that variables are created correctly in a tf function."""
|
||||
m = metrics.MeanTensor(dtype=dtypes.float64)
|
||||
@ -1474,10 +1481,11 @@ class MeanTensorTest(test.TestCase):
|
||||
def call_metric(x):
|
||||
return m(x)
|
||||
|
||||
self.assertAllClose(self.evaluate(call_metric([100, 40])), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.total), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.count), [1, 1])
|
||||
self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21])
|
||||
with self.test_session():
|
||||
self.assertAllClose(self.evaluate(call_metric([100, 40])), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.total), [100, 40])
|
||||
self.assertAllClose(self.evaluate(m.count), [1, 1])
|
||||
self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21])
|
||||
|
||||
def test_in_keras_model(self):
|
||||
with context.eager_mode():
|
||||
@ -1522,7 +1530,7 @@ class MeanTensorTest(test.TestCase):
|
||||
np.full((4, 3), 4))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class BinaryCrossentropyTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1642,7 +1650,7 @@ class BinaryCrossentropyTest(test.TestCase):
|
||||
self.assertAllClose(expected_value, self.evaluate(result), atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class CategoricalCrossentropyTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1768,7 +1776,7 @@ class CategoricalCrossentropyTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(loss), 3.667, atol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SparseCategoricalCrossentropyTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
@ -1943,7 +1951,7 @@ class BinaryTruePositives(metrics.Metric):
|
||||
return self.true_positives
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class CustomMetricsTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
|
Loading…
Reference in New Issue
Block a user