Update tests under keras.saving to use combinations.
Change all test_util.run_all_in_graph_and_eager_modes to combination. PiperOrigin-RevId: 301396996 Change-Id: I1d79695f819bb289a428b3fd97965841a873bda9
This commit is contained in:
parent
8d01f78f82
commit
f7896058b2
@ -92,6 +92,7 @@ tf_py_test(
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -106,6 +107,7 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/feature_column:feature_column_v2",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -142,6 +144,7 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -156,6 +159,7 @@ tf_py_test(
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -51,10 +51,10 @@ except ImportError:
|
||||
h5py = None
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_saved_model_formats
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_weight_loading(self):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
@ -83,7 +83,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
y = model.predict(x)
|
||||
self.assertAllClose(ref_y, y)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_weight_preprocessing(self):
|
||||
input_dim = 3
|
||||
output_dim = 3
|
||||
@ -210,7 +209,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
for (x, y) in zip(weights1, weights2)
|
||||
]
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_weight_loading(self):
|
||||
if h5py is None:
|
||||
return
|
||||
@ -243,7 +241,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(y, ref_y)
|
||||
|
||||
@keras_parameterized.run_with_all_saved_model_formats
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_nested_model_weight_loading(self):
|
||||
save_format = testing_utils.get_save_format()
|
||||
temp_dir = self.get_temp_dir()
|
||||
@ -282,7 +279,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllClose(y, ref_y)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_weight_loading_group_name_with_incorrect_length(self):
|
||||
if h5py is None:
|
||||
return
|
||||
@ -314,16 +310,16 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
model.compile(loss=keras.losses.MSE,
|
||||
optimizer='rmsprop',
|
||||
metrics=[keras.metrics.categorical_accuracy])
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Layer #0 \(named \"d1\"\) expects 1 '
|
||||
r'weight\(s\), but the saved weights have 2 '
|
||||
r'element\(s\)\.'):
|
||||
hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Layer #0 \(named \"d1\"\) expects 1 '
|
||||
r'weight\(s\), but the saved weights have 2 '
|
||||
r'element\(s\)\.'):
|
||||
hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
|
||||
|
||||
hdf5_format.load_weights_from_hdf5_group_by_name(
|
||||
f_model, model.layers, skip_mismatch=True)
|
||||
self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel),
|
||||
keras.backend.get_value(model.layers[1].kernel))
|
||||
hdf5_format.load_weights_from_hdf5_group_by_name(
|
||||
f_model, model.layers, skip_mismatch=True)
|
||||
self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel),
|
||||
keras.backend.get_value(model.layers[1].kernel))
|
||||
|
||||
def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
|
||||
if h5py is None:
|
||||
@ -779,7 +775,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
||||
self.assertRegexpMatches(
|
||||
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_functional_model_with_custom_loss_and_metric(self):
|
||||
def _make_model():
|
||||
inputs = keras.Input(shape=(4,))
|
||||
@ -818,7 +814,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
||||
evaluation_results['sparse_categorical_crossentropy'] +
|
||||
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_save_uncompiled_model_with_optimizer(self):
|
||||
with self.cached_session() as session:
|
||||
saved_model_dir = self._save_model_dir()
|
||||
@ -901,6 +897,7 @@ class _make_subclassed_built(_make_subclassed): # pylint: disable=invalid-name
|
||||
self.build((None, input_size))
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
||||
"""Tests saving a whole model that contains other models."""
|
||||
|
||||
@ -913,7 +910,6 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
||||
('subclassed', _make_subclassed),
|
||||
('subclassed_built', _make_subclassed_built),
|
||||
])
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_functional(self, model_fn):
|
||||
"""Tests serializing a model that uses a nested model to share weights."""
|
||||
if h5py is None:
|
||||
@ -926,22 +922,23 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
||||
outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
|
||||
return keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
x = (np.random.normal(size=(16, 4)).astype(np.float32),
|
||||
np.random.normal(size=(16, 4)).astype(np.float32))
|
||||
model = _make_model()
|
||||
predictions = model(x)
|
||||
# Save and reload.
|
||||
model_path = os.path.join(self.get_temp_dir(), 'model.h5')
|
||||
model.save(model_path)
|
||||
del model
|
||||
loaded_model = keras.models.load_model(
|
||||
model_path,
|
||||
custom_objects={
|
||||
'_make_subclassed': _make_subclassed,
|
||||
'_make_subclassed_built': _make_subclassed_built,
|
||||
},
|
||||
compile=False)
|
||||
self.assertAllClose(loaded_model(x), predictions, 1e-9)
|
||||
with self.cached_session():
|
||||
x = (np.random.normal(size=(16, 4)).astype(np.float32),
|
||||
np.random.normal(size=(16, 4)).astype(np.float32))
|
||||
model = _make_model()
|
||||
predictions = model(x)
|
||||
# Save and reload.
|
||||
model_path = os.path.join(self.get_temp_dir(), 'model.h5')
|
||||
model.save(model_path)
|
||||
del model
|
||||
loaded_model = keras.models.load_model(
|
||||
model_path,
|
||||
custom_objects={
|
||||
'_make_subclassed': _make_subclassed,
|
||||
'_make_subclassed_built': _make_subclassed_built,
|
||||
},
|
||||
compile=False)
|
||||
self.assertAllClose(loaded_model(x), predictions, 1e-9)
|
||||
|
||||
|
||||
class SubclassedModel(training.Model):
|
||||
@ -955,7 +952,7 @@ class SubclassedModel(training.Model):
|
||||
return self.b_layer(self.x_layer(a))
|
||||
|
||||
|
||||
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_keras_optimizer_warning(self):
|
||||
graph = ops.Graph()
|
||||
@ -974,7 +971,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
str(mock_log.call_args),
|
||||
'Keras optimizer')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_tensorflow_format_overwrite(self):
|
||||
with self.cached_session() as session:
|
||||
model = SubclassedModel()
|
||||
@ -1025,12 +1022,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
model.save_weights(prefix, save_format='tensorflow')
|
||||
op_count = len(graph.get_operations())
|
||||
model.save_weights(prefix, save_format='tensorflow')
|
||||
self.assertEqual(len(graph.get_operations()), op_count)
|
||||
self.assertLen(graph.get_operations(), op_count)
|
||||
|
||||
model.load_weights(prefix)
|
||||
op_count = len(graph.get_operations())
|
||||
model.load_weights(prefix)
|
||||
self.assertEqual(len(graph.get_operations()), op_count)
|
||||
self.assertLen(graph.get_operations(), op_count)
|
||||
|
||||
def _weight_loading_test_template(self, make_model_fn):
|
||||
with self.cached_session():
|
||||
@ -1079,7 +1076,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
load_model.train_on_batch(train_x, train_y)
|
||||
self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weight_loading_graph_model(self):
|
||||
def _make_graph_model():
|
||||
a = keras.layers.Input(shape=(2,))
|
||||
@ -1089,7 +1086,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
|
||||
self._weight_loading_test_template(_make_graph_model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weight_loading_subclassed_model(self):
|
||||
self._weight_loading_test_template(SubclassedModel)
|
||||
|
||||
@ -1127,7 +1124,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
y = self.evaluate(model(x))
|
||||
self.assertAllClose(ref_y, y)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weight_loading_graph_model_added_layer(self):
|
||||
def _save_graph_model():
|
||||
a = keras.layers.Input(shape=(2,))
|
||||
@ -1144,7 +1141,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
self._new_layer_weight_loading_test_template(
|
||||
_save_graph_model, _restore_graph_model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weight_loading_graph_model_added_no_weight_layer(self):
|
||||
def _save_graph_model():
|
||||
a = keras.layers.Input(shape=(2,))
|
||||
@ -1161,7 +1158,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
self._new_layer_weight_loading_test_template(
|
||||
_save_graph_model, _restore_graph_model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_weight_loading_subclassed_model_added_layer(self):
|
||||
|
||||
class SubclassedModelRestore(training.Model):
|
||||
@ -1178,7 +1175,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
self._new_layer_weight_loading_test_template(
|
||||
SubclassedModel, SubclassedModelRestore)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_incompatible_checkpoint(self):
|
||||
save_path = trackable.Checkpoint().save(
|
||||
os.path.join(self.get_temp_dir(), 'ckpt'))
|
||||
@ -1191,57 +1188,62 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
AssertionError, 'Nothing except the root object matched'):
|
||||
m.load_weights(save_path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_directory_passed(self):
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/')
|
||||
m.save_weights(prefix)
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(),
|
||||
'{}'.format(ops.uid()), 'ckpt/')
|
||||
m.save_weights(prefix)
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_relative_path(self):
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
os.chdir(self.get_temp_dir())
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
os.chdir(self.get_temp_dir())
|
||||
|
||||
prefix = 'ackpt'
|
||||
self.evaluate(v.assign(42.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('ackpt.index'))
|
||||
self.evaluate(v.assign(1.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
prefix = 'ackpt'
|
||||
self.evaluate(v.assign(42.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('ackpt.index'))
|
||||
self.evaluate(v.assign(1.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
|
||||
prefix = 'subdir/ackpt'
|
||||
self.evaluate(v.assign(43.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(43., self.evaluate(v))
|
||||
prefix = 'subdir/ackpt'
|
||||
self.evaluate(v.assign(43.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(43., self.evaluate(v))
|
||||
|
||||
prefix = 'ackpt/'
|
||||
self.evaluate(v.assign(44.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('ackpt/.index'))
|
||||
self.evaluate(v.assign(3.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(44., self.evaluate(v))
|
||||
prefix = 'ackpt/'
|
||||
self.evaluate(v.assign(44.))
|
||||
m.save_weights(prefix)
|
||||
self.assertTrue(file_io.file_exists('ackpt/.index'))
|
||||
self.evaluate(v.assign(3.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(44., self.evaluate(v))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_nonexistent_prefix_directory(self):
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt')
|
||||
m.save_weights(prefix)
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(),
|
||||
'{}'.format(ops.uid()), 'bckpt')
|
||||
m.save_weights(prefix)
|
||||
self.evaluate(v.assign(2.))
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
@ -28,6 +29,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.feature_column import feature_column_lib
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.saving import model_config
|
||||
from tensorflow.python.keras.saving import save
|
||||
@ -43,7 +45,7 @@ except ImportError:
|
||||
h5py = None
|
||||
|
||||
|
||||
class TestSaveModel(test.TestCase):
|
||||
class TestSaveModel(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestSaveModel, self).setUp()
|
||||
@ -99,7 +101,7 @@ class TestSaveModel(test.TestCase):
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
save.load_model(path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_saving_with_dense_features(self):
|
||||
cols = [
|
||||
feature_column_lib.numeric_column('a'),
|
||||
@ -128,13 +130,14 @@ class TestSaveModel(test.TestCase):
|
||||
inputs_a = np.arange(10).reshape(10, 1)
|
||||
inputs_b = np.arange(10).reshape(10, 1).astype('str')
|
||||
|
||||
# Initialize tables for V1 lookup.
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
with self.cached_session():
|
||||
# Initialize tables for V1 lookup.
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
|
||||
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
|
||||
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_saving_with_sequence_features(self):
|
||||
cols = [
|
||||
feature_column_lib.sequence_numeric_column('a'),
|
||||
@ -182,17 +185,18 @@ class TestSaveModel(test.TestCase):
|
||||
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
|
||||
(batch_size, timesteps, 1))
|
||||
|
||||
# Initialize tables for V1 lookup.
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
with self.cached_session():
|
||||
# Initialize tables for V1 lookup.
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
|
||||
self.assertLen(
|
||||
loaded_model.predict({
|
||||
'a': inputs_a,
|
||||
'b': inputs_b
|
||||
}, steps=1), batch_size)
|
||||
self.assertLen(
|
||||
loaded_model.predict({
|
||||
'a': inputs_a,
|
||||
'b': inputs_b
|
||||
}, steps=1), batch_size)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_saving_h5_for_rnn_layers(self):
|
||||
# See https://github.com/tensorflow/tensorflow/issues/35731 for details.
|
||||
inputs = keras.Input([10, 91], name='train_input')
|
||||
@ -213,7 +217,7 @@ class TestSaveModel(test.TestCase):
|
||||
rnn_layers[1].kernel.name)
|
||||
self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_saving_optimizer_weights(self):
|
||||
|
||||
class MyModel(keras.Model):
|
||||
|
@ -44,7 +44,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -700,7 +700,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.assertAllClose(model(input_arr), loaded(input_arr))
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase):
|
||||
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_functions_have_same_trace(self):
|
||||
|
||||
@ -773,7 +773,7 @@ class TestLayerCallTracing(test.TestCase):
|
||||
|
||||
assert_num_traces(LayerWithChildLayer, training_keyword=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_maintains_losses(self):
|
||||
layer = LayerWithLoss()
|
||||
layer(np.ones((2, 3)))
|
||||
@ -786,7 +786,7 @@ class TestLayerCallTracing(test.TestCase):
|
||||
self.assertAllEqual(previous_losses, layer.losses)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
@ -870,28 +870,30 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
# while returning nothing.
|
||||
super(CustomMetric, self).update_state(*args)
|
||||
|
||||
metric = CustomMetric()
|
||||
save_dir = self._save_model_dir('first_save')
|
||||
with self.cached_session():
|
||||
metric = CustomMetric()
|
||||
save_dir = self._save_model_dir('first_save')
|
||||
|
||||
if requires_build:
|
||||
metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable
|
||||
if requires_build:
|
||||
metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable
|
||||
|
||||
self.evaluate([v.initializer for v in metric.variables])
|
||||
self.evaluate([v.initializer for v in metric.variables])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'):
|
||||
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
|
||||
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
|
||||
loaded = self._test_metric_save_and_load(
|
||||
metric,
|
||||
save_dir,
|
||||
num_tensor_args,
|
||||
test_sample_weight=False)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Unable to restore custom object'):
|
||||
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
|
||||
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
|
||||
loaded = self._test_metric_save_and_load(
|
||||
metric,
|
||||
save_dir,
|
||||
num_tensor_args,
|
||||
test_sample_weight=False)
|
||||
|
||||
self._test_metric_save_and_load(
|
||||
loaded,
|
||||
self._save_model_dir('second_save'),
|
||||
num_tensor_args,
|
||||
test_sample_weight=False)
|
||||
self._test_metric_save_and_load(
|
||||
loaded,
|
||||
self._save_model_dir('second_save'),
|
||||
num_tensor_args,
|
||||
test_sample_weight=False)
|
||||
|
||||
def test_custom_metric_wrapped_call(self):
|
||||
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
@ -62,7 +63,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_trace_model_outputs(self):
|
||||
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
|
||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||
@ -155,7 +156,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_trace_features_layer(self):
|
||||
columns = [feature_column_lib.numeric_column('x')]
|
||||
model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
|
||||
@ -176,7 +177,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose({'output_1': [[1., 2.]]},
|
||||
fn({'x': [[1.]], 'y': [[2.]]}))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_specify_input_signature(self):
|
||||
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
||||
inputs = array_ops.ones((8, 5))
|
||||
@ -193,7 +194,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
expected_outputs = {'output_1': model(inputs)}
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_subclassed_model_with_input_signature(self):
|
||||
|
||||
class Model(keras.Model):
|
||||
@ -218,7 +219,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_model_with_fixed_input_dim(self):
|
||||
"""Ensure that the batch_dim is removed when saving.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user