Update load_v1_in_v2 to handle composite tensors.
PiperOrigin-RevId: 332868904 Change-Id: I64669ec391c1d6601d0f8212acb65f2ec2edcad6
This commit is contained in:
parent
2321f691fd
commit
ae1d38052c
@ -442,6 +442,7 @@ py_strict_library(
|
||||
":loader",
|
||||
":signature_serialization",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
@ -449,6 +450,7 @@ py_strict_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:lift_to_graph",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
|
@ -23,6 +23,7 @@ import functools
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
@ -36,6 +37,7 @@ from tensorflow.python.saved_model import signature_serialization
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class _Initializer(tracking.CapturableResource):
|
||||
@ -154,6 +156,11 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
dense_shape_name = "%s_dense_shape" % original_input_name
|
||||
input_names.extend([indices_name, values_name, dense_shape_name])
|
||||
input_tensors.extend([feed.indices, feed.values, feed.dense_shape])
|
||||
elif isinstance(feed, composite_tensor.CompositeTensor):
|
||||
component_tensors = nest.flatten(feed, expand_composites=True)
|
||||
input_names.extend("%s_component_%d" % (original_input_name, n)
|
||||
for n in range(len(component_tensors)))
|
||||
input_tensors.extend(component_tensors)
|
||||
else:
|
||||
input_names.append(original_input_name)
|
||||
input_tensors.append(feed)
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.saved_model import builder_impl
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
@ -566,6 +567,25 @@ class LoadTest(test.TestCase):
|
||||
start_dense_shape=dense_shape)
|
||||
self.assertAllEqual([84, 86, 88], result["output"].values.numpy())
|
||||
|
||||
def _model_with_ragged_input(self):
|
||||
"""Generate a graph with a RaggedTensor input and serialize in V1 format."""
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
x = ragged_factory_ops.placeholder(dtypes.float32, 1, [])
|
||||
y = x * 2
|
||||
with session_lib.Session() as sess:
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||
simple_save.simple_save(sess, path, inputs={"x": x}, outputs={"y": y})
|
||||
return path
|
||||
|
||||
def test_load_ragged_inputs(self):
|
||||
path = self._model_with_ragged_input()
|
||||
imported = load.load(path)
|
||||
imported_fn = imported.signatures["serving_default"]
|
||||
x = ragged_factory_ops.constant([[10., 20.], [30.]])
|
||||
result = imported_fn(x_component_0=x.values, x_component_1=x.row_splits)
|
||||
self.assertAllEqual(result["y"], [[20., 40.], [60.]])
|
||||
|
||||
def _model_with_defun(self):
|
||||
"""Generate a graph with a Defun and serialize in V1 format."""
|
||||
export_graph = ops.Graph()
|
||||
|
Loading…
Reference in New Issue
Block a user