Remove run_deprecated_v1 qualifier from saved_model:simple_save_test.
PiperOrigin-RevId: 323613279 Change-Id: I96f174f589c203acb7303627a33131867d9ac5bb
This commit is contained in:
parent
482d273416
commit
111f48de6e
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader
|
||||
@ -32,7 +31,7 @@ from tensorflow.python.saved_model import tag_constants
|
||||
|
||||
class SimpleSaveTest(test.TestCase):
|
||||
|
||||
def _init_and_validate_variable(self, sess, variable_name, variable_value):
|
||||
def _init_and_validate_variable(self, variable_name, variable_value):
|
||||
v = variables.Variable(variable_value, name=variable_name)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(variable_value, self.evaluate(v))
|
||||
@ -54,50 +53,54 @@ class SimpleSaveTest(test.TestCase):
|
||||
self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size,
|
||||
expected_tensor.shape[i])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSimpleSave(self):
|
||||
"""Test simple_save that uses the default parameters."""
|
||||
export_dir = os.path.join(test.get_temp_dir(),
|
||||
"test_simple_save")
|
||||
|
||||
# Initialize input and output variables and save a prediction graph using
|
||||
# the default parameters.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
var_x = self._init_and_validate_variable(sess, "var_x", 1)
|
||||
var_y = self._init_and_validate_variable(sess, "var_y", 2)
|
||||
inputs = {"x": var_x}
|
||||
outputs = {"y": var_y}
|
||||
simple_save.simple_save(sess, export_dir, inputs, outputs)
|
||||
# Force the test to run in graph mode.
|
||||
# This tests a deprecated v1 API that both requires a session and uses
|
||||
# functionality that does not work with eager tensors (such as
|
||||
# build_tensor_info as called by predict_signature_def).
|
||||
with ops.Graph().as_default():
|
||||
# Initialize input and output variables and save a prediction graph using
|
||||
# the default parameters.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
var_x = self._init_and_validate_variable("var_x", 1)
|
||||
var_y = self._init_and_validate_variable("var_y", 2)
|
||||
inputs = {"x": var_x}
|
||||
outputs = {"y": var_y}
|
||||
simple_save.simple_save(sess, export_dir, inputs, outputs)
|
||||
|
||||
# Restore the graph with a valid tag and check the global variables and
|
||||
# signature def map.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
graph = loader.load(sess, [tag_constants.SERVING], export_dir)
|
||||
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
# Restore the graph with a valid tag and check the global variables and
|
||||
# signature def map.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
graph = loader.load(sess, [tag_constants.SERVING], export_dir)
|
||||
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
# Check value and metadata of the saved variables.
|
||||
self.assertEqual(len(collection_vars), 2)
|
||||
self.assertEqual(1, collection_vars[0].eval())
|
||||
self.assertEqual(2, collection_vars[1].eval())
|
||||
self._check_variable_info(collection_vars[0], var_x)
|
||||
self._check_variable_info(collection_vars[1], var_y)
|
||||
# Check value and metadata of the saved variables.
|
||||
self.assertEqual(len(collection_vars), 2)
|
||||
self.assertEqual(1, collection_vars[0].eval())
|
||||
self.assertEqual(2, collection_vars[1].eval())
|
||||
self._check_variable_info(collection_vars[0], var_x)
|
||||
self._check_variable_info(collection_vars[1], var_y)
|
||||
|
||||
# Check that the appropriate signature_def_map is created with the
|
||||
# default key and method name, and the specified inputs and outputs.
|
||||
signature_def_map = graph.signature_def
|
||||
self.assertEqual(1, len(signature_def_map))
|
||||
self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||
list(signature_def_map.keys())[0])
|
||||
# Check that the appropriate signature_def_map is created with the
|
||||
# default key and method name, and the specified inputs and outputs.
|
||||
signature_def_map = graph.signature_def
|
||||
self.assertEqual(1, len(signature_def_map))
|
||||
self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||
list(signature_def_map.keys())[0])
|
||||
|
||||
signature_def = signature_def_map[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
|
||||
signature_def.method_name)
|
||||
signature_def = signature_def_map[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
|
||||
signature_def.method_name)
|
||||
|
||||
self.assertEqual(1, len(signature_def.inputs))
|
||||
self._check_tensor_info(signature_def.inputs["x"], var_x)
|
||||
self.assertEqual(1, len(signature_def.outputs))
|
||||
self._check_tensor_info(signature_def.outputs["y"], var_y)
|
||||
self.assertEqual(1, len(signature_def.inputs))
|
||||
self._check_tensor_info(signature_def.inputs["x"], var_x)
|
||||
self.assertEqual(1, len(signature_def.outputs))
|
||||
self._check_tensor_info(signature_def.outputs["y"], var_y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user