Update tests under Keras to use combination.

1. Change all test_util.run_all_in_graph_and_eager_modes to combination.
2. Replace import tensorflow.python.keras with explicit module import.
3. Update BUILD file to not rely on the overall Keras target.

PiperOrigin-RevId: 299403937
Change-Id: Ic798fd0cc2602b6447aaf28922a4a418514d7131
This commit is contained in:
Scott Zhu 2020-03-06 11:52:17 -08:00 committed by TensorFlower Gardener
parent c4d9a3a647
commit c8823160e0
9 changed files with 1355 additions and 1297 deletions

View File

@ -94,6 +94,7 @@ py_library(
name = "backend_config", name = "backend_config",
srcs = ["backend_config.py"], srcs = ["backend_config.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = ["//tensorflow/python:util"],
) )
# TODO(scottzhu): Cleanup this target and point all the user to keras/engine. # TODO(scottzhu): Cleanup this target and point all the user to keras/engine.
@ -322,9 +323,14 @@ tf_py_test(
srcs = ["activations_test.py"], srcs = ["activations_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":keras", ":activations",
":backend",
":combinations",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python/keras/layers",
"//tensorflow/python/keras/layers:advanced_activations",
"//tensorflow/python/keras/layers:core",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -352,10 +358,11 @@ tf_py_test(
srcs = ["constraints_test.py"], srcs = ["constraints_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":keras", ":backend",
":combinations",
":constraints",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
], ],
) )
@ -365,11 +372,17 @@ tf_py_test(
srcs = ["initializers_test.py"], srcs = ["initializers_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":keras", ":backend",
":combinations",
":initializers",
":models",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:tf2",
"//tensorflow/python/keras/engine",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
], ],
) )
@ -407,10 +420,17 @@ tf_py_test(
srcs = ["losses_test.py"], srcs = ["losses_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":keras", ":backend",
":combinations",
":losses",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/keras/utils:engine_utils",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
], ],
) )
@ -433,10 +453,27 @@ tf_py_test(
python_version = "PY3", python_version = "PY3",
shard_count = 4, shard_count = 4,
deps = [ deps = [
":combinations",
":keras", ":keras",
":metrics",
":testing_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
"//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/keras/layers",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/training/tracking:util",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
], ],
) )
@ -447,8 +484,17 @@ tf_py_test(
python_version = "PY3", python_version = "PY3",
shard_count = 4, shard_count = 4,
deps = [ deps = [
":keras", ":combinations",
":metrics",
":models",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
"//tensorflow/python/keras/layers",
"//tensorflow/python/keras/utils:metrics_utils",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -527,9 +573,22 @@ tf_py_test(
python_version = "PY3", python_version = "PY3",
shard_count = 4, shard_count = 4,
deps = [ deps = [
":keras", ":backend",
":combinations",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:config",
"//tensorflow/python:errors",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -541,10 +600,10 @@ tf_py_test(
srcs = ["backend_config_test.py"], srcs = ["backend_config_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":keras", ":backend",
":backend_config",
":combinations",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:util",
"//third_party/py/numpy",
], ],
) )

View File

@ -18,10 +18,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python.keras import activations
from tensorflow.python.framework import test_util from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import serialization
from tensorflow.python.ops import nn_ops as nn from tensorflow.python.ops import nn_ops as nn
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -32,34 +37,34 @@ def _ref_softmax(values):
return e / np.sum(e) return e / np.sum(e)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KerasActivationsTest(test.TestCase): class KerasActivationsTest(test.TestCase, parameterized.TestCase):
def test_serialization(self): def test_serialization(self):
all_activations = ['softmax', 'relu', 'elu', 'tanh', all_activations = ['softmax', 'relu', 'elu', 'tanh',
'sigmoid', 'hard_sigmoid', 'linear', 'sigmoid', 'hard_sigmoid', 'linear',
'softplus', 'softsign', 'selu'] 'softplus', 'softsign', 'selu']
for name in all_activations: for name in all_activations:
fn = keras.activations.get(name) fn = activations.get(name)
ref_fn = getattr(keras.activations, name) ref_fn = getattr(activations, name)
assert fn == ref_fn assert fn == ref_fn
config = keras.activations.serialize(fn) config = activations.serialize(fn)
fn = keras.activations.deserialize(config) fn = activations.deserialize(config)
assert fn == ref_fn assert fn == ref_fn
def test_serialization_v2(self): def test_serialization_v2(self):
activation_map = {nn.softmax_v2: 'softmax'} activation_map = {nn.softmax_v2: 'softmax'}
for fn_v2_key in activation_map: for fn_v2_key in activation_map:
fn_v2 = keras.activations.get(fn_v2_key) fn_v2 = activations.get(fn_v2_key)
config = keras.activations.serialize(fn_v2) config = activations.serialize(fn_v2)
fn = keras.activations.deserialize(config) fn = activations.deserialize(config)
assert fn.__name__ == activation_map[fn_v2_key] assert fn.__name__ == activation_map[fn_v2_key]
def test_serialization_with_layers(self): def test_serialization_with_layers(self):
activation = keras.layers.LeakyReLU(alpha=0.1) activation = advanced_activations.LeakyReLU(alpha=0.1)
layer = keras.layers.Dense(3, activation=activation) layer = core.Dense(3, activation=activation)
config = keras.layers.serialize(layer) config = serialization.serialize(layer)
deserialized_layer = keras.layers.deserialize( deserialized_layer = serialization.deserialize(
config, custom_objects={'LeakyReLU': activation}) config, custom_objects={'LeakyReLU': activation})
self.assertEqual(deserialized_layer.__class__.__name__, self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__) layer.__class__.__name__)
@ -67,8 +72,8 @@ class KerasActivationsTest(test.TestCase):
activation.__class__.__name__) activation.__class__.__name__)
def test_softmax(self): def test_softmax(self):
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softmax(x)]) f = backend.function([x], [activations.softmax(x)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
@ -76,28 +81,28 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result[0], expected, rtol=1e-05) self.assertAllClose(result[0], expected, rtol=1e-05)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
x = keras.backend.placeholder(ndim=1) x = backend.placeholder(ndim=1)
keras.activations.softmax(x) activations.softmax(x)
def test_temporal_softmax(self): def test_temporal_softmax(self):
x = keras.backend.placeholder(shape=(2, 2, 3)) x = backend.placeholder(shape=(2, 2, 3))
f = keras.backend.function([x], [keras.activations.softmax(x)]) f = backend.function([x], [activations.softmax(x)])
test_values = np.random.random((2, 2, 3)) * 10 test_values = np.random.random((2, 2, 3)) * 10
result = f([test_values])[0] result = f([test_values])[0]
expected = _ref_softmax(test_values[0, 0]) expected = _ref_softmax(test_values[0, 0])
self.assertAllClose(result[0, 0], expected, rtol=1e-05) self.assertAllClose(result[0, 0], expected, rtol=1e-05)
def test_selu(self): def test_selu(self):
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.selu(x)]) f = backend.function([x], [activations.selu(x)])
alpha = 1.6732632423543772848170429916717 alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946 scale = 1.0507009873554804934193349852946
positive_values = np.array([[1, 2]], dtype=keras.backend.floatx()) positive_values = np.array([[1, 2]], dtype=backend.floatx())
result = f([positive_values])[0] result = f([positive_values])[0]
self.assertAllClose(result, positive_values * scale, rtol=1e-05) self.assertAllClose(result, positive_values * scale, rtol=1e-05)
negative_values = np.array([[-1, -2]], dtype=keras.backend.floatx()) negative_values = np.array([[-1, -2]], dtype=backend.floatx())
result = f([negative_values])[0] result = f([negative_values])[0]
true_result = (np.exp(negative_values) - 1) * scale * alpha true_result = (np.exp(negative_values) - 1) * scale * alpha
self.assertAllClose(result, true_result) self.assertAllClose(result, true_result)
@ -106,8 +111,8 @@ class KerasActivationsTest(test.TestCase):
def softplus(x): def softplus(x):
return np.log(np.ones_like(x) + np.exp(x)) return np.log(np.ones_like(x) + np.exp(x))
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softplus(x)]) f = backend.function([x], [activations.softplus(x)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
expected = softplus(test_values) expected = softplus(test_values)
@ -117,8 +122,8 @@ class KerasActivationsTest(test.TestCase):
def softsign(x): def softsign(x):
return np.divide(x, np.ones_like(x) + np.absolute(x)) return np.divide(x, np.ones_like(x) + np.absolute(x))
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softsign(x)]) f = backend.function([x], [activations.softsign(x)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
expected = softsign(test_values) expected = softsign(test_values)
@ -133,8 +138,8 @@ class KerasActivationsTest(test.TestCase):
return z / (1 + z) return z / (1 + z)
sigmoid = np.vectorize(ref_sigmoid) sigmoid = np.vectorize(ref_sigmoid)
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.sigmoid(x)]) f = backend.function([x], [activations.sigmoid(x)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
expected = sigmoid(test_values) expected = sigmoid(test_values)
@ -146,16 +151,16 @@ class KerasActivationsTest(test.TestCase):
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x) z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
return z return z
hard_sigmoid = np.vectorize(ref_hard_sigmoid) hard_sigmoid = np.vectorize(ref_hard_sigmoid)
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.hard_sigmoid(x)]) f = backend.function([x], [activations.hard_sigmoid(x)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
expected = hard_sigmoid(test_values) expected = hard_sigmoid(test_values)
self.assertAllClose(result, expected, rtol=1e-05) self.assertAllClose(result, expected, rtol=1e-05)
def test_relu(self): def test_relu(self):
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.relu(x)]) f = backend.function([x], [activations.relu(x)])
positive_values = np.random.random((2, 5)) positive_values = np.random.random((2, 5))
result = f([positive_values])[0] result = f([positive_values])[0]
self.assertAllClose(result, positive_values, rtol=1e-05) self.assertAllClose(result, positive_values, rtol=1e-05)
@ -166,44 +171,45 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-05) self.assertAllClose(result, expected, rtol=1e-05)
def test_elu(self): def test_elu(self):
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.elu(x, 0.5)]) f = backend.function([x], [activations.elu(x, 0.5)])
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
result = f([test_values])[0] result = f([test_values])[0]
self.assertAllClose(result, test_values, rtol=1e-05) self.assertAllClose(result, test_values, rtol=1e-05)
negative_values = np.array([[-1, -2]], dtype=keras.backend.floatx()) negative_values = np.array([[-1, -2]], dtype=backend.floatx())
result = f([negative_values])[0] result = f([negative_values])[0]
true_result = (np.exp(negative_values) - 1) / 2 true_result = (np.exp(negative_values) - 1) / 2
self.assertAllClose(result, true_result) self.assertAllClose(result, true_result)
def test_tanh(self): def test_tanh(self):
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
exp = keras.activations.tanh(x) exp = activations.tanh(x)
f = keras.backend.function([x], [exp]) f = backend.function([x], [exp])
result = f([test_values])[0] result = f([test_values])[0]
expected = np.tanh(test_values) expected = np.tanh(test_values)
self.assertAllClose(result, expected, rtol=1e-05) self.assertAllClose(result, expected, rtol=1e-05)
def test_exponential(self): def test_exponential(self):
test_values = np.random.random((2, 5)) test_values = np.random.random((2, 5))
x = keras.backend.placeholder(ndim=2) x = backend.placeholder(ndim=2)
exp = keras.activations.exponential(x) exp = activations.exponential(x)
f = keras.backend.function([x], [exp]) f = backend.function([x], [exp])
result = f([test_values])[0] result = f([test_values])[0]
expected = np.exp(test_values) expected = np.exp(test_values)
self.assertAllClose(result, expected, rtol=1e-05) self.assertAllClose(result, expected, rtol=1e-05)
def test_linear(self): def test_linear(self):
x = np.random.random((10, 5)) x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x)) self.assertAllClose(x, activations.linear(x))
def test_invalid_usage(self): def test_invalid_usage(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
keras.activations.get('unknown') activations.get('unknown')
# The following should be possible but should raise a warning: # The following should be possible but should raise a warning:
keras.activations.get(keras.layers.LeakyReLU()) activations.get(advanced_activations.LeakyReLU())
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -17,38 +17,38 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import keras from tensorflow.python.keras import backend
from tensorflow.python.framework import test_util from tensorflow.python.keras import backend_config
from tensorflow.python.keras import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class BackendConfigTest(test.TestCase): class BackendConfigTest(test.TestCase):
def test_backend(self): def test_backend(self):
self.assertEqual(keras.backend.backend(), 'tensorflow') self.assertEqual(backend.backend(), 'tensorflow')
def test_epsilon(self): def test_epsilon(self):
epsilon = 1e-2 epsilon = 1e-2
keras.backend_config.set_epsilon(epsilon) backend_config.set_epsilon(epsilon)
self.assertEqual(keras.backend_config.epsilon(), epsilon) self.assertEqual(backend_config.epsilon(), epsilon)
keras.backend_config.set_epsilon(1e-7) backend_config.set_epsilon(1e-7)
self.assertEqual(keras.backend_config.epsilon(), 1e-7) self.assertEqual(backend_config.epsilon(), 1e-7)
def test_floatx(self): def test_floatx(self):
floatx = 'float64' floatx = 'float64'
keras.backend_config.set_floatx(floatx) backend_config.set_floatx(floatx)
self.assertEqual(keras.backend_config.floatx(), floatx) self.assertEqual(backend_config.floatx(), floatx)
keras.backend_config.set_floatx('float32') backend_config.set_floatx('float32')
self.assertEqual(keras.backend_config.floatx(), 'float32') self.assertEqual(backend_config.floatx(), 'float32')
def test_image_data_format(self): def test_image_data_format(self):
image_data_format = 'channels_first' image_data_format = 'channels_first'
keras.backend_config.set_image_data_format(image_data_format) backend_config.set_image_data_format(image_data_format)
self.assertEqual(keras.backend_config.image_data_format(), self.assertEqual(backend_config.image_data_format(), image_data_format)
image_data_format) backend_config.set_image_data_format('channels_last')
keras.backend_config.set_image_data_format('channels_last') self.assertEqual(backend_config.image_data_format(), 'channels_last')
self.assertEqual(keras.backend_config.image_data_format(), 'channels_last')
if __name__ == '__main__': if __name__ == '__main__':

File diff suppressed because it is too large Load Diff

View File

@ -22,8 +22,9 @@ import math
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python.keras import backend
from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations
from tensorflow.python.keras import constraints
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -44,47 +45,45 @@ def get_example_kernel(width):
return example_array return example_array
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KerasConstraintsTest(test.TestCase): class KerasConstraintsTest(test.TestCase):
def test_serialization(self): def test_serialization(self):
all_activations = ['max_norm', 'non_neg', all_activations = ['max_norm', 'non_neg',
'unit_norm', 'min_max_norm'] 'unit_norm', 'min_max_norm']
for name in all_activations: for name in all_activations:
fn = keras.constraints.get(name) fn = constraints.get(name)
ref_fn = getattr(keras.constraints, name)() ref_fn = getattr(constraints, name)()
assert fn.__class__ == ref_fn.__class__ assert fn.__class__ == ref_fn.__class__
config = keras.constraints.serialize(fn) config = constraints.serialize(fn)
fn = keras.constraints.deserialize(config) fn = constraints.deserialize(config)
assert fn.__class__ == ref_fn.__class__ assert fn.__class__ == ref_fn.__class__
def test_max_norm(self): def test_max_norm(self):
array = get_example_array() array = get_example_array()
for m in get_test_values(): for m in get_test_values():
norm_instance = keras.constraints.max_norm(m) norm_instance = constraints.max_norm(m)
normed = norm_instance(keras.backend.variable(array)) normed = norm_instance(backend.variable(array))
assert np.all(keras.backend.eval(normed) < m) assert np.all(backend.eval(normed) < m)
# a more explicit example # a more explicit example
norm_instance = keras.constraints.max_norm(2.0) norm_instance = constraints.max_norm(2.0)
x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T
x_normed_target = np.array( x_normed_target = np.array(
[[0, 0, 0], [1.0, 0, 0], [2.0, 0, 0], [[0, 0, 0], [1.0, 0, 0], [2.0, 0, 0],
[2. / np.sqrt(3), 2. / np.sqrt(3), 2. / np.sqrt(3)]]).T [2. / np.sqrt(3), 2. / np.sqrt(3), 2. / np.sqrt(3)]]).T
x_normed_actual = keras.backend.eval( x_normed_actual = backend.eval(norm_instance(backend.variable(x)))
norm_instance(keras.backend.variable(x)))
self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05) self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05)
def test_non_neg(self): def test_non_neg(self):
non_neg_instance = keras.constraints.non_neg() non_neg_instance = constraints.non_neg()
normed = non_neg_instance(keras.backend.variable(get_example_array())) normed = non_neg_instance(backend.variable(get_example_array()))
assert np.all(np.min(keras.backend.eval(normed), axis=1) == 0.) assert np.all(np.min(backend.eval(normed), axis=1) == 0.)
def test_unit_norm(self): def test_unit_norm(self):
unit_norm_instance = keras.constraints.unit_norm() unit_norm_instance = constraints.unit_norm()
normalized = unit_norm_instance(keras.backend.variable(get_example_array())) normalized = unit_norm_instance(backend.variable(get_example_array()))
norm_of_normalized = np.sqrt( norm_of_normalized = np.sqrt(np.sum(backend.eval(normalized)**2, axis=0))
np.sum(keras.backend.eval(normalized)**2, axis=0))
# In the unit norm constraint, it should be equal to 1. # In the unit norm constraint, it should be equal to 1.
difference = norm_of_normalized - 1. difference = norm_of_normalized - 1.
largest_difference = np.max(np.abs(difference)) largest_difference = np.max(np.abs(difference))
@ -93,10 +92,9 @@ class KerasConstraintsTest(test.TestCase):
def test_min_max_norm(self): def test_min_max_norm(self):
array = get_example_array() array = get_example_array()
for m in get_test_values(): for m in get_test_values():
norm_instance = keras.constraints.min_max_norm( norm_instance = constraints.min_max_norm(min_value=m, max_value=m * 2)
min_value=m, max_value=m * 2) normed = norm_instance(backend.variable(array))
normed = norm_instance(keras.backend.variable(array)) value = backend.eval(normed)
value = keras.backend.eval(normed)
l2 = np.sqrt(np.sum(np.square(value), axis=0)) l2 = np.sqrt(np.sum(np.square(value), axis=0))
assert not l2[l2 < m] assert not l2[l2 < m]
assert not l2[l2 > m * 2 + 1e-5] assert not l2[l2 > m * 2 + 1e-5]
@ -104,9 +102,9 @@ class KerasConstraintsTest(test.TestCase):
def test_conv2d_radial_constraint(self): def test_conv2d_radial_constraint(self):
for width in (3, 4, 5, 6): for width in (3, 4, 5, 6):
array = get_example_kernel(width) array = get_example_kernel(width)
norm_instance = keras.constraints.radial_constraint() norm_instance = constraints.radial_constraint()
normed = norm_instance(keras.backend.variable(array)) normed = norm_instance(backend.variable(array))
value = keras.backend.eval(normed) value = backend.eval(normed)
assert np.all(value.shape == array.shape) assert np.all(value.shape == array.shape)
assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0]) assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0])
assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2) assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2)

