Update tests under keras.saving to use combinations.

Change all test_util.run_all_in_graph_and_eager_modes to combination.

PiperOrigin-RevId: 301396996
Change-Id: I1d79695f819bb289a428b3fd97965841a873bda9
This commit is contained in:
Scott Zhu 2020-03-17 10:06:33 -07:00 committed by TensorFlower Gardener
parent 8d01f78f82
commit f7896058b2
5 changed files with 144 additions and 131 deletions

View File

@ -92,6 +92,7 @@ tf_py_test(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@ -106,6 +107,7 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@ -142,6 +144,7 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@ -156,6 +159,7 @@ tf_py_test(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],

View File

@ -30,7 +30,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import testing_utils
@ -51,10 +51,10 @@ except ImportError:
h5py = None
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats
@test_util.run_in_graph_and_eager_modes
def test_weight_loading(self):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
@ -83,7 +83,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
y = model.predict(x)
self.assertAllClose(ref_y, y)
@test_util.run_in_graph_and_eager_modes
def test_weight_preprocessing(self):
input_dim = 3
output_dim = 3
@ -210,7 +209,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
for (x, y) in zip(weights1, weights2)
]
@test_util.run_in_graph_and_eager_modes
def test_sequential_weight_loading(self):
if h5py is None:
return
@ -243,7 +241,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y)
@keras_parameterized.run_with_all_saved_model_formats
@test_util.run_in_graph_and_eager_modes
def test_nested_model_weight_loading(self):
save_format = testing_utils.get_save_format()
temp_dir = self.get_temp_dir()
@ -282,7 +279,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y)
@test_util.run_in_graph_and_eager_modes
def test_sequential_weight_loading_group_name_with_incorrect_length(self):
if h5py is None:
return
@ -314,16 +310,16 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
model.compile(loss=keras.losses.MSE,
optimizer='rmsprop',
metrics=[keras.metrics.categorical_accuracy])
with self.assertRaisesRegexp(ValueError,
r'Layer #0 \(named \"d1\"\) expects 1 '
r'weight\(s\), but the saved weights have 2 '
r'element\(s\)\.'):
hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
with self.assertRaisesRegexp(ValueError,
r'Layer #0 \(named \"d1\"\) expects 1 '
r'weight\(s\), but the saved weights have 2 '
r'element\(s\)\.'):
hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
hdf5_format.load_weights_from_hdf5_group_by_name(
f_model, model.layers, skip_mismatch=True)
self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel),
keras.backend.get_value(model.layers[1].kernel))
hdf5_format.load_weights_from_hdf5_group_by_name(
f_model, model.layers, skip_mismatch=True)
self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel),
keras.backend.get_value(model.layers[1].kernel))
def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
if h5py is None:
@ -779,7 +775,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
self.assertRegexpMatches(
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_functional_model_with_custom_loss_and_metric(self):
def _make_model():
inputs = keras.Input(shape=(4,))
@ -818,7 +814,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
evaluation_results['sparse_categorical_crossentropy'] +
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_save_uncompiled_model_with_optimizer(self):
with self.cached_session() as session:
saved_model_dir = self._save_model_dir()
@ -901,6 +897,7 @@ class _make_subclassed_built(_make_subclassed): # pylint: disable=invalid-name
self.build((None, input_size))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
"""Tests saving a whole model that contains other models."""
@ -913,7 +910,6 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
('subclassed', _make_subclassed),
('subclassed_built', _make_subclassed_built),
])
@test_util.run_in_graph_and_eager_modes
def test_functional(self, model_fn):
"""Tests serializing a model that uses a nested model to share weights."""
if h5py is None:
@ -926,22 +922,23 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
return keras.Model(inputs=inputs, outputs=outputs)
x = (np.random.normal(size=(16, 4)).astype(np.float32),
np.random.normal(size=(16, 4)).astype(np.float32))
model = _make_model()
predictions = model(x)
# Save and reload.
model_path = os.path.join(self.get_temp_dir(), 'model.h5')
model.save(model_path)
del model
loaded_model = keras.models.load_model(
model_path,
custom_objects={
'_make_subclassed': _make_subclassed,
'_make_subclassed_built': _make_subclassed_built,
},
compile=False)
self.assertAllClose(loaded_model(x), predictions, 1e-9)
with self.cached_session():
x = (np.random.normal(size=(16, 4)).astype(np.float32),
np.random.normal(size=(16, 4)).astype(np.float32))
model = _make_model()
predictions = model(x)
# Save and reload.
model_path = os.path.join(self.get_temp_dir(), 'model.h5')
model.save(model_path)
del model
loaded_model = keras.models.load_model(
model_path,
custom_objects={
'_make_subclassed': _make_subclassed,
'_make_subclassed_built': _make_subclassed_built,
},
compile=False)
self.assertAllClose(loaded_model(x), predictions, 1e-9)
class SubclassedModel(training.Model):
@ -955,7 +952,7 @@ class SubclassedModel(training.Model):
return self.b_layer(self.x_layer(a))
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
def test_keras_optimizer_warning(self):
graph = ops.Graph()
@ -974,7 +971,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
str(mock_log.call_args),
'Keras optimizer')
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_tensorflow_format_overwrite(self):
with self.cached_session() as session:
model = SubclassedModel()
@ -1025,12 +1022,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
model.save_weights(prefix, save_format='tensorflow')
op_count = len(graph.get_operations())
model.save_weights(prefix, save_format='tensorflow')
self.assertEqual(len(graph.get_operations()), op_count)
self.assertLen(graph.get_operations(), op_count)
model.load_weights(prefix)
op_count = len(graph.get_operations())
model.load_weights(prefix)
self.assertEqual(len(graph.get_operations()), op_count)
self.assertLen(graph.get_operations(), op_count)
def _weight_loading_test_template(self, make_model_fn):
with self.cached_session():
@ -1079,7 +1076,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
load_model.train_on_batch(train_x, train_y)
self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model(self):
def _make_graph_model():
a = keras.layers.Input(shape=(2,))
@ -1089,7 +1086,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._weight_loading_test_template(_make_graph_model)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_subclassed_model(self):
self._weight_loading_test_template(SubclassedModel)
@ -1127,7 +1124,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
y = self.evaluate(model(x))
self.assertAllClose(ref_y, y)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model_added_layer(self):
def _save_graph_model():
a = keras.layers.Input(shape=(2,))
@ -1144,7 +1141,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_graph_model_added_no_weight_layer(self):
def _save_graph_model():
a = keras.layers.Input(shape=(2,))
@ -1161,7 +1158,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template(
_save_graph_model, _restore_graph_model)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_weight_loading_subclassed_model_added_layer(self):
class SubclassedModelRestore(training.Model):
@ -1178,7 +1175,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._new_layer_weight_loading_test_template(
SubclassedModel, SubclassedModelRestore)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_incompatible_checkpoint(self):
save_path = trackable.Checkpoint().save(
os.path.join(self.get_temp_dir(), 'ckpt'))
@ -1191,57 +1188,62 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
AssertionError, 'Nothing except the root object matched'):
m.load_weights(save_path)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_directory_passed(self):
m = keras.Model()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/')
m.save_weights(prefix)
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
with self.cached_session():
m = keras.Model()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(),
'{}'.format(ops.uid()), 'ckpt/')
m.save_weights(prefix)
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_relative_path(self):
m = keras.Model()
v = m.add_weight(name='v', shape=[])
os.chdir(self.get_temp_dir())
with self.cached_session():
m = keras.Model()
v = m.add_weight(name='v', shape=[])
os.chdir(self.get_temp_dir())
prefix = 'ackpt'
self.evaluate(v.assign(42.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('ackpt.index'))
self.evaluate(v.assign(1.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
prefix = 'ackpt'
self.evaluate(v.assign(42.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('ackpt.index'))
self.evaluate(v.assign(1.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
prefix = 'subdir/ackpt'
self.evaluate(v.assign(43.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(43., self.evaluate(v))
prefix = 'subdir/ackpt'
self.evaluate(v.assign(43.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(43., self.evaluate(v))
prefix = 'ackpt/'
self.evaluate(v.assign(44.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('ackpt/.index'))
self.evaluate(v.assign(3.))
m.load_weights(prefix)
self.assertEqual(44., self.evaluate(v))
prefix = 'ackpt/'
self.evaluate(v.assign(44.))
m.save_weights(prefix)
self.assertTrue(file_io.file_exists('ackpt/.index'))
self.evaluate(v.assign(3.))
m.load_weights(prefix)
self.assertEqual(44., self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nonexistent_prefix_directory(self):
m = keras.Model()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt')
m.save_weights(prefix)
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
with self.cached_session():
m = keras.Model()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(),
'{}'.format(ops.uid()), 'bckpt')
m.save_weights(prefix)
self.evaluate(v.assign(2.))
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
if __name__ == '__main__':
test.main()

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import os
import sys
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
@ -28,6 +29,7 @@ from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.saving import model_config
from tensorflow.python.keras.saving import save
@ -43,7 +45,7 @@ except ImportError:
h5py = None
class TestSaveModel(test.TestCase):
class TestSaveModel(test.TestCase, parameterized.TestCase):
def setUp(self):
super(TestSaveModel, self).setUp()
@ -99,7 +101,7 @@ class TestSaveModel(test.TestCase):
save.save_model(self.model, path, save_format='tf')
save.load_model(path)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_with_dense_features(self):
cols = [
feature_column_lib.numeric_column('a'),
@ -128,13 +130,14 @@ class TestSaveModel(test.TestCase):
inputs_a = np.arange(10).reshape(10, 1)
inputs_b = np.arange(10).reshape(10, 1).astype('str')
# Initialize tables for V1 lookup.
if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer())
with self.cached_session():
# Initialize tables for V1 lookup.
if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer())
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_with_sequence_features(self):
cols = [
feature_column_lib.sequence_numeric_column('a'),
@ -182,17 +185,18 @@ class TestSaveModel(test.TestCase):
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
(batch_size, timesteps, 1))
# Initialize tables for V1 lookup.
if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer())
with self.cached_session():
# Initialize tables for V1 lookup.
if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer())
self.assertLen(
loaded_model.predict({
'a': inputs_a,
'b': inputs_b
}, steps=1), batch_size)
self.assertLen(
loaded_model.predict({
'a': inputs_a,
'b': inputs_b
}, steps=1), batch_size)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_h5_for_rnn_layers(self):
# See https://github.com/tensorflow/tensorflow/issues/35731 for details.
inputs = keras.Input([10, 91], name='train_input')
@ -213,7 +217,7 @@ class TestSaveModel(test.TestCase):
rnn_layers[1].kernel.name)
self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_optimizer_weights(self):
class MyModel(keras.Model):

View File

@ -44,7 +44,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils
@ -700,7 +700,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(model(input_arr), loaded(input_arr))
class TestLayerCallTracing(test.TestCase):
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
def test_functions_have_same_trace(self):
@ -773,7 +773,7 @@ class TestLayerCallTracing(test.TestCase):
assert_num_traces(LayerWithChildLayer, training_keyword=False)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_maintains_losses(self):
layer = LayerWithLoss()
layer(np.ones((2, 3)))
@ -786,7 +786,7 @@ class TestLayerCallTracing(test.TestCase):
self.assertAllEqual(previous_losses, layer.losses)
@test_util.run_all_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MetricTest(test.TestCase, parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
@ -870,28 +870,30 @@ class MetricTest(test.TestCase, parameterized.TestCase):
# while returning nothing.
super(CustomMetric, self).update_state(*args)
metric = CustomMetric()
save_dir = self._save_model_dir('first_save')
with self.cached_session():
metric = CustomMetric()
save_dir = self._save_model_dir('first_save')
if requires_build:
metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable
if requires_build:
metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable
self.evaluate([v.initializer for v in metric.variables])
self.evaluate([v.initializer for v in metric.variables])
with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'):
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
loaded = self._test_metric_save_and_load(
metric,
save_dir,
num_tensor_args,
test_sample_weight=False)
with self.assertRaisesRegexp(ValueError,
'Unable to restore custom object'):
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
loaded = self._test_metric_save_and_load(
metric,
save_dir,
num_tensor_args,
test_sample_weight=False)
self._test_metric_save_and_load(
loaded,
self._save_model_dir('second_save'),
num_tensor_args,
test_sample_weight=False)
self._test_metric_save_and_load(
loaded,
self._save_model_dir('second_save'),
num_tensor_args,
test_sample_weight=False)
def test_custom_metric_wrapped_call(self):

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
@ -62,7 +63,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self.assertAllClose(expected, actual)
@keras_parameterized.run_with_all_model_types
@test_util.run_in_graph_and_eager_modes
@keras_parameterized.run_all_keras_modes
def test_trace_model_outputs(self):
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
model = testing_utils.get_small_mlp(10, 3, input_dim)
@ -155,7 +156,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
self._assert_all_close(expected_outputs, signature_outputs)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_trace_features_layer(self):
columns = [feature_column_lib.numeric_column('x')]
model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
@ -176,7 +177,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self.assertAllClose({'output_1': [[1., 2.]]},
fn({'x': [[1.]], 'y': [[2.]]}))
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_specify_input_signature(self):
model = testing_utils.get_small_sequential_mlp(10, 3, None)
inputs = array_ops.ones((8, 5))
@ -193,7 +194,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
expected_outputs = {'output_1': model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs)
@test_util.run_in_graph_and_eager_modes
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_subclassed_model_with_input_signature(self):
class Model(keras.Model):
@ -218,7 +219,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_with_all_model_types
@test_util.run_in_graph_and_eager_modes
@keras_parameterized.run_all_keras_modes
def test_model_with_fixed_input_dim(self):
"""Ensure that the batch_dim is removed when saving.