Move keras related saved_model test to keras integration test and unit test.

The memory test couldn't be convert to integration test since the anntation for eager garbage collection is not publicly visible.

PiperOrigin-RevId: 315702765
Change-Id: I4c8d54b074364d3884af64a2b4d00ad615ef319d
This commit is contained in:
Scott Zhu 2020-06-10 09:17:33 -07:00 committed by TensorFlower Gardener
parent 9affecb6c0
commit f2306d9f25
5 changed files with 176 additions and 93 deletions

View File

@ -43,8 +43,8 @@ tf_py_test(
)
cuda_py_test(
name = "load_test",
srcs = ["load_test.py"],
name = "saved_model_test",
srcs = ["saved_model_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",

View File

@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from absl.testing import parameterized
@ -40,6 +41,79 @@ def cycle(obj, cycles, signatures=None):
return loaded
class _ModelWithOptimizer(tf.train.Checkpoint):
def __init__(self):
self.dense = tf.keras.layers.Dense(1)
self.optimizer = tf.keras.optimizers.Adam(0.01)
@tf.function(
input_signature=(tf.TensorSpec([None, 2], tf.float32),
tf.TensorSpec([None], tf.float32)))
def call(self, x, y):
with tf.GradientTape() as tape:
loss = tf.math.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
def _import_and_infer(save_dir, inputs, signature_key="serving_default"):
"""Import a SavedModel into a TF 1.x-style graph and run `signature_key`."""
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session() as session:
model = tf.compat.v1.saved_model.load(session, ["serve"], save_dir)
return _run_signature(session, model, inputs, signature_key)
def _run_signature(session, meta_graph_def, inputs, signature_key):
signature = meta_graph_def.signature_def[signature_key]
assert set(inputs.keys()) == set(signature.inputs.keys())
feed_dict = {}
for arg_name in inputs.keys():
input_tensor = session.graph.get_tensor_by_name(
signature.inputs[arg_name].name)
feed_dict[input_tensor] = inputs[arg_name]
output_dict = {}
for output_name, output_tensor_info in signature.outputs.items():
output_dict[output_name] = session.graph.get_tensor_by_name(
output_tensor_info.name)
return session.run(output_dict, feed_dict=feed_dict)
class SaveTest(tf.test.TestCase):
def test_unbuilt_model_does_not_prevent_saving(self):
root = tf.train.Checkpoint(
model=tf.keras.Sequential([tf.keras.layers.Dense(2)]))
tf.saved_model.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def test_optimizer(self):
x = tf.constant([[3., 4.]])
y = tf.constant([2.])
model = _ModelWithOptimizer()
first_loss = model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
tf.saved_model.save(model, save_dir, model.call)
second_loss = model.call(x, y)
self.assertNotEqual(first_loss, second_loss)
self.assertAllClose(
second_loss,
_import_and_infer(save_dir, {"x": [[3., 4.]], "y": [2.]}))
def test_single_method_default_signature(self):
model = _ModelWithOptimizer()
x = tf.constant([[3., 4.]])
y = tf.constant([2.])
model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
tf.saved_model.save(model, save_dir)
self.assertIn("loss",
_import_and_infer(save_dir,
{"x": [[3., 4.]], "y": [2.]}))
@parameterized.named_parameters(
dict(testcase_name="ReloadOnce", cycles=1),
dict(testcase_name="ReloadTwice", cycles=2),

View File

@ -277,6 +277,26 @@ cuda_py_test(
],
)
tf_py_test(
name = "saved_model_test",
size = "small",
srcs = ["saved_model_test.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:function",
"//tensorflow/python/keras/layers:core",
"//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/training/tracking:util",
],
)
tf_py_test(
name = "temporal_sample_weights_correctness_test",
srcs = ["temporal_sample_weights_correctness_test.py"],

View File

@ -0,0 +1,80 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for trackable object SavedModel save."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
from tensorflow.python.eager import backprop
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save
from tensorflow.python.training.tracking import util
class _ModelWithOptimizerUsingDefun(util.Checkpoint):
def __init__(self):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
# conds to gate variable initialization and so triggers cond reference cycles,
# but the thing being wrapped here does not use cond itself.
@function.defun(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)),
)
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
class MemoryTests(test.TestCase):
def setUp(self):
super(MemoryTests, self).setUp()
self._model = _ModelWithOptimizerUsingDefun()
@test_util.assert_no_garbage_created
def test_no_reference_cycles(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
self._model.call(x, y)
if sys.version_info[0] < 3:
# TODO(allenl): debug reference cycles in Python 2.x
self.skipTest("This test only works in Python 3+. Reference cycles are "
"created in older Python versions.")
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(self._model, save_dir, self._model.call)
if __name__ == "__main__":
test.main()

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import os
import sys
from google.protobuf import text_format
@ -29,7 +28,6 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -38,9 +36,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.lib.io import file_io
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
@ -61,24 +56,6 @@ from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
class _ModelWithOptimizer(util.Checkpoint):
def __init__(self):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
@def_function.function(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)))
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
def _run_signature(session, meta_graph_def, inputs, signature_key):
signature = meta_graph_def.signature_def[signature_key]
assert set(inputs.keys()) == set(signature.inputs.keys())
@ -186,10 +163,6 @@ class SaveTest(test.TestCase):
_import_and_infer(
save_dir, {"z": 1.}, signature_key="non_default_key"))
def test_unbuilt_model_does_not_prevent_saving(self):
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def test_unsaveable_func_graph(self):
root = module.Module()
@ -288,30 +261,6 @@ class SaveTest(test.TestCase):
self.assertAllEqual({"output_0": 12.},
_import_and_infer(save_dir, {"x": 2.}))
def test_optimizer(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model = _ModelWithOptimizer()
first_loss = model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir, model.call)
second_loss = model.call(x, y)
self.assertNotEqual(first_loss, second_loss)
self.assertAllClose(
second_loss,
_import_and_infer(save_dir, {"x": [[3., 4.]], "y": [2.]}))
def test_single_method_default_signature(self):
model = _ModelWithOptimizer()
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertIn("loss",
_import_and_infer(save_dir,
{"x": [[3., 4.]], "y": [2.]}))
def test_single_function_default_signature(self):
model = tracking.AutoTrackable()
model.f = def_function.function(lambda: 3., input_signature=())
@ -677,46 +626,6 @@ class AssetTests(test.TestCase):
_calls_save()
class _ModelWithOptimizerUsingDefun(util.Checkpoint):
def __init__(self):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
# conds to gate variable initialization and so triggers cond reference cycles,
# but the thing being wrapped here does not use cond itself.
@function.defun(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)),
)
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
class MemoryTests(test.TestCase):
def setUp(self):
self._model = _ModelWithOptimizerUsingDefun()
@test_util.assert_no_garbage_created
def test_no_reference_cycles(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
self._model.call(x, y)
if sys.version_info[0] < 3:
# TODO(allenl): debug reference cycles in Python 2.x
self.skipTest("This test only works in Python 3+. Reference cycles are "
"created in older Python versions.")
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(self._model, save_dir, self._model.call)
class ExportMetaGraphTests(test.TestCase):
def test_export_meta_graph(self):