Update tests under keras.saving to use combinations.
Change all test_util.run_all_in_graph_and_eager_modes to combination. PiperOrigin-RevId: 301396996 Change-Id: I1d79695f819bb289a428b3fd97965841a873bda9
This commit is contained in:
parent
8d01f78f82
commit
f7896058b2
@ -92,6 +92,7 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
"//tensorflow/python/keras:combinations",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
@ -106,6 +107,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/feature_column:feature_column_v2",
|
"//tensorflow/python/feature_column:feature_column_v2",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
"//tensorflow/python/keras:combinations",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
@ -142,6 +144,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/distribute:mirrored_strategy",
|
"//tensorflow/python/distribute:mirrored_strategy",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
"//tensorflow/python/keras:combinations",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
@ -156,6 +159,7 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
"//tensorflow/python/keras:combinations",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
|
@ -30,7 +30,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
@ -51,10 +51,10 @@ except ImportError:
|
|||||||
h5py = None
|
h5py = None
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_saved_model_formats
|
@keras_parameterized.run_with_all_saved_model_formats
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_weight_loading(self):
|
def test_weight_loading(self):
|
||||||
temp_dir = self.get_temp_dir()
|
temp_dir = self.get_temp_dir()
|
||||||
self.addCleanup(shutil.rmtree, temp_dir)
|
self.addCleanup(shutil.rmtree, temp_dir)
|
||||||
@ -83,7 +83,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||||||
y = model.predict(x)
|
y = model.predict(x)
|
||||||
self.assertAllClose(ref_y, y)
|
self.assertAllClose(ref_y, y)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_weight_preprocessing(self):
|
def test_weight_preprocessing(self):
|
||||||
input_dim = 3
|
input_dim = 3
|
||||||
output_dim = 3
|
output_dim = 3
|
||||||
@ -210,7 +209,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||||||
for (x, y) in zip(weights1, weights2)
|
for (x, y) in zip(weights1, weights2)
|
||||||
]
|
]
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_sequential_weight_loading(self):
|
def test_sequential_weight_loading(self):
|
||||||
if h5py is None:
|
if h5py is None:
|
||||||
return
|
return
|
||||||
@ -243,7 +241,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(y, ref_y)
|
self.assertAllClose(y, ref_y)
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_saved_model_formats
|
@keras_parameterized.run_with_all_saved_model_formats
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_nested_model_weight_loading(self):
|
def test_nested_model_weight_loading(self):
|
||||||
save_format = testing_utils.get_save_format()
|
save_format = testing_utils.get_save_format()
|
||||||
temp_dir = self.get_temp_dir()
|
temp_dir = self.get_temp_dir()
|
||||||
@ -282,7 +279,6 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(y, ref_y)
|
self.assertAllClose(y, ref_y)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_sequential_weight_loading_group_name_with_incorrect_length(self):
|
def test_sequential_weight_loading_group_name_with_incorrect_length(self):
|
||||||
if h5py is None:
|
if h5py is None:
|
||||||
return
|
return
|
||||||
@ -779,7 +775,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
|
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_functional_model_with_custom_loss_and_metric(self):
|
def test_functional_model_with_custom_loss_and_metric(self):
|
||||||
def _make_model():
|
def _make_model():
|
||||||
inputs = keras.Input(shape=(4,))
|
inputs = keras.Input(shape=(4,))
|
||||||
@ -818,7 +814,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||||||
evaluation_results['sparse_categorical_crossentropy'] +
|
evaluation_results['sparse_categorical_crossentropy'] +
|
||||||
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
|
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_save_uncompiled_model_with_optimizer(self):
|
def test_save_uncompiled_model_with_optimizer(self):
|
||||||
with self.cached_session() as session:
|
with self.cached_session() as session:
|
||||||
saved_model_dir = self._save_model_dir()
|
saved_model_dir = self._save_model_dir()
|
||||||
@ -901,6 +897,7 @@ class _make_subclassed_built(_make_subclassed): # pylint: disable=invalid-name
|
|||||||
self.build((None, input_size))
|
self.build((None, input_size))
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
||||||
"""Tests saving a whole model that contains other models."""
|
"""Tests saving a whole model that contains other models."""
|
||||||
|
|
||||||
@ -913,7 +910,6 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
|||||||
('subclassed', _make_subclassed),
|
('subclassed', _make_subclassed),
|
||||||
('subclassed_built', _make_subclassed_built),
|
('subclassed_built', _make_subclassed_built),
|
||||||
])
|
])
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_functional(self, model_fn):
|
def test_functional(self, model_fn):
|
||||||
"""Tests serializing a model that uses a nested model to share weights."""
|
"""Tests serializing a model that uses a nested model to share weights."""
|
||||||
if h5py is None:
|
if h5py is None:
|
||||||
@ -926,6 +922,7 @@ class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
|
|||||||
outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
|
outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
|
||||||
return keras.Model(inputs=inputs, outputs=outputs)
|
return keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
x = (np.random.normal(size=(16, 4)).astype(np.float32),
|
x = (np.random.normal(size=(16, 4)).astype(np.float32),
|
||||||
np.random.normal(size=(16, 4)).astype(np.float32))
|
np.random.normal(size=(16, 4)).astype(np.float32))
|
||||||
model = _make_model()
|
model = _make_model()
|
||||||
@ -955,7 +952,7 @@ class SubclassedModel(training.Model):
|
|||||||
return self.b_layer(self.x_layer(a))
|
return self.b_layer(self.x_layer(a))
|
||||||
|
|
||||||
|
|
||||||
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_keras_optimizer_warning(self):
|
def test_keras_optimizer_warning(self):
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
@ -974,7 +971,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
str(mock_log.call_args),
|
str(mock_log.call_args),
|
||||||
'Keras optimizer')
|
'Keras optimizer')
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_tensorflow_format_overwrite(self):
|
def test_tensorflow_format_overwrite(self):
|
||||||
with self.cached_session() as session:
|
with self.cached_session() as session:
|
||||||
model = SubclassedModel()
|
model = SubclassedModel()
|
||||||
@ -1025,12 +1022,12 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
model.save_weights(prefix, save_format='tensorflow')
|
model.save_weights(prefix, save_format='tensorflow')
|
||||||
op_count = len(graph.get_operations())
|
op_count = len(graph.get_operations())
|
||||||
model.save_weights(prefix, save_format='tensorflow')
|
model.save_weights(prefix, save_format='tensorflow')
|
||||||
self.assertEqual(len(graph.get_operations()), op_count)
|
self.assertLen(graph.get_operations(), op_count)
|
||||||
|
|
||||||
model.load_weights(prefix)
|
model.load_weights(prefix)
|
||||||
op_count = len(graph.get_operations())
|
op_count = len(graph.get_operations())
|
||||||
model.load_weights(prefix)
|
model.load_weights(prefix)
|
||||||
self.assertEqual(len(graph.get_operations()), op_count)
|
self.assertLen(graph.get_operations(), op_count)
|
||||||
|
|
||||||
def _weight_loading_test_template(self, make_model_fn):
|
def _weight_loading_test_template(self, make_model_fn):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -1079,7 +1076,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
load_model.train_on_batch(train_x, train_y)
|
load_model.train_on_batch(train_x, train_y)
|
||||||
self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
|
self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_weight_loading_graph_model(self):
|
def test_weight_loading_graph_model(self):
|
||||||
def _make_graph_model():
|
def _make_graph_model():
|
||||||
a = keras.layers.Input(shape=(2,))
|
a = keras.layers.Input(shape=(2,))
|
||||||
@ -1089,7 +1086,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
|
|
||||||
self._weight_loading_test_template(_make_graph_model)
|
self._weight_loading_test_template(_make_graph_model)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_weight_loading_subclassed_model(self):
|
def test_weight_loading_subclassed_model(self):
|
||||||
self._weight_loading_test_template(SubclassedModel)
|
self._weight_loading_test_template(SubclassedModel)
|
||||||
|
|
||||||
@ -1127,7 +1124,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
y = self.evaluate(model(x))
|
y = self.evaluate(model(x))
|
||||||
self.assertAllClose(ref_y, y)
|
self.assertAllClose(ref_y, y)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_weight_loading_graph_model_added_layer(self):
|
def test_weight_loading_graph_model_added_layer(self):
|
||||||
def _save_graph_model():
|
def _save_graph_model():
|
||||||
a = keras.layers.Input(shape=(2,))
|
a = keras.layers.Input(shape=(2,))
|
||||||
@ -1144,7 +1141,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
self._new_layer_weight_loading_test_template(
|
self._new_layer_weight_loading_test_template(
|
||||||
_save_graph_model, _restore_graph_model)
|
_save_graph_model, _restore_graph_model)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_weight_loading_graph_model_added_no_weight_layer(self):
|
def test_weight_loading_graph_model_added_no_weight_layer(self):
|
||||||
def _save_graph_model():
|
def _save_graph_model():
|
||||||
a = keras.layers.Input(shape=(2,))
|
a = keras.layers.Input(shape=(2,))
|
||||||
@ -1161,7 +1158,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
self._new_layer_weight_loading_test_template(
|
self._new_layer_weight_loading_test_template(
|
||||||
_save_graph_model, _restore_graph_model)
|
_save_graph_model, _restore_graph_model)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_weight_loading_subclassed_model_added_layer(self):
|
def test_weight_loading_subclassed_model_added_layer(self):
|
||||||
|
|
||||||
class SubclassedModelRestore(training.Model):
|
class SubclassedModelRestore(training.Model):
|
||||||
@ -1178,7 +1175,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
self._new_layer_weight_loading_test_template(
|
self._new_layer_weight_loading_test_template(
|
||||||
SubclassedModel, SubclassedModelRestore)
|
SubclassedModel, SubclassedModelRestore)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_incompatible_checkpoint(self):
|
def test_incompatible_checkpoint(self):
|
||||||
save_path = trackable.Checkpoint().save(
|
save_path = trackable.Checkpoint().save(
|
||||||
os.path.join(self.get_temp_dir(), 'ckpt'))
|
os.path.join(self.get_temp_dir(), 'ckpt'))
|
||||||
@ -1191,19 +1188,22 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
AssertionError, 'Nothing except the root object matched'):
|
AssertionError, 'Nothing except the root object matched'):
|
||||||
m.load_weights(save_path)
|
m.load_weights(save_path)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_directory_passed(self):
|
def test_directory_passed(self):
|
||||||
|
with self.cached_session():
|
||||||
m = keras.Model()
|
m = keras.Model()
|
||||||
v = m.add_weight(name='v', shape=[])
|
v = m.add_weight(name='v', shape=[])
|
||||||
self.evaluate(v.assign(42.))
|
self.evaluate(v.assign(42.))
|
||||||
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/')
|
prefix = os.path.join(self.get_temp_dir(),
|
||||||
|
'{}'.format(ops.uid()), 'ckpt/')
|
||||||
m.save_weights(prefix)
|
m.save_weights(prefix)
|
||||||
self.evaluate(v.assign(2.))
|
self.evaluate(v.assign(2.))
|
||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
self.assertEqual(42., self.evaluate(v))
|
self.assertEqual(42., self.evaluate(v))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_relative_path(self):
|
def test_relative_path(self):
|
||||||
|
with self.cached_session():
|
||||||
m = keras.Model()
|
m = keras.Model()
|
||||||
v = m.add_weight(name='v', shape=[])
|
v = m.add_weight(name='v', shape=[])
|
||||||
os.chdir(self.get_temp_dir())
|
os.chdir(self.get_temp_dir())
|
||||||
@ -1232,12 +1232,14 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
|||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
self.assertEqual(44., self.evaluate(v))
|
self.assertEqual(44., self.evaluate(v))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_nonexistent_prefix_directory(self):
|
def test_nonexistent_prefix_directory(self):
|
||||||
|
with self.cached_session():
|
||||||
m = keras.Model()
|
m = keras.Model()
|
||||||
v = m.add_weight(name='v', shape=[])
|
v = m.add_weight(name='v', shape=[])
|
||||||
self.evaluate(v.assign(42.))
|
self.evaluate(v.assign(42.))
|
||||||
prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt')
|
prefix = os.path.join(self.get_temp_dir(),
|
||||||
|
'{}'.format(ops.uid()), 'bckpt')
|
||||||
m.save_weights(prefix)
|
m.save_weights(prefix)
|
||||||
self.evaluate(v.assign(2.))
|
self.evaluate(v.assign(2.))
|
||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
@ -28,6 +29,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.feature_column import feature_column_lib
|
from tensorflow.python.feature_column import feature_column_lib
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.saving import model_config
|
from tensorflow.python.keras.saving import model_config
|
||||||
from tensorflow.python.keras.saving import save
|
from tensorflow.python.keras.saving import save
|
||||||
@ -43,7 +45,7 @@ except ImportError:
|
|||||||
h5py = None
|
h5py = None
|
||||||
|
|
||||||
|
|
||||||
class TestSaveModel(test.TestCase):
|
class TestSaveModel(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestSaveModel, self).setUp()
|
super(TestSaveModel, self).setUp()
|
||||||
@ -99,7 +101,7 @@ class TestSaveModel(test.TestCase):
|
|||||||
save.save_model(self.model, path, save_format='tf')
|
save.save_model(self.model, path, save_format='tf')
|
||||||
save.load_model(path)
|
save.load_model(path)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_saving_with_dense_features(self):
|
def test_saving_with_dense_features(self):
|
||||||
cols = [
|
cols = [
|
||||||
feature_column_lib.numeric_column('a'),
|
feature_column_lib.numeric_column('a'),
|
||||||
@ -128,13 +130,14 @@ class TestSaveModel(test.TestCase):
|
|||||||
inputs_a = np.arange(10).reshape(10, 1)
|
inputs_a = np.arange(10).reshape(10, 1)
|
||||||
inputs_b = np.arange(10).reshape(10, 1).astype('str')
|
inputs_b = np.arange(10).reshape(10, 1).astype('str')
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
# Initialize tables for V1 lookup.
|
# Initialize tables for V1 lookup.
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.evaluate(lookup_ops.tables_initializer())
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
|
|
||||||
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
|
self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_saving_with_sequence_features(self):
|
def test_saving_with_sequence_features(self):
|
||||||
cols = [
|
cols = [
|
||||||
feature_column_lib.sequence_numeric_column('a'),
|
feature_column_lib.sequence_numeric_column('a'),
|
||||||
@ -182,6 +185,7 @@ class TestSaveModel(test.TestCase):
|
|||||||
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
|
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
|
||||||
(batch_size, timesteps, 1))
|
(batch_size, timesteps, 1))
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
# Initialize tables for V1 lookup.
|
# Initialize tables for V1 lookup.
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.evaluate(lookup_ops.tables_initializer())
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
@ -192,7 +196,7 @@ class TestSaveModel(test.TestCase):
|
|||||||
'b': inputs_b
|
'b': inputs_b
|
||||||
}, steps=1), batch_size)
|
}, steps=1), batch_size)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_saving_h5_for_rnn_layers(self):
|
def test_saving_h5_for_rnn_layers(self):
|
||||||
# See https://github.com/tensorflow/tensorflow/issues/35731 for details.
|
# See https://github.com/tensorflow/tensorflow/issues/35731 for details.
|
||||||
inputs = keras.Input([10, 91], name='train_input')
|
inputs = keras.Input([10, 91], name='train_input')
|
||||||
@ -213,7 +217,7 @@ class TestSaveModel(test.TestCase):
|
|||||||
rnn_layers[1].kernel.name)
|
rnn_layers[1].kernel.name)
|
||||||
self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
|
self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_saving_optimizer_weights(self):
|
def test_saving_optimizer_weights(self):
|
||||||
|
|
||||||
class MyModel(keras.Model):
|
class MyModel(keras.Model):
|
||||||
|
@ -44,7 +44,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import regularizers
|
from tensorflow.python.keras import regularizers
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
@ -700,7 +700,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose(model(input_arr), loaded(input_arr))
|
self.assertAllClose(model(input_arr), loaded(input_arr))
|
||||||
|
|
||||||
|
|
||||||
class TestLayerCallTracing(test.TestCase):
|
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_functions_have_same_trace(self):
|
def test_functions_have_same_trace(self):
|
||||||
|
|
||||||
@ -773,7 +773,7 @@ class TestLayerCallTracing(test.TestCase):
|
|||||||
|
|
||||||
assert_num_traces(LayerWithChildLayer, training_keyword=False)
|
assert_num_traces(LayerWithChildLayer, training_keyword=False)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_maintains_losses(self):
|
def test_maintains_losses(self):
|
||||||
layer = LayerWithLoss()
|
layer = LayerWithLoss()
|
||||||
layer(np.ones((2, 3)))
|
layer(np.ones((2, 3)))
|
||||||
@ -786,7 +786,7 @@ class TestLayerCallTracing(test.TestCase):
|
|||||||
self.assertAllEqual(previous_losses, layer.losses)
|
self.assertAllEqual(previous_losses, layer.losses)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class MetricTest(test.TestCase, parameterized.TestCase):
|
class MetricTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _save_model_dir(self, dirname='saved_model'):
|
def _save_model_dir(self, dirname='saved_model'):
|
||||||
@ -870,6 +870,7 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
|||||||
# while returning nothing.
|
# while returning nothing.
|
||||||
super(CustomMetric, self).update_state(*args)
|
super(CustomMetric, self).update_state(*args)
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
metric = CustomMetric()
|
metric = CustomMetric()
|
||||||
save_dir = self._save_model_dir('first_save')
|
save_dir = self._save_model_dir('first_save')
|
||||||
|
|
||||||
@ -878,7 +879,8 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.evaluate([v.initializer for v in metric.variables])
|
self.evaluate([v.initializer for v in metric.variables])
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'Unable to restore custom object'):
|
||||||
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
|
self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
|
||||||
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
|
with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
|
||||||
loaded = self._test_metric_save_and_load(
|
loaded = self._test_metric_save_and_load(
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
@ -62,7 +63,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose(expected, actual)
|
self.assertAllClose(expected, actual)
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_trace_model_outputs(self):
|
def test_trace_model_outputs(self):
|
||||||
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
|
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
|
||||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||||
@ -155,7 +156,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
|
expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
|
||||||
self._assert_all_close(expected_outputs, signature_outputs)
|
self._assert_all_close(expected_outputs, signature_outputs)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_trace_features_layer(self):
|
def test_trace_features_layer(self):
|
||||||
columns = [feature_column_lib.numeric_column('x')]
|
columns = [feature_column_lib.numeric_column('x')]
|
||||||
model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
|
model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)])
|
||||||
@ -176,7 +177,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose({'output_1': [[1., 2.]]},
|
self.assertAllClose({'output_1': [[1., 2.]]},
|
||||||
fn({'x': [[1.]], 'y': [[2.]]}))
|
fn({'x': [[1.]], 'y': [[2.]]}))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_specify_input_signature(self):
|
def test_specify_input_signature(self):
|
||||||
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
||||||
inputs = array_ops.ones((8, 5))
|
inputs = array_ops.ones((8, 5))
|
||||||
@ -193,7 +194,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
expected_outputs = {'output_1': model(inputs)}
|
expected_outputs = {'output_1': model(inputs)}
|
||||||
self._assert_all_close(expected_outputs, signature_outputs)
|
self._assert_all_close(expected_outputs, signature_outputs)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_subclassed_model_with_input_signature(self):
|
def test_subclassed_model_with_input_signature(self):
|
||||||
|
|
||||||
class Model(keras.Model):
|
class Model(keras.Model):
|
||||||
@ -218,7 +219,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
self._assert_all_close(expected_outputs, signature_outputs)
|
self._assert_all_close(expected_outputs, signature_outputs)
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_model_with_fixed_input_dim(self):
|
def test_model_with_fixed_input_dim(self):
|
||||||
"""Ensure that the batch_dim is removed when saving.
|
"""Ensure that the batch_dim is removed when saving.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user