Update tests under keras.saving to use combinations.

Change all test_util.run_all_in_graph_and_eager_modes to combination.

PiperOrigin-RevId: 301396996
Change-Id: I1d79695f819bb289a428b3fd97965841a873bda9
This commit is contained in:
Scott Zhu 2020-03-17 10:06:33 -07:00 committed by TensorFlower Gardener
parent 8d01f78f82
commit f7896058b2
5 changed files with 144 additions and 131 deletions

View File

@ -92,6 +92,7 @@ tf_py_test(
deps = [ deps = [
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -106,6 +107,7 @@ tf_py_test(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/feature_column:feature_column_v2", "//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -142,6 +144,7 @@ tf_py_test(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -156,6 +159,7 @@ tf_py_test(
deps = [ deps = [
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],

View File

@ -30,7 +30,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
@ -51,10 +51,10 @@ except ImportError:
h5py = None h5py = None
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats @keras_parameterized.run_with_all_saved_model_formats
@test_util.run_in_graph_and_eager_modes
def test_weight_loading(self): def test_weight_loading(self):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir) self.addCleanup(shutil.rmtree, temp_dir)
@ -83,7 +83,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
y = model.predict(x) y = model.predict(x)
self.assertAllClose(ref_y, y) self.assertAllClose(ref_y, y)
@test_util.run_in_graph_and_eager_modes
def test_weight_preprocessing(self): def test_weight_preprocessing(self):
input_dim = 3 input_dim = 3
output_dim = 3 output_dim = 3
@ -210,7 +209,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
for (x, y) in zip(weights1, weights2) for (x, y) in zip(weights1, weights2)
] ]
@test_util.run_in_graph_and_eager_modes
def test_sequential_weight_loading(self): def test_sequential_weight_loading(self):
if h5py is None: if h5py is None:
return return
@ -243,7 +241,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y) self.assertAllClose(y, ref_y)
@keras_parameterized.run_with_all_saved_model_formats @keras_parameterized.run_with_all_saved_model_formats
@test_util.run_in_graph_and_eager_modes
def test_nested_model_weight_loading(self): def test_nested_model_weight_loading(self):
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
@ -282,7 +279,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y) self.assertAllClose(y, ref_y)
@test_util.run_in_graph_and_eager_modes
def test_sequential_weight_loading_group_name_with_incorrect_length(self): def test_sequential_weight_loading_group_name_with_incorrect_length(self):
if h5py is None: if h5py is None:
return return
@ -779,7 +775,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
self.assertRegexpMatches( self.assertRegexpMatches(
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$') h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_functional_model_with_custom_loss_and_metric(self): def test_functional_model_with_custom_loss_and_metric(self):
def _make_model(): def _make_model():
inputs = keras.Input(shape=(4,)) inputs = keras.Input(shape=(4,))
@ -818,7 +814,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
evaluation_results['sparse_categorical_crossentropy'] + evaluation_results['sparse_categorical_crossentropy'] +
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6) evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_save_uncompiled_model_with_optimizer(self): def test_save_uncompiled_model_with_optimizer(self):
with self.cached_session() as session: with self.cached_session() as session:
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
@ -901,6 +897,7 @@ class _make_subclassed_built(_make_subclassed): # pylint: disable=invalid-name
self.build((None, input_size)) self.build((None, input_size))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase): class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
"""Tests saving a whole model that contains other models.""" """Tests saving a whole model that contains other models."""
@ -913,7 +910,6 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
('subclassed', _make_subclassed), ('subclassed', _make_subclassed),
('subclassed_built', _make_subclassed_built), ('subclassed_built', _make_subclassed_built),
]) ])
@test_util.run_in_graph_and_eager_modes
def test_functional(self, model_fn): def test_functional(self, model_fn):
"""Tests serializing a model that uses a nested model to share weights.""" """Tests serializing a model that uses a nested model to share weights."""
if h5py is None: if h5py is None:
@ -926,6 +922,7 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])]) outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
return keras.Model(inputs=inputs, outputs=outputs) return keras.Model(inputs=inputs, outputs=outputs)
with self.cached_session():
x = (np.random.normal(size=(16, 4)).astype(np.float32), x = (np.random.normal(size=(16, 4)).astype(np.float32),
np.random.normal(size=(16, 4)).astype(np.float32)) np.random.normal(size=(16, 4)).astype(np.float32))
model = _make_model() model = _make_model()
@ -955,7 +952,7 @@ class SubclassedModel(training.Model):
return self.b_layer(self.x_layer(a)) return self.b_layer(self.x_layer(a))
class TestWeightSavingAndLoadingTFFormat(test.TestCase): class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
def test_keras_optimizer_warning(self): def test_keras_optimizer_warning(self):
graph = ops.Graph() graph = ops.Graph()
@ -974,7 +971,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
str(mock_log.call_args), str(mock_log.call_args),
'Keras optimizer') 'Keras optimizer')
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_tensorflow_format_overwrite(self): def test_tensorflow_format_overwrite(self):
with self.cached_session() as session: with self.cached_session() as session:
model = SubclassedModel() model = SubclassedModel()
@ -1025,12 +1022,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
model.save_weights(prefix, save_format='tensorflow') model.save_weights(prefix, save_format='tensorflow')
op_count = len(graph.get_operations()) op_count = len(graph.get_operations())
model.save_weights(prefix, save_format='tensorflow') model.save_weights(prefix, save_format='tensorflow')
self.assertEqual(len(graph.get_operations()), op_count) self.assertLen(graph.get_operations(), op_count)
model.load_weights(prefix) model.load_weights(prefix)
op_count = len(graph.get_operations()) op_count = len(graph.get_operations())
model.load_weights(prefix) model.load_weights(prefix)
self.assertEqual(len(graph.get_operations()), op_count) self.assertLen(graph.get_operations(), op_count)
def _weight_loading_test_template(self, make_model_fn): def _weight_loading_test_template(self, make_model_fn):
with self.cached_session(): with self.cached_session():
@ -1079,7 +1076,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
load_model.train_on_batch(train_x, train_y) load_model.train_on_batch(train_x, train_y)
self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model(self): def test_weight_loading_graph_model(self):
def _make_graph_model(): def _make_graph_model():
a = keras.layers.Input(shape=(2,)) a = keras.layers.Input(shape=(2,))
@ -1089,7 +1086,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._weight_loading_test_template(_make_graph_model) self._weight_loading_test_template(_make_graph_model)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_subclassed_model(self): def test_weight_loading_subclassed_model(self):
self._weight_loading_test_template(SubclassedModel) self._weight_loading_test_template(SubclassedModel)
@ -1127,7 +1124,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
y = self.evaluate(model(x)) y = self.evaluate(model(x))
self.assertAllClose(ref_y, y) self.assertAllClose(ref_y, y)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model_added_layer(self): def test_weight_loading_graph_model_added_layer(self):
def _save_graph_model(): def _save_graph_model():
a = keras.layers.Input(shape=(2,)) a = keras.layers.Input(shape=(2,))
@ -1144,7 +1141,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template( self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model) _save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model_added_no_weight_layer(self): def test_weight_loading_graph_model_added_no_weight_layer(self):
def _save_graph_model(): def _save_graph_model():
a = keras.layers.Input(shape=(2,)) a = keras.layers.Input(shape=(2,))
@ -1161,7 +1158,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template( self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model) _save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_subclassed_model_added_layer(self): def test_weight_loading_subclassed_model_added_layer(self):
class SubclassedModelRestore(training.Model): class SubclassedModelRestore(training.Model):
@ -1178,7 +1175,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template( self._new_layer_weight_loading_test_template(
SubclassedModel, SubclassedModelRestore) SubclassedModel, SubclassedModelRestore)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_incompatible_checkpoint(self): def test_incompatible_checkpoint(self):
save_path = trackable.Checkpoint().save( save_path = trackable.Checkpoint().save(
os.path.join(self.get_temp_dir(), 'ckpt')) os.path.join(self.get_temp_dir(), 'ckpt'))
@ -1191,19 +1188,22 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
AssertionError, 'Nothing except the root object matched'): AssertionError, 'Nothing except the root object matched'):
m.load_weights(save_path) m.load_weights(save_path)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_directory_passed(self): def test_directory_passed(self):
with self.cached_session():
m = keras.Model() m = keras.Model()
v = m.add_weight(name='v', shape=[]) v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.)) self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/') prefix = os.path.join(self.get_temp_dir(),
'{}'.format(ops.uid()), 'ckpt/')
m.save_weights(prefix) m.save_weights(prefix)
self.evaluate(v.assign(2.)) self.evaluate(v.assign(2.))
m.load_weights(prefix) m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v)) self.assertEqual(42., self.evaluate(v))
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_relative_path(self): def test_relative_path(self):
with self.cached_session():
m = keras.Model() m = keras.Model()
v = m.add_weight(name='v', shape=[]) v = m.add_weight(name='v', shape=[])
os.chdir(self.get_temp_dir()) os.chdir(self.get_temp_dir())
@ -1232,12 +1232,14 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
m.load_weights(prefix) m.load_weights(prefix)
self.assertEqual(44., self.evaluate(v)) self.assertEqual(44., self.evaluate(v))
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nonexistent_prefix_directory(self): def test_nonexistent_prefix_directory(self):
with self.cached_session():
m = keras.Model() m = keras.Model()
v = m.add_weight(name='v', shape=[]) v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.)) self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt') prefix = os.path.join(self.get_temp_dir(),
'{}'.format(ops.uid()), 'bckpt')
m.save_weights(prefix) m.save_weights(prefix)
self.evaluate(v.assign(2.)) self.evaluate(v.assign(2.))
m.load_weights(prefix) m.load_weights(prefix)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import sys import sys
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
@ -28,6 +29,7 @@ from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.saving import model_config from tensorflow.python.keras.saving import model_config
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
@ -43,7 +45,7 @@ except ImportError:
h5py = None h5py = None
class TestSaveModel(test.TestCase): class TestSaveModel(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(TestSaveModel, self).setUp() super(TestSaveModel, self).setUp()
@ -99,7 +101,7 @@ class TestSaveModel(test.TestCase):
save.save_model(self.model, path, save_format='tf') save.save_model(self.model, path, save_format='tf')
save.load_model(path) save.load_model(path)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_with_dense_features(self): def test_saving_with_dense_features(self):
cols = [ cols = [
feature_column_lib.numeric_column('a'), feature_column_lib.numeric_column('a'),
@ -128,13 +130,14 @@ class TestSaveModel(test.TestCase):
inputs_a = np.arange(10).reshape(10, 1) inputs_a = np.arange(10).reshape(10, 1)
inputs_b = np.arange(10).reshape(10, 1).astype('str') inputs_b = np.arange(10).reshape(10, 1).astype('str')
with self.cached_session():
# Initialize tables for V1 lookup. # Initialize tables for V1 lookup.
if not context.executing_eagerly(): if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer()) self.evaluate(lookup_ops.tables_initializer())
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10) self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_with_sequence_features(self): def test_saving_with_sequence_features(self):
cols = [ cols = [
feature_column_lib.sequence_numeric_column('a'), feature_column_lib.sequence_numeric_column('a'),
@ -182,6 +185,7 @@ class TestSaveModel(test.TestCase):
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b, inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
(batch_size, timesteps, 1)) (batch_size, timesteps, 1))
with self.cached_session():
# Initialize tables for V1 lookup. # Initialize tables for V1 lookup.
if not context.executing_eagerly(): if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer()) self.evaluate(lookup_ops.tables_initializer())
@ -192,7 +196,7 @@ class TestSaveModel(test.TestCase):
'b': inputs_b 'b': inputs_b
}, steps=1), batch_size) }, steps=1), batch_size)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_h5_for_rnn_layers(self): def test_saving_h5_for_rnn_layers(self):
# See https://github.com/tensorflow/tensorflow/issues/35731 for details. # See https://github.com/tensorflow/tensorflow/issues/35731 for details.
inputs = keras.Input([10, 91], name='train_input') inputs = keras.Input([10, 91], name='train_input')
@ -213,7 +217,7 @@ class TestSaveModel(test.TestCase):
rnn_layers[1].kernel.name) rnn_layers[1].kernel.name)
self.assertIn('rnn_cell1', rnn_layers[1].kernel.name) self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_optimizer_weights(self): def test_saving_optimizer_weights(self):
class MyModel(keras.Model): class MyModel(keras.Model):

