Update load_v1_in_v2 to handle composite tensors.

PiperOrigin-RevId: 332868904
Change-Id: I64669ec391c1d6601d0f8212acb65f2ec2edcad6
This commit is contained in:
Edward Loper 2020-09-21 09:40:48 -07:00 committed by TensorFlower Gardener
parent 2321f691fd
commit ae1d38052c
3 changed files with 29 additions and 0 deletions

View File

@ -442,6 +442,7 @@ py_strict_library(
":loader",
":signature_serialization",
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
@ -449,6 +450,7 @@ py_strict_library(
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:lift_to_graph",
"//tensorflow/python/eager:wrap_function",

View File

@ -23,6 +23,7 @@ import functools
from tensorflow.python.eager import context
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
@ -36,6 +37,7 @@ from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
class _Initializer(tracking.CapturableResource):
@ -154,6 +156,11 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
dense_shape_name = "%s_dense_shape" % original_input_name
input_names.extend([indices_name, values_name, dense_shape_name])
input_tensors.extend([feed.indices, feed.values, feed.dense_shape])
elif isinstance(feed, composite_tensor.CompositeTensor):
component_tensors = nest.flatten(feed, expand_composites=True)
input_names.extend("%s_component_%d" % (original_input_name, n)
for n in range(len(component_tensors)))
input_tensors.extend(component_tensors)
else:
input_names.append(original_input_name)
input_tensors.append(feed)

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
@ -566,6 +567,25 @@ class LoadTest(test.TestCase):
start_dense_shape=dense_shape)
self.assertAllEqual([84, 86, 88], result["output"].values.numpy())
def _model_with_ragged_input(self):
"""Generate a graph with a RaggedTensor input and serialize in V1 format."""
export_graph = ops.Graph()
with export_graph.as_default():
x = ragged_factory_ops.placeholder(dtypes.float32, 1, [])
y = x * 2
with session_lib.Session() as sess:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(sess, path, inputs={"x": x}, outputs={"y": y})
return path
def test_load_ragged_inputs(self):
path = self._model_with_ragged_input()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
x = ragged_factory_ops.constant([[10., 20.], [30.]])
result = imported_fn(x_component_0=x.values, x_component_1=x.row_splits)
self.assertAllEqual(result["y"], [[20., 40.], [60.]])
def _model_with_defun(self):
"""Generate a graph with a Defun and serialize in V1 format."""
export_graph = ops.Graph()