diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 10cf520f4e5..9c1cc090d7d 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -442,6 +442,7 @@ py_strict_library( ":loader", ":signature_serialization", "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -449,6 +450,7 @@ py_strict_library( "//tensorflow/python:platform", "//tensorflow/python:saver", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:lift_to_graph", "//tensorflow/python/eager:wrap_function", diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index add3b4e6320..a8627701bb8 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -23,6 +23,7 @@ import functools from tensorflow.python.eager import context from tensorflow.python.eager import lift_to_graph from tensorflow.python.eager import wrap_function +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph @@ -36,6 +37,7 @@ from tensorflow.python.saved_model import signature_serialization from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as tf_saver from tensorflow.python.training.tracking import tracking +from tensorflow.python.util import nest class _Initializer(tracking.CapturableResource): @@ -154,6 +156,11 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): dense_shape_name = "%s_dense_shape" % original_input_name input_names.extend([indices_name, values_name, dense_shape_name]) input_tensors.extend([feed.indices, feed.values, feed.dense_shape]) + elif isinstance(feed, composite_tensor.CompositeTensor): + component_tensors = nest.flatten(feed, expand_composites=True) + input_names.extend("%s_component_%d" % (original_input_name, n) + for n in range(len(component_tensors))) + input_tensors.extend(component_tensors) else: input_names.append(original_input_name) input_tensors.append(feed) diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 806a4db6fba..cab2c8bedb0 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import load from tensorflow.python.saved_model import save @@ -566,6 +567,25 @@ class LoadTest(test.TestCase): start_dense_shape=dense_shape) self.assertAllEqual([84, 86, 88], result["output"].values.numpy()) + def _model_with_ragged_input(self): + """Generate a graph with a RaggedTensor input and serialize in V1 format.""" + export_graph = ops.Graph() + with export_graph.as_default(): + x = ragged_factory_ops.placeholder(dtypes.float32, 1, []) + y = x * 2 + with session_lib.Session() as sess: + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save(sess, path, inputs={"x": x}, outputs={"y": y}) + return path + + def test_load_ragged_inputs(self): + path = self._model_with_ragged_input() + imported = load.load(path) + imported_fn = imported.signatures["serving_default"] + x = ragged_factory_ops.constant([[10., 20.], [30.]]) + result = imported_fn(x_component_0=x.values, x_component_1=x.row_splits) + self.assertAllEqual(result["y"], [[20., 40.], [60.]]) + def _model_with_defun(self): """Generate a graph with a Defun and serialize in V1 format.""" export_graph = ops.Graph()