View File

@ -44,7 +44,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
@ -700,7 +700,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(model(input_arr), loaded(input_arr)) self.assertAllClose(model(input_arr), loaded(input_arr))
class TestLayerCallTracing(test.TestCase): class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
def test_functions_have_same_trace(self): def test_functions_have_same_trace(self):
@ -773,7 +773,7 @@ class TestLayerCallTracing(test.TestCase):
assert_num_traces(LayerWithChildLayer, training_keyword=False) assert_num_traces(LayerWithChildLayer, training_keyword=False)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_maintains_losses(self): def test_maintains_losses(self):
layer = LayerWithLoss() layer = LayerWithLoss()
layer(np.ones((2, 3))) layer(np.ones((2, 3)))
@ -786,7 +786,7 @@ class TestLayerCallTracing(test.TestCase):
self.assertAllEqual(previous_losses, layer.losses) self.assertAllEqual(previous_losses, layer.losses)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MetricTest(test.TestCase, parameterized.TestCase): class MetricTest(test.TestCase, parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'): def _save_model_dir(self, dirname='saved_model'):
@ -870,6 +870,7 @@ class MetricTest(test.TestCase, parameterized.TestCase):
# while returning nothing. # while returning nothing.
super(CustomMetric, self).update_state(*args) super(CustomMetric, self).update_state(*args)
with self.cached_session():
metric = CustomMetric() metric = CustomMetric()
save_dir = self._save_model_dir('first_save') save_dir = self._save_model_dir('first_save')
@ -878,7 +879,8 @@ class MetricTest(test.TestCase, parameterized.TestCase):
self.evaluate([v.initializer for v in metric.variables]) self.evaluate([v.initializer for v in metric.variables])
with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'): with self.assertRaisesRegexp(ValueError,
'Unable to restore custom object'):
self._test_metric_save_and_load(metric, save_dir, num_tensor_args) self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}): with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
loaded = self._test_metric_save_and_load( loaded = self._test_metric_save_and_load(

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import sequential
@ -62,7 +63,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self.assertAllClose(expected, actual) self.assertAllClose(expected, actual)
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@test_util.run_in_graph_and_eager_modes @keras_parameterized.run_all_keras_modes
def test_trace_model_outputs(self): def test_trace_model_outputs(self):
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
model = testing_utils.get_small_mlp(10, 3, input_dim) model = testing_utils.get_small_mlp(10, 3, input_dim)
@ -155,7 +156,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]} expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
self._assert_all_close(expected_outputs, signature_outputs) self._assert_all_close(expected_outputs, signature_outputs)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_trace_features_layer(self): def test_trace_features_layer(self):
columns = [feature_column_lib.numeric_column('x')] columns = [feature_column_lib.numeric_column('x')]
model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)]) model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
@ -176,7 +177,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self.assertAllClose({'output_1': [[1., 2.]]}, self.assertAllClose({'output_1': [[1., 2.]]},
fn({'x': [[1.]], 'y': [[2.]]})) fn({'x': [[1.]], 'y': [[2.]]}))
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_specify_input_signature(self): def test_specify_input_signature(self):
model = testing_utils.get_small_sequential_mlp(10, 3, None) model = testing_utils.get_small_sequential_mlp(10, 3, None)
inputs = array_ops.ones((8, 5)) inputs = array_ops.ones((8, 5))
@ -193,7 +194,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
expected_outputs = {'output_1': model(inputs)} expected_outputs = {'output_1': model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs) self._assert_all_close(expected_outputs, signature_outputs)
@test_util.run_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_subclassed_model_with_input_signature(self): def test_subclassed_model_with_input_signature(self):
class Model(keras.Model): class Model(keras.Model):
@ -218,7 +219,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self._assert_all_close(expected_outputs, signature_outputs) self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@test_util.run_in_graph_and_eager_modes @keras_parameterized.run_all_keras_modes
def test_model_with_fixed_input_dim(self): def test_model_with_fixed_input_dim(self):
"""Ensure that the batch_dim is removed when saving. """Ensure that the batch_dim is removed when saving.