Serialize concrete function signature using structured_input_signature instead of function inputs.

This way, if the user calls set_shape on the inputs within the function body, the serialized signature is not affected.

PiperOrigin-RevId: 338345710
Change-Id: Ie31b9e2206de57aca4e592bbd43fafbff0d2bda6
This commit is contained in:
Katherine Wu 2020-10-21 14:52:32 -07:00 committed by TensorFlower Gardener
parent 642c3e8498
commit 12d00c3e34
4 changed files with 51 additions and 6 deletions

View File

@ -1058,6 +1058,10 @@ class TrtGraphConverterV2(object):
[tensor.name for tensor in func.outputs])
rebuilt_func.graph.structured_outputs = nest.pack_sequence_as(
func.graph.structured_outputs, rebuilt_func.graph.structured_outputs)
# Copy structured input signature from original function (used during
# serialization)
rebuilt_func.graph.structured_input_signature = (
func.structured_input_signature)
return rebuilt_func
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
@ -1110,6 +1114,10 @@ class TrtGraphConverterV2(object):
self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
func.graph.structured_outputs,
self._converted_func.graph.structured_outputs)
# Copy structured input signature from original function (used during
# serialization)
self._converted_func.graph.structured_input_signature = (
func.structured_input_signature)
if self._need_calibration:
for inp in calibration_input_fn():
@ -1265,6 +1273,8 @@ class TrtGraphConverterV2(object):
reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
self._converted_func.graph.structured_outputs,
reset_converted_func.graph.structured_outputs)
reset_converted_func.graph.strucutred_input_signature = (
self._converted_func.structured_input_signature)
self._converted_func = reset_converted_func
signatures[self._input_saved_model_signature_key] = self._converted_func

View File

@ -869,6 +869,13 @@ class Functional(training_lib.Model):
def _trackable_saved_model_saver(self):
return network_serialization.NetworkSavedModelSaver(self)
def _get_save_spec(self, dynamic_batch=True):
if getattr(self, '_has_explicit_input_shape', True):
# Functional models and Sequential models that have an explicit input
# shape should use the batch size set by the input layer.
dynamic_batch = False
return super(Functional, self)._get_save_spec(dynamic_batch)
def _make_node_key(layer_name, node_index):
return layer_name + '_ib-' + str(node_index)

View File

@ -668,6 +668,23 @@ class SaveTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
loader.load(session, [tag_constants.SERVING], export_dir)
def test_concrete_function_with_set_shape(self,):
# Serialized concrete function should retain the shape from the TensorSpec,
# instead of using the shape of the inputs (which are changed by set_shape).
@def_function.function
def f(x):
x.set_shape((5, 1))
return x
root = tracking.AutoTrackable()
path = os.path.join(self.get_temp_dir(), "saved_model")
concrete = f.get_concrete_function(
tensor_spec.TensorSpec((None, 1), name="name"))
save.save(root, path, signatures={"key": concrete})
imported = load.load(path)
self.assertEqual(imported.signatures["key"].structured_input_signature[1],
{"name": tensor_spec.TensorSpec((None, 1), name="name")})
class VariablePolicyEnumTest(test.TestCase):

View File

@ -124,15 +124,26 @@ def canonicalize_signatures(signatures):
structured_outputs = signature_function(**kwargs)
return _normalize_outputs(
structured_outputs, signature_function.name, signature_key)
# TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
# always match keyword arguments.
tensor_spec_signature = {}
for keyword, tensor in zip(
if signature_function.structured_input_signature is not None:
# The structured input signature may contain other non-tensor arguments.
inputs = filter(
lambda x: isinstance(x, tensor_spec.TensorSpec),
nest.flatten(signature_function.structured_input_signature,
expand_composites=True))
else:
# Structured input signature isn't always defined for some functions.
inputs = signature_function.inputs
for keyword, inp in zip(
signature_function._arg_keywords, # pylint: disable=protected-access
signature_function.inputs):
inputs):
keyword = compat.as_str(keyword)
tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
tensor, name=keyword)
if isinstance(inp, tensor_spec.TensorSpec):
spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
else:
spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
tensor_spec_signature[keyword] = spec
final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access
**tensor_spec_signature)
# pylint: disable=protected-access