diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index b47012469e0..8ea5c96f4cc 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -1058,6 +1058,10 @@ class TrtGraphConverterV2(object): [tensor.name for tensor in func.outputs]) rebuilt_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, rebuilt_func.graph.structured_outputs) + # Copy structured input signature from original function (used during + # serialization) + rebuilt_func.graph.structured_input_signature = ( + func.structured_input_signature) return rebuilt_func # TODO(laigd): provide a utility function to optimize a ConcreteFunction and @@ -1110,6 +1114,10 @@ class TrtGraphConverterV2(object): self._converted_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, self._converted_func.graph.structured_outputs) + # Copy structured input signature from original function (used during + # serialization) + self._converted_func.graph.structured_input_signature = ( + func.structured_input_signature) if self._need_calibration: for inp in calibration_input_fn(): @@ -1265,6 +1273,8 @@ class TrtGraphConverterV2(object): reset_converted_func.graph.structured_outputs = nest.pack_sequence_as( self._converted_func.graph.structured_outputs, reset_converted_func.graph.structured_outputs) + reset_converted_func.graph.strucutred_input_signature = ( + self._converted_func.structured_input_signature) self._converted_func = reset_converted_func signatures[self._input_saved_model_signature_key] = self._converted_func diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 892773fa656..57ae9ee92a7 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -869,6 +869,13 @@ class Functional(training_lib.Model): def _trackable_saved_model_saver(self): return network_serialization.NetworkSavedModelSaver(self) + def _get_save_spec(self, dynamic_batch=True): + if getattr(self, '_has_explicit_input_shape', True): + # Functional models and Sequential models that have an explicit input + # shape should use the batch size set by the input layer. + dynamic_batch = False + return super(Functional, self)._get_save_spec(dynamic_batch) + def _make_node_key(layer_name, node_index): return layer_name + '_ib-' + str(node_index) diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index cf25bed9d9d..3de9d70e9dc 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -668,6 +668,23 @@ class SaveTest(test.TestCase, parameterized.TestCase): with self.assertRaises(ValueError): loader.load(session, [tag_constants.SERVING], export_dir) + def test_concrete_function_with_set_shape(self,): + # Serialized concrete function should retain the shape from the TensorSpec, + # instead of using the shape of the inputs (which are changed by set_shape). + @def_function.function + def f(x): + x.set_shape((5, 1)) + return x + + root = tracking.AutoTrackable() + path = os.path.join(self.get_temp_dir(), "saved_model") + concrete = f.get_concrete_function( + tensor_spec.TensorSpec((None, 1), name="name")) + save.save(root, path, signatures={"key": concrete}) + imported = load.load(path) + self.assertEqual(imported.signatures["key"].structured_input_signature[1], + {"name": tensor_spec.TensorSpec((None, 1), name="name")}) + class VariablePolicyEnumTest(test.TestCase): diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 74f76c690f2..14f4df81380 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -124,15 +124,26 @@ def canonicalize_signatures(signatures): structured_outputs = signature_function(**kwargs) return _normalize_outputs( structured_outputs, signature_function.name, signature_key) - # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names - # always match keyword arguments. tensor_spec_signature = {} - for keyword, tensor in zip( + if signature_function.structured_input_signature is not None: + # The structured input signature may contain other non-tensor arguments. + inputs = filter( + lambda x: isinstance(x, tensor_spec.TensorSpec), + nest.flatten(signature_function.structured_input_signature, + expand_composites=True)) + else: + # Structured input signature isn't always defined for some functions. + inputs = signature_function.inputs + + for keyword, inp in zip( signature_function._arg_keywords, # pylint: disable=protected-access - signature_function.inputs): + inputs): keyword = compat.as_str(keyword) - tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor( - tensor, name=keyword) + if isinstance(inp, tensor_spec.TensorSpec): + spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) + else: + spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) + tensor_spec_signature[keyword] = spec final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature) # pylint: disable=protected-access