View File

@ -20,33 +20,38 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import models
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KerasInitializersTest(test.TestCase): class KerasInitializersTest(test.TestCase):
def _runner(self, init, shape, target_mean=None, target_std=None, def _runner(self, init, shape, target_mean=None, target_std=None,
target_max=None, target_min=None): target_max=None, target_min=None):
variable = keras.backend.variable(init(shape)) variable = backend.variable(init(shape))
output = keras.backend.get_value(variable) output = backend.get_value(variable)
# Test serialization (assumes deterministic behavior). # Test serialization (assumes deterministic behavior).
config = init.get_config() config = init.get_config()
reconstructed_init = init.__class__.from_config(config) reconstructed_init = init.__class__.from_config(config)
variable = keras.backend.variable(reconstructed_init(shape)) variable = backend.variable(reconstructed_init(shape))
output_2 = keras.backend.get_value(variable) output_2 = backend.get_value(variable)
self.assertAllClose(output, output_2, atol=1e-4) self.assertAllClose(output, output_2, atol=1e-4)
def test_uniform(self): def test_uniform(self):
tensor_shape = (9, 6, 7) tensor_shape = (9, 6, 7)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.RandomUniformV2(minval=-1, maxval=1, seed=124), initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_max=1, target_max=1,
@ -56,7 +61,7 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (8, 12, 99) tensor_shape = (8, 12, 99)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.RandomNormalV2(mean=0, stddev=1, seed=153), initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=1) target_std=1)
@ -65,7 +70,7 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (12, 99, 7) tensor_shape = (12, 99, 7)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126), initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_max=2, target_max=2,
@ -75,7 +80,7 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (5, 6, 4) tensor_shape = (5, 6, 4)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.ConstantV2(2.), initializers.ConstantV2(2.),
tensor_shape, tensor_shape,
target_mean=2, target_mean=2,
target_max=2, target_max=2,
@ -87,7 +92,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in) std = np.sqrt(1. / fan_in)
self._runner( self._runner(
keras.initializers.lecun_uniformV2(seed=123), initializers.lecun_uniformV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -98,7 +103,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, fan_out = init_ops._compute_fans(tensor_shape) fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out)) std = np.sqrt(2. / (fan_in + fan_out))
self._runner( self._runner(
keras.initializers.GlorotUniformV2(seed=123), initializers.GlorotUniformV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -109,7 +114,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in) std = np.sqrt(2. / fan_in)
self._runner( self._runner(
keras.initializers.he_uniformV2(seed=123), initializers.he_uniformV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -120,7 +125,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in) std = np.sqrt(1. / fan_in)
self._runner( self._runner(
keras.initializers.lecun_normalV2(seed=123), initializers.lecun_normalV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -131,7 +136,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, fan_out = init_ops._compute_fans(tensor_shape) fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out)) std = np.sqrt(2. / (fan_in + fan_out))
self._runner( self._runner(
keras.initializers.GlorotNormalV2(seed=123), initializers.GlorotNormalV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -142,7 +147,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in) std = np.sqrt(2. / fan_in)
self._runner( self._runner(
keras.initializers.he_normalV2(seed=123), initializers.he_normalV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -151,23 +156,21 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (20, 20) tensor_shape = (20, 20)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.OrthogonalV2(seed=123), initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.)
tensor_shape,
target_mean=0.)
def test_identity(self): def test_identity(self):
with self.cached_session(): with self.cached_session():
tensor_shape = (3, 4, 5) tensor_shape = (3, 4, 5)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self._runner( self._runner(
keras.initializers.IdentityV2(), initializers.IdentityV2(),
tensor_shape, tensor_shape,
target_mean=1. / tensor_shape[0], target_mean=1. / tensor_shape[0],
target_max=1.) target_max=1.)
tensor_shape = (3, 3) tensor_shape = (3, 3)
self._runner( self._runner(
keras.initializers.IdentityV2(), initializers.IdentityV2(),
tensor_shape, tensor_shape,
target_mean=1. / tensor_shape[0], target_mean=1. / tensor_shape[0],
target_max=1.) target_max=1.)
@ -176,32 +179,26 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (4, 5) tensor_shape = (4, 5)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.ZerosV2(), initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.)
tensor_shape,
target_mean=0.,
target_max=0.)
def test_one(self): def test_one(self):
tensor_shape = (4, 5) tensor_shape = (4, 5)
with self.cached_session(): with self.cached_session():
self._runner( self._runner(
keras.initializers.OnesV2(), initializers.OnesV2(), tensor_shape, target_mean=1., target_max=1.)
tensor_shape,
target_mean=1.,
target_max=1.)
def test_default_random_uniform(self): def test_default_random_uniform(self):
ru = keras.initializers.get('uniform') ru = initializers.get('uniform')
self.assertEqual(ru.minval, -0.05) self.assertEqual(ru.minval, -0.05)
self.assertEqual(ru.maxval, 0.05) self.assertEqual(ru.maxval, 0.05)
def test_default_random_normal(self): def test_default_random_normal(self):
rn = keras.initializers.get('normal') rn = initializers.get('normal')
self.assertEqual(rn.mean, 0.0) self.assertEqual(rn.mean, 0.0)
self.assertEqual(rn.stddev, 0.05) self.assertEqual(rn.stddev, 0.05)
def test_default_truncated_normal(self): def test_default_truncated_normal(self):
tn = keras.initializers.get('truncated_normal') tn = initializers.get('truncated_normal')
self.assertEqual(tn.mean, 0.0) self.assertEqual(tn.mean, 0.0)
self.assertEqual(tn.stddev, 0.05) self.assertEqual(tn.stddev, 0.05)
@ -209,7 +206,7 @@ class KerasInitializersTest(test.TestCase):
tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access
try: try:
tf2.enable() tf2.enable()
rn = keras.initializers.get('random_normal') rn = initializers.get('random_normal')
self.assertIn('init_ops_v2', rn.__class__.__module__) self.assertIn('init_ops_v2', rn.__class__.__module__)
finally: finally:
tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access
@ -219,9 +216,9 @@ class KerasInitializersTest(test.TestCase):
def my_initializer(shape, dtype=None): def my_initializer(shape, dtype=None):
return array_ops.ones(shape, dtype=dtype) return array_ops.ones(shape, dtype=dtype)
inputs = keras.Input((10,)) inputs = input_layer.Input((10,))
outputs = keras.layers.Dense(1, kernel_initializer=my_initializer)(inputs) outputs = core.Dense(1, kernel_initializer=my_initializer)(inputs)
model = keras.Model(inputs, outputs) model = models.Model(inputs, outputs)
model2 = model.from_config( model2 = model.from_config(
model.get_config(), custom_objects={'my_initializer': my_initializer}) model.get_config(), custom_objects={'my_initializer': my_initializer})
self.assertEqual(model2.layers[1].kernel_initializer, my_initializer) self.assertEqual(model2.layers[1].kernel_initializer, my_initializer)
@ -237,7 +234,7 @@ class KerasInitializersTest(test.TestCase):
'seed': None 'seed': None
} }
} }
initializer = keras.initializers.deserialize(external_serialized_json) initializer = initializers.deserialize(external_serialized_json)
self.assertEqual(initializer.distribution, 'truncated_normal') self.assertEqual(initializer.distribution, 'truncated_normal')

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ from scipy.special import expit
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import metrics from tensorflow.python.keras import metrics
from tensorflow.python.keras import models from tensorflow.python.keras import models
@ -37,19 +37,19 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class FalsePositivesTest(test.TestCase): class FalsePositivesTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
fp_obj = metrics.FalsePositives(name='my_fp', thresholds=[0.4, 0.9]) fp_obj = metrics.FalsePositives(name='my_fp', thresholds=[0.4, 0.9])
self.assertEqual(fp_obj.name, 'my_fp') self.assertEqual(fp_obj.name, 'my_fp')
self.assertEqual(len(fp_obj.variables), 1) self.assertLen(fp_obj.variables, 1)
self.assertEqual(fp_obj.thresholds, [0.4, 0.9]) self.assertEqual(fp_obj.thresholds, [0.4, 0.9])
# Check save and restore config # Check save and restore config
fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config()) fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config())
self.assertEqual(fp_obj2.name, 'my_fp') self.assertEqual(fp_obj2.name, 'my_fp')
self.assertEqual(len(fp_obj2.variables), 1) self.assertLen(fp_obj2.variables, 1)
self.assertEqual(fp_obj2.thresholds, [0.4, 0.9]) self.assertEqual(fp_obj2.thresholds, [0.4, 0.9])
def test_unweighted(self): def test_unweighted(self):
@ -117,19 +117,19 @@ class FalsePositivesTest(test.TestCase):
metrics.FalsePositives(thresholds=[None]) metrics.FalsePositives(thresholds=[None])
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class FalseNegativesTest(test.TestCase): class FalseNegativesTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
fn_obj = metrics.FalseNegatives(name='my_fn', thresholds=[0.4, 0.9]) fn_obj = metrics.FalseNegatives(name='my_fn', thresholds=[0.4, 0.9])
self.assertEqual(fn_obj.name, 'my_fn') self.assertEqual(fn_obj.name, 'my_fn')
self.assertEqual(len(fn_obj.variables), 1) self.assertLen(fn_obj.variables, 1)
self.assertEqual(fn_obj.thresholds, [0.4, 0.9]) self.assertEqual(fn_obj.thresholds, [0.4, 0.9])
# Check save and restore config # Check save and restore config
fn_obj2 = metrics.FalseNegatives.from_config(fn_obj.get_config()) fn_obj2 = metrics.FalseNegatives.from_config(fn_obj.get_config())
self.assertEqual(fn_obj2.name, 'my_fn') self.assertEqual(fn_obj2.name, 'my_fn')
self.assertEqual(len(fn_obj2.variables), 1) self.assertLen(fn_obj2.variables, 1)
self.assertEqual(fn_obj2.thresholds, [0.4, 0.9]) self.assertEqual(fn_obj2.thresholds, [0.4, 0.9])
def test_unweighted(self): def test_unweighted(self):
@ -185,19 +185,19 @@ class FalseNegativesTest(test.TestCase):
self.assertAllClose([4., 16., 23.], self.evaluate(result)) self.assertAllClose([4., 16., 23.], self.evaluate(result))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TrueNegativesTest(test.TestCase): class TrueNegativesTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
tn_obj = metrics.TrueNegatives(name='my_tn', thresholds=[0.4, 0.9]) tn_obj = metrics.TrueNegatives(name='my_tn', thresholds=[0.4, 0.9])
self.assertEqual(tn_obj.name, 'my_tn') self.assertEqual(tn_obj.name, 'my_tn')
self.assertEqual(len(tn_obj.variables), 1) self.assertLen(tn_obj.variables, 1)
self.assertEqual(tn_obj.thresholds, [0.4, 0.9]) self.assertEqual(tn_obj.thresholds, [0.4, 0.9])
# Check save and restore config # Check save and restore config
tn_obj2 = metrics.TrueNegatives.from_config(tn_obj.get_config()) tn_obj2 = metrics.TrueNegatives.from_config(tn_obj.get_config())
self.assertEqual(tn_obj2.name, 'my_tn') self.assertEqual(tn_obj2.name, 'my_tn')
self.assertEqual(len(tn_obj2.variables), 1) self.assertLen(tn_obj2.variables, 1)
self.assertEqual(tn_obj2.thresholds, [0.4, 0.9]) self.assertEqual(tn_obj2.thresholds, [0.4, 0.9])
def test_unweighted(self): def test_unweighted(self):
@ -253,19 +253,19 @@ class TrueNegativesTest(test.TestCase):
self.assertAllClose([5., 15., 23.], self.evaluate(result)) self.assertAllClose([5., 15., 23.], self.evaluate(result))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TruePositivesTest(test.TestCase): class TruePositivesTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
tp_obj = metrics.TruePositives(name='my_tp', thresholds=[0.4, 0.9]) tp_obj = metrics.TruePositives(name='my_tp', thresholds=[0.4, 0.9])
self.assertEqual(tp_obj.name, 'my_tp') self.assertEqual(tp_obj.name, 'my_tp')
self.assertEqual(len(tp_obj.variables), 1) self.assertLen(tp_obj.variables, 1)
self.assertEqual(tp_obj.thresholds, [0.4, 0.9]) self.assertEqual(tp_obj.thresholds, [0.4, 0.9])
# Check save and restore config # Check save and restore config
tp_obj2 = metrics.TruePositives.from_config(tp_obj.get_config()) tp_obj2 = metrics.TruePositives.from_config(tp_obj.get_config())
self.assertEqual(tp_obj2.name, 'my_tp') self.assertEqual(tp_obj2.name, 'my_tp')
self.assertEqual(len(tp_obj2.variables), 1) self.assertLen(tp_obj2.variables, 1)
self.assertEqual(tp_obj2.thresholds, [0.4, 0.9]) self.assertEqual(tp_obj2.thresholds, [0.4, 0.9])
def test_unweighted(self): def test_unweighted(self):
@ -320,14 +320,14 @@ class TruePositivesTest(test.TestCase):
self.assertAllClose([222., 111., 37.], self.evaluate(result)) self.assertAllClose([222., 111., 37.], self.evaluate(result))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class PrecisionTest(test.TestCase): class PrecisionTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
p_obj = metrics.Precision( p_obj = metrics.Precision(
name='my_precision', thresholds=[0.4, 0.9], top_k=15, class_id=12) name='my_precision', thresholds=[0.4, 0.9], top_k=15, class_id=12)
self.assertEqual(p_obj.name, 'my_precision') self.assertEqual(p_obj.name, 'my_precision')
self.assertEqual(len(p_obj.variables), 2) self.assertLen(p_obj.variables, 2)
self.assertEqual([v.name for v in p_obj.variables], self.assertEqual([v.name for v in p_obj.variables],
['true_positives:0', 'false_positives:0']) ['true_positives:0', 'false_positives:0'])
self.assertEqual(p_obj.thresholds, [0.4, 0.9]) self.assertEqual(p_obj.thresholds, [0.4, 0.9])
@ -337,7 +337,7 @@ class PrecisionTest(test.TestCase):
# Check save and restore config # Check save and restore config
p_obj2 = metrics.Precision.from_config(p_obj.get_config()) p_obj2 = metrics.Precision.from_config(p_obj.get_config())
self.assertEqual(p_obj2.name, 'my_precision') self.assertEqual(p_obj2.name, 'my_precision')
self.assertEqual(len(p_obj2.variables), 2) self.assertLen(p_obj2.variables, 2)
self.assertEqual(p_obj2.thresholds, [0.4, 0.9]) self.assertEqual(p_obj2.thresholds, [0.4, 0.9])
self.assertEqual(p_obj2.top_k, 15) self.assertEqual(p_obj2.top_k, 15)
self.assertEqual(p_obj2.class_id, 12) self.assertEqual(p_obj2.class_id, 12)
@ -525,14 +525,14 @@ class PrecisionTest(test.TestCase):
self.assertAlmostEqual(0, self.evaluate(p_obj.false_positives)) self.assertAlmostEqual(0, self.evaluate(p_obj.false_positives))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class RecallTest(test.TestCase): class RecallTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
r_obj = metrics.Recall( r_obj = metrics.Recall(
name='my_recall', thresholds=[0.4, 0.9], top_k=15, class_id=12) name='my_recall', thresholds=[0.4, 0.9], top_k=15, class_id=12)
self.assertEqual(r_obj.name, 'my_recall') self.assertEqual(r_obj.name, 'my_recall')
self.assertEqual(len(r_obj.variables), 2) self.assertLen(r_obj.variables, 2)
self.assertEqual([v.name for v in r_obj.variables], self.assertEqual([v.name for v in r_obj.variables],
['true_positives:0', 'false_negatives:0']) ['true_positives:0', 'false_negatives:0'])
self.assertEqual(r_obj.thresholds, [0.4, 0.9]) self.assertEqual(r_obj.thresholds, [0.4, 0.9])
@ -542,7 +542,7 @@ class RecallTest(test.TestCase):
# Check save and restore config # Check save and restore config
r_obj2 = metrics.Recall.from_config(r_obj.get_config()) r_obj2 = metrics.Recall.from_config(r_obj.get_config())
self.assertEqual(r_obj2.name, 'my_recall') self.assertEqual(r_obj2.name, 'my_recall')
self.assertEqual(len(r_obj2.variables), 2) self.assertLen(r_obj2.variables, 2)
self.assertEqual(r_obj2.thresholds, [0.4, 0.9]) self.assertEqual(r_obj2.thresholds, [0.4, 0.9])
self.assertEqual(r_obj2.top_k, 15) self.assertEqual(r_obj2.top_k, 15)
self.assertEqual(r_obj2.class_id, 12) self.assertEqual(r_obj2.class_id, 12)
@ -729,7 +729,7 @@ class RecallTest(test.TestCase):
self.assertAlmostEqual(3, self.evaluate(r_obj.false_negatives)) self.assertAlmostEqual(3, self.evaluate(r_obj.false_negatives))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase): class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
@ -771,13 +771,14 @@ class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
1e-3) 1e-3)
def test_unweighted_all_correct(self): def test_unweighted_all_correct(self):
s_obj = metrics.SensitivityAtSpecificity(0.7) with self.test_session():
inputs = np.random.randint(0, 2, size=(100, 1)) s_obj = metrics.SensitivityAtSpecificity(0.7)
y_pred = constant_op.constant(inputs, dtype=dtypes.float32) inputs = np.random.randint(0, 2, size=(100, 1))
y_true = constant_op.constant(inputs) y_pred = constant_op.constant(inputs, dtype=dtypes.float32)
self.evaluate(variables.variables_initializer(s_obj.variables)) y_true = constant_op.constant(inputs)
result = s_obj(y_true, y_pred) self.evaluate(variables.variables_initializer(s_obj.variables))
self.assertAlmostEqual(1, self.evaluate(result)) result = s_obj(y_true, y_pred)
self.assertAlmostEqual(1, self.evaluate(result))
def test_unweighted_high_specificity(self): def test_unweighted_high_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.8) s_obj = metrics.SensitivityAtSpecificity(0.8)
@ -825,7 +826,7 @@ class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1) metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase): class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
@ -921,7 +922,7 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1) metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase): class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
@ -1018,7 +1019,7 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
metrics.PrecisionAtRecall(0.4, num_thresholds=-1) metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class RecallAtPrecisionTest(test.TestCase, parameterized.TestCase): class RecallAtPrecisionTest(test.TestCase, parameterized.TestCase):
def test_config(self): def test_config(self):
@ -1133,8 +1134,8 @@ class RecallAtPrecisionTest(test.TestCase, parameterized.TestCase):
metrics.RecallAtPrecision(0.4, num_thresholds=-1) metrics.RecallAtPrecision(0.4, num_thresholds=-1)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class AUCTest(test.TestCase): class AUCTest(test.TestCase, parameterized.TestCase):
def setup(self): def setup(self):
self.num_thresholds = 3 self.num_thresholds = 3
@ -1172,7 +1173,7 @@ class AUCTest(test.TestCase):
name='auc_1') name='auc_1')
auc_obj.update_state(self.y_true, self.y_pred) auc_obj.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj.name, 'auc_1') self.assertEqual(auc_obj.name, 'auc_1')
self.assertEqual(len(auc_obj.variables), 4) self.assertLen(auc_obj.variables, 4)
self.assertEqual(auc_obj.num_thresholds, 100) self.assertEqual(auc_obj.num_thresholds, 100)
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR) self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj.summation_method, self.assertEqual(auc_obj.summation_method,
@ -1184,7 +1185,7 @@ class AUCTest(test.TestCase):
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
auc_obj2.update_state(self.y_true, self.y_pred) auc_obj2.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj2.name, 'auc_1') self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4) self.assertLen(auc_obj2.variables, 4)
self.assertEqual(auc_obj2.num_thresholds, 100) self.assertEqual(auc_obj2.num_thresholds, 100)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR) self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj2.summation_method, self.assertEqual(auc_obj2.summation_method,
@ -1203,7 +1204,7 @@ class AUCTest(test.TestCase):
thresholds=[0.3, 0.5]) thresholds=[0.3, 0.5])
auc_obj.update_state(self.y_true, self.y_pred) auc_obj.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj.name, 'auc_1') self.assertEqual(auc_obj.name, 'auc_1')
self.assertEqual(len(auc_obj.variables), 4) self.assertLen(auc_obj.variables, 4)
self.assertEqual(auc_obj.num_thresholds, 4) self.assertEqual(auc_obj.num_thresholds, 4)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0]) self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR) self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
@ -1216,7 +1217,7 @@ class AUCTest(test.TestCase):
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
auc_obj2.update_state(self.y_true, self.y_pred) auc_obj2.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj2.name, 'auc_1') self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4) self.assertLen(auc_obj2.variables, 4)
self.assertEqual(auc_obj2.num_thresholds, 4) self.assertEqual(auc_obj2.num_thresholds, 4)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR) self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj2.summation_method, self.assertEqual(auc_obj2.summation_method,
@ -1407,8 +1408,8 @@ class AUCTest(test.TestCase):
self.assertEqual(self.evaluate(result), 0.5) self.assertEqual(self.evaluate(result), 0.5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MultiAUCTest(test.TestCase): class MultiAUCTest(test.TestCase, parameterized.TestCase):
def setup(self): def setup(self):
self.num_thresholds = 5 self.num_thresholds = 5
@ -1457,26 +1458,28 @@ class MultiAUCTest(test.TestCase):
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]] # fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
def test_value_is_idempotent(self): def test_value_is_idempotent(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=5, multi_label=True) self.setup()
self.evaluate(variables.variables_initializer(auc_obj.variables)) auc_obj = metrics.AUC(num_thresholds=5, multi_label=True)
self.evaluate(variables.variables_initializer(auc_obj.variables))
# Run several updates. # Run several updates.
update_op = auc_obj.update_state(self.y_true_good, self.y_pred) update_op = auc_obj.update_state(self.y_true_good, self.y_pred)
for _ in range(10): for _ in range(10):
self.evaluate(update_op) self.evaluate(update_op)
# Then verify idempotency. # Then verify idempotency.
initial_auc = self.evaluate(auc_obj.result()) initial_auc = self.evaluate(auc_obj.result())
for _ in range(10): for _ in range(10):
self.assertAllClose(initial_auc, self.evaluate(auc_obj.result()), 1e-3) self.assertAllClose(initial_auc, self.evaluate(auc_obj.result()), 1e-3)
def test_unweighted_all_correct(self): def test_unweighted_all_correct(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(multi_label=True) self.setup()
self.evaluate(variables.variables_initializer(auc_obj.variables)) auc_obj = metrics.AUC(multi_label=True)
result = auc_obj(self.y_true_good, self.y_true_good) self.evaluate(variables.variables_initializer(auc_obj.variables))
self.assertEqual(self.evaluate(result), 1) result = auc_obj(self.y_true_good, self.y_true_good)
self.assertEqual(self.evaluate(result), 1)
def test_unweighted_all_correct_flat(self): def test_unweighted_all_correct_flat(self):
self.setup() self.setup()
@ -1486,15 +1489,17 @@ class MultiAUCTest(test.TestCase):
self.assertEqual(self.evaluate(result), 1) self.assertEqual(self.evaluate(result), 1)
def test_unweighted(self): def test_unweighted(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True) self.setup()
self.evaluate(variables.variables_initializer(auc_obj.variables)) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
result = auc_obj(self.y_true_good, self.y_pred) multi_label=True)
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true_good, self.y_pred)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 + 1.0) / 2.0 expected_result = (0.875 + 1.0) / 2.0
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_sample_weight_flat(self): def test_sample_weight_flat(self):
self.setup() self.setup()
@ -1521,18 +1526,19 @@ class MultiAUCTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_label_weights(self): def test_label_weights(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC( self.setup()
num_thresholds=self.num_thresholds, auc_obj = metrics.AUC(
multi_label=True, num_thresholds=self.num_thresholds,
label_weights=[0.75, 0.25]) multi_label=True,
self.evaluate(variables.variables_initializer(auc_obj.variables)) label_weights=[0.75, 0.25])
result = auc_obj(self.y_true_good, self.y_pred) self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true_good, self.y_pred)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25) expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_label_weights_flat(self): def test_label_weights_flat(self):
self.setup() self.setup()
@ -1565,65 +1571,72 @@ class MultiAUCTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_manual_thresholds(self): def test_manual_thresholds(self):
self.setup() with self.test_session():
# Verify that when specified, thresholds are used instead of num_thresholds. self.setup()
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5], multi_label=True) # Verify that when specified, thresholds are used instead of
self.assertEqual(auc_obj.num_thresholds, 3) # num_thresholds.
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0]) auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5],
self.evaluate(variables.variables_initializer(auc_obj.variables)) multi_label=True)
result = auc_obj(self.y_true_good, self.y_pred) self.assertEqual(auc_obj.num_thresholds, 3)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true_good, self.y_pred)
# tp = [[2, 1, 0], [2, 0, 0]] # tp = [[2, 1, 0], [2, 0, 0]]
# fp = [2, 0, 0], [2, 0, 0]] # fp = [2, 0, 0], [2, 0, 0]]
# fn = [[0, 1, 2], [0, 2, 2]] # fn = [[0, 1, 2], [0, 2, 2]]
# tn = [[0, 2, 2], [0, 2, 2]] # tn = [[0, 2, 2], [0, 2, 2]]
# tpr = [[1, 0.5, 0], [1, 0, 0]] # tpr = [[1, 0.5, 0], [1, 0, 0]]
# fpr = [[1, 0, 0], [1, 0, 0]] # fpr = [[1, 0, 0], [1, 0, 0]]
# auc by slice = [0.75, 0.5] # auc by slice = [0.75, 0.5]
expected_result = (0.75 + 0.5) / 2.0 expected_result = (0.75 + 0.5) / 2.0
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_weighted_roc_interpolation(self): def test_weighted_roc_interpolation(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True) self.setup()
self.evaluate(variables.variables_initializer(auc_obj.variables)) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
result = auc_obj( multi_label=True)
self.y_true_good, self.y_pred, sample_weight=self.sample_weight) self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(
self.y_true_good, self.y_pred, sample_weight=self.sample_weight)
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]] # tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]] # fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = 1.0 - 0.5 * 0.43 * 0.67 expected_result = 1.0 - 0.5 * 0.43 * 0.67
self.assertAllClose(self.evaluate(result), expected_result, 1e-1) self.assertAllClose(self.evaluate(result), expected_result, 1e-1)
def test_pr_interpolation_unweighted(self): def test_pr_interpolation_unweighted(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR', self.setup()
multi_label=True) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
self.evaluate(variables.variables_initializer(auc_obj.variables)) multi_label=True)
good_result = auc_obj(self.y_true_good, self.y_pred) self.evaluate(variables.variables_initializer(auc_obj.variables))
with self.subTest(name='good'): good_result = auc_obj(self.y_true_good, self.y_pred)
# PR AUCs are 0.917 and 1.0 respectively with self.subTest(name='good'):
self.assertAllClose(self.evaluate(good_result), (0.91667 + 1.0) / 2.0, # PR AUCs are 0.917 and 1.0 respectively
1e-1) self.assertAllClose(self.evaluate(good_result), (0.91667 + 1.0) / 2.0,
bad_result = auc_obj(self.y_true_bad, self.y_pred) 1e-1)
with self.subTest(name='bad'): bad_result = auc_obj(self.y_true_bad, self.y_pred)
# PR AUCs are 0.917 and 0.5 respectively with self.subTest(name='bad'):
self.assertAllClose(self.evaluate(bad_result), (0.91667 + 0.5) / 2.0, # PR AUCs are 0.917 and 0.5 respectively
1e-1) self.assertAllClose(self.evaluate(bad_result), (0.91667 + 0.5) / 2.0,
1e-1)
def test_pr_interpolation(self): def test_pr_interpolation(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR', self.setup()
multi_label=True) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR',
self.evaluate(variables.variables_initializer(auc_obj.variables)) multi_label=True)
good_result = auc_obj(self.y_true_good, self.y_pred, self.evaluate(variables.variables_initializer(auc_obj.variables))
sample_weight=self.sample_weight) good_result = auc_obj(self.y_true_good, self.y_pred,
# PR AUCs are 0.939 and 1.0 respectively sample_weight=self.sample_weight)
self.assertAllClose(self.evaluate(good_result), (0.939 + 1.0) / 2.0, # PR AUCs are 0.939 and 1.0 respectively
1e-1) self.assertAllClose(self.evaluate(good_result), (0.939 + 1.0) / 2.0,
1e-1)
def test_keras_model_compiles(self): def test_keras_model_compiles(self):
inputs = layers.Input(shape=(10,)) inputs = layers.Input(shape=(10,))
@ -1635,12 +1648,14 @@ class MultiAUCTest(test.TestCase):
) )
def test_reset_states(self): def test_reset_states(self):
self.setup() with self.test_session():
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=True) self.setup()
self.evaluate(variables.variables_initializer(auc_obj.variables)) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds,
auc_obj(self.y_true_good, self.y_pred) multi_label=True)
auc_obj.reset_states() self.evaluate(variables.variables_initializer(auc_obj.variables))
self.assertAllEqual(auc_obj.true_positives, np.zeros((5, 2))) auc_obj(self.y_true_good, self.y_pred)
auc_obj.reset_states()
self.assertAllEqual(auc_obj.true_positives, np.zeros((5, 2)))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -22,6 +22,7 @@ import json
import math import math
import os import os
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -32,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import metrics from tensorflow.python.keras import metrics
@ -46,35 +48,36 @@ from tensorflow.python.platform import test
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KerasSumTest(test.TestCase): class KerasSumTest(test.TestCase, parameterized.TestCase):
def test_sum(self): def test_sum(self):
m = metrics.Sum(name='my_sum') with self.test_session():
m = metrics.Sum(name='my_sum')
# check config # check config
self.assertEqual(m.name, 'my_sum') self.assertEqual(m.name, 'my_sum')
self.assertTrue(m.stateful) self.assertTrue(m.stateful)
self.assertEqual(m.dtype, dtypes.float32) self.assertEqual(m.dtype, dtypes.float32)
self.assertEqual(len(m.variables), 1) self.assertLen(m.variables, 1)
self.evaluate(variables.variables_initializer(m.variables)) self.evaluate(variables.variables_initializer(m.variables))
# check initial state # check initial state
self.assertEqual(self.evaluate(m.total), 0) self.assertEqual(self.evaluate(m.total), 0)
# check __call__() # check __call__()
self.assertEqual(self.evaluate(m(100)), 100) self.assertEqual(self.evaluate(m(100)), 100)
self.assertEqual(self.evaluate(m.total), 100) self.assertEqual(self.evaluate(m.total), 100)
# check update_state() and result() + state accumulation + tensor input # check update_state() and result() + state accumulation + tensor input
update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
self.evaluate(update_op) self.evaluate(update_op)
self.assertAlmostEqual(self.evaluate(m.result()), 106) self.assertAlmostEqual(self.evaluate(m.result()), 106)
self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
# check reset_states() # check reset_states()
m.reset_states() m.reset_states()
self.assertEqual(self.evaluate(m.total), 0) self.assertEqual(self.evaluate(m.total), 0)
def test_sum_with_sample_weight(self): def test_sum_with_sample_weight(self):
m = metrics.Sum(dtype=dtypes.float64) m = metrics.Sum(dtype=dtypes.float64)
@ -133,33 +136,34 @@ class KerasSumTest(test.TestCase):
self.assertAlmostEqual(self.evaluate(m.total), 52., 2) self.assertAlmostEqual(self.evaluate(m.total), 52., 2)
def test_save_restore(self): def test_save_restore(self):
checkpoint_directory = self.get_temp_dir() with self.test_session():
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') checkpoint_directory = self.get_temp_dir()
m = metrics.Sum() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
checkpoint = trackable_utils.Checkpoint(sum=m) m = metrics.Sum()
self.evaluate(variables.variables_initializer(m.variables)) checkpoint = trackable_utils.Checkpoint(sum=m)
self.evaluate(variables.variables_initializer(m.variables))
# update state # update state
self.evaluate(m(100.)) self.evaluate(m(100.))
self.evaluate(m(200.)) self.evaluate(m(200.))
# save checkpoint and then add an update # save checkpoint and then add an update
save_path = checkpoint.save(checkpoint_prefix) save_path = checkpoint.save(checkpoint_prefix)
self.evaluate(m(1000.)) self.evaluate(m(1000.))
# restore to the same checkpoint sum object (= 300) # restore to the same checkpoint sum object (= 300)
checkpoint.restore(save_path).assert_consumed().run_restore_ops() checkpoint.restore(save_path).assert_consumed().run_restore_ops()
self.evaluate(m(300.)) self.evaluate(m(300.))
self.assertEqual(600., self.evaluate(m.result())) self.assertEqual(600., self.evaluate(m.result()))
# restore to a different checkpoint sum object # restore to a different checkpoint sum object
restore_sum = metrics.Sum() restore_sum = metrics.Sum()
restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum) restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum)
status = restore_checkpoint.restore(save_path) status = restore_checkpoint.restore(save_path)
restore_update = restore_sum(300.) restore_update = restore_sum(300.)
status.assert_consumed().run_restore_ops() status.assert_consumed().run_restore_ops()
self.evaluate(restore_update) self.evaluate(restore_update)
self.assertEqual(600., self.evaluate(restore_sum.result())) self.assertEqual(600., self.evaluate(restore_sum.result()))
class MeanTest(keras_parameterized.TestCase): class MeanTest(keras_parameterized.TestCase):
@ -354,7 +358,7 @@ class MeanTest(keras_parameterized.TestCase):
self.assertEqual(self.evaluate(m.count), 1) self.assertEqual(self.evaluate(m.count), 1)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KerasAccuracyTest(test.TestCase): class KerasAccuracyTest(test.TestCase):
def test_accuracy(self): def test_accuracy(self):
@ -598,7 +602,7 @@ class KerasAccuracyTest(test.TestCase):
self.assertEqual(acc_fn, metrics.accuracy) self.assertEqual(acc_fn, metrics.accuracy)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CosineSimilarityTest(test.TestCase): class CosineSimilarityTest(test.TestCase):
def l2_norm(self, x, axis): def l2_norm(self, x, axis):
@ -659,7 +663,7 @@ class CosineSimilarityTest(test.TestCase):
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3) self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanAbsoluteErrorTest(test.TestCase): class MeanAbsoluteErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -697,7 +701,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5) self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanAbsolutePercentageErrorTest(test.TestCase): class MeanAbsolutePercentageErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -737,7 +741,7 @@ class MeanAbsolutePercentageErrorTest(test.TestCase):
self.assertAllClose(40e7, self.evaluate(result), atol=1e-5) self.assertAllClose(40e7, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanSquaredErrorTest(test.TestCase): class MeanSquaredErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -775,7 +779,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5) self.assertAllClose(0.54285, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanSquaredLogarithmicErrorTest(test.TestCase): class MeanSquaredLogarithmicErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -815,7 +819,7 @@ class MeanSquaredLogarithmicErrorTest(test.TestCase):
self.assertAllClose(0.26082, self.evaluate(result), atol=1e-5) self.assertAllClose(0.26082, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class HingeTest(test.TestCase): class HingeTest(test.TestCase):
def test_config(self): def test_config(self):
@ -870,7 +874,7 @@ class HingeTest(test.TestCase):
self.assertAllClose(0.493, self.evaluate(result), atol=1e-3) self.assertAllClose(0.493, self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SquaredHingeTest(test.TestCase): class SquaredHingeTest(test.TestCase):
def test_config(self): def test_config(self):
@ -931,7 +935,7 @@ class SquaredHingeTest(test.TestCase):
self.assertAllClose(0.347, self.evaluate(result), atol=1e-3) self.assertAllClose(0.347, self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CategoricalHingeTest(test.TestCase): class CategoricalHingeTest(test.TestCase):
def test_config(self): def test_config(self):
@ -971,7 +975,7 @@ class CategoricalHingeTest(test.TestCase):
self.assertAllClose(0.5, self.evaluate(result), atol=1e-5) self.assertAllClose(0.5, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class RootMeanSquaredErrorTest(test.TestCase): class RootMeanSquaredErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1005,7 +1009,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertAllClose(math.sqrt(13), self.evaluate(result), atol=1e-3) self.assertAllClose(math.sqrt(13), self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TopKCategoricalAccuracyTest(test.TestCase): class TopKCategoricalAccuracyTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1052,7 +1056,7 @@ class TopKCategoricalAccuracyTest(test.TestCase):
self.assertAllClose(1.0, self.evaluate(result), atol=1e-5) self.assertAllClose(1.0, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SparseTopKCategoricalAccuracyTest(test.TestCase): class SparseTopKCategoricalAccuracyTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1099,7 +1103,7 @@ class SparseTopKCategoricalAccuracyTest(test.TestCase):
self.assertAllClose(1.0, self.evaluate(result), atol=1e-5) self.assertAllClose(1.0, self.evaluate(result), atol=1e-5)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class LogCoshErrorTest(test.TestCase): class LogCoshErrorTest(test.TestCase):
def setup(self): def setup(self):
@ -1142,7 +1146,7 @@ class LogCoshErrorTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3) self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class PoissonTest(test.TestCase): class PoissonTest(test.TestCase):
def setup(self): def setup(self):
@ -1188,7 +1192,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3) self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class KLDivergenceTest(test.TestCase): class KLDivergenceTest(test.TestCase):
def setup(self): def setup(self):
@ -1235,7 +1239,7 @@ class KLDivergenceTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3) self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanRelativeErrorTest(test.TestCase): class MeanRelativeErrorTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1291,7 +1295,7 @@ class MeanRelativeErrorTest(test.TestCase):
self.assertEqual(self.evaluate(result), 0) self.assertEqual(self.evaluate(result), 0)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanIoUTest(test.TestCase): class MeanIoUTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1374,90 +1378,93 @@ class MeanIoUTest(test.TestCase):
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3) self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
class MeanTensorTest(test.TestCase): class MeanTensorTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_config(self): def test_config(self):
m = metrics.MeanTensor(name='mean_by_element') with self.test_session():
m = metrics.MeanTensor(name='mean_by_element')
# check config # check config
self.assertEqual(m.name, 'mean_by_element') self.assertEqual(m.name, 'mean_by_element')
self.assertTrue(m.stateful) self.assertTrue(m.stateful)
self.assertEqual(m.dtype, dtypes.float32) self.assertEqual(m.dtype, dtypes.float32)
self.assertEqual(len(m.variables), 0) self.assertEmpty(m.variables)
with self.assertRaisesRegexp(ValueError, 'does not have any result yet'): with self.assertRaisesRegexp(ValueError, 'does not have any result yet'):
m.result() m.result()
self.evaluate(m([[3], [5], [3]])) self.evaluate(m([[3], [5], [3]]))
self.assertAllEqual(m._shape, [3, 1]) self.assertAllEqual(m._shape, [3, 1])
m2 = metrics.MeanTensor.from_config(m.get_config()) m2 = metrics.MeanTensor.from_config(m.get_config())
self.assertEqual(m2.name, 'mean_by_element') self.assertEqual(m2.name, 'mean_by_element')
self.assertTrue(m2.stateful) self.assertTrue(m2.stateful)
self.assertEqual(m2.dtype, dtypes.float32) self.assertEqual(m2.dtype, dtypes.float32)
self.assertEqual(len(m2.variables), 0) self.assertEmpty(m2.variables)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_unweighted(self): def test_unweighted(self):
m = metrics.MeanTensor(dtype=dtypes.float64) with self.test_session():
m = metrics.MeanTensor(dtype=dtypes.float64)
# check __call__() # check __call__()
self.assertAllClose(self.evaluate(m([100, 40])), [100, 40]) self.assertAllClose(self.evaluate(m([100, 40])), [100, 40])
self.assertAllClose(self.evaluate(m.total), [100, 40]) self.assertAllClose(self.evaluate(m.total), [100, 40])
self.assertAllClose(self.evaluate(m.count), [1, 1]) self.assertAllClose(self.evaluate(m.count), [1, 1])
# check update_state() and result() + state accumulation + tensor input # check update_state() and result() + state accumulation + tensor input
update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
self.evaluate(update_op) self.evaluate(update_op)
self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5]) self.assertAllClose(self.evaluate(m.result()), [50.5, 22.5])
self.assertAllClose(self.evaluate(m.total), [101, 45]) self.assertAllClose(self.evaluate(m.total), [101, 45])
self.assertAllClose(self.evaluate(m.count), [2, 2]) self.assertAllClose(self.evaluate(m.count), [2, 2])
# check reset_states() # check reset_states()
m.reset_states() m.reset_states()
self.assertAllClose(self.evaluate(m.total), [0, 0]) self.assertAllClose(self.evaluate(m.total), [0, 0])
self.assertAllClose(self.evaluate(m.count), [0, 0]) self.assertAllClose(self.evaluate(m.count), [0, 0])
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weighted(self): def test_weighted(self):
m = metrics.MeanTensor(dtype=dtypes.float64) with self.test_session():
self.assertEqual(m.dtype, dtypes.float64) m = metrics.MeanTensor(dtype=dtypes.float64)
self.assertEqual(m.dtype, dtypes.float64)
# check scalar weight # check scalar weight
result_t = m([100, 30], sample_weight=0.5) result_t = m([100, 30], sample_weight=0.5)
self.assertAllClose(self.evaluate(result_t), [100, 30]) self.assertAllClose(self.evaluate(result_t), [100, 30])
self.assertAllClose(self.evaluate(m.total), [50, 15]) self.assertAllClose(self.evaluate(m.total), [50, 15])
self.assertAllClose(self.evaluate(m.count), [0.5, 0.5]) self.assertAllClose(self.evaluate(m.count), [0.5, 0.5])
# check weights not scalar and weights rank matches values rank # check weights not scalar and weights rank matches values rank
result_t = m([1, 5], sample_weight=[1, 0.2]) result_t = m([1, 5], sample_weight=[1, 0.2])
result = self.evaluate(result_t) result = self.evaluate(result_t)
self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2) self.assertAllClose(result, [51 / 1.5, 16 / 0.7], 2)
self.assertAllClose(self.evaluate(m.total), [51, 16]) self.assertAllClose(self.evaluate(m.total), [51, 16])
self.assertAllClose(self.evaluate(m.count), [1.5, 0.7]) self.assertAllClose(self.evaluate(m.count), [1.5, 0.7])
# check weights broadcast # check weights broadcast
result_t = m([1, 2], sample_weight=0.5) result_t = m([1, 2], sample_weight=0.5)
self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2]) self.assertAllClose(self.evaluate(result_t), [51.5 / 2, 17 / 1.2])
self.assertAllClose(self.evaluate(m.total), [51.5, 17]) self.assertAllClose(self.evaluate(m.total), [51.5, 17])
self.assertAllClose(self.evaluate(m.count), [2, 1.2]) self.assertAllClose(self.evaluate(m.count), [2, 1.2])
# check weights squeeze # check weights squeeze
result_t = m([1, 5], sample_weight=[[1], [0.2]]) result_t = m([1, 5], sample_weight=[[1], [0.2]])
self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4]) self.assertAllClose(self.evaluate(result_t), [52.5 / 3, 18 / 1.4])
self.assertAllClose(self.evaluate(m.total), [52.5, 18]) self.assertAllClose(self.evaluate(m.total), [52.5, 18])
self.assertAllClose(self.evaluate(m.count), [3, 1.4]) self.assertAllClose(self.evaluate(m.count), [3, 1.4])
# check weights expand # check weights expand
m = metrics.MeanTensor(dtype=dtypes.float64) m = metrics.MeanTensor(dtype=dtypes.float64)
self.evaluate(variables.variables_initializer(m.variables)) self.evaluate(variables.variables_initializer(m.variables))
result_t = m([[1], [5]], sample_weight=[1, 0.2]) result_t = m([[1], [5]], sample_weight=[1, 0.2])
self.assertAllClose(self.evaluate(result_t), [[1], [5]]) self.assertAllClose(self.evaluate(result_t), [[1], [5]])
self.assertAllClose(self.evaluate(m.total), [[1], [1]]) self.assertAllClose(self.evaluate(m.total), [[1], [1]])
self.assertAllClose(self.evaluate(m.count), [[1], [0.2]]) self.assertAllClose(self.evaluate(m.count), [[1], [0.2]])
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_invalid_value_shape(self): def test_invalid_value_shape(self):
m = metrics.MeanTensor(dtype=dtypes.float64) m = metrics.MeanTensor(dtype=dtypes.float64)
m([1]) m([1])
@ -1465,7 +1472,7 @@ class MeanTensorTest(test.TestCase):
ValueError, 'MeanTensor input values must always have the same shape'): ValueError, 'MeanTensor input values must always have the same shape'):
m([1, 5]) m([1, 5])
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_build_in_tf_function(self): def test_build_in_tf_function(self):
"""Ensure that variables are created correctly in a tf function.""" """Ensure that variables are created correctly in a tf function."""
m = metrics.MeanTensor(dtype=dtypes.float64) m = metrics.MeanTensor(dtype=dtypes.float64)
@ -1474,10 +1481,11 @@ class MeanTensorTest(test.TestCase):
def call_metric(x): def call_metric(x):
return m(x) return m(x)
self.assertAllClose(self.evaluate(call_metric([100, 40])), [100, 40]) with self.test_session():
self.assertAllClose(self.evaluate(m.total), [100, 40]) self.assertAllClose(self.evaluate(call_metric([100, 40])), [100, 40])
self.assertAllClose(self.evaluate(m.count), [1, 1]) self.assertAllClose(self.evaluate(m.total), [100, 40])
self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21]) self.assertAllClose(self.evaluate(m.count), [1, 1])
self.assertAllClose(self.evaluate(call_metric([20, 2])), [60, 21])
def test_in_keras_model(self): def test_in_keras_model(self):
with context.eager_mode(): with context.eager_mode():
@ -1522,7 +1530,7 @@ class MeanTensorTest(test.TestCase):
np.full((4, 3), 4)) np.full((4, 3), 4))
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class BinaryCrossentropyTest(test.TestCase): class BinaryCrossentropyTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1642,7 +1650,7 @@ class BinaryCrossentropyTest(test.TestCase):
self.assertAllClose(expected_value, self.evaluate(result), atol=1e-3) self.assertAllClose(expected_value, self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CategoricalCrossentropyTest(test.TestCase): class CategoricalCrossentropyTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1768,7 +1776,7 @@ class CategoricalCrossentropyTest(test.TestCase):
self.assertAllClose(self.evaluate(loss), 3.667, atol=1e-3) self.assertAllClose(self.evaluate(loss), 3.667, atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SparseCategoricalCrossentropyTest(test.TestCase): class SparseCategoricalCrossentropyTest(test.TestCase):
def test_config(self): def test_config(self):
@ -1943,7 +1951,7 @@ class BinaryTruePositives(metrics.Metric):
return self.true_positives return self.true_positives
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CustomMetricsTest(test.TestCase): class CustomMetricsTest(test.TestCase):
def test_config(self): def test_config(self):