Serialize concrete function signature using structured_input_signature instead of function inputs.
This way, if the user calls set_shape on the inputs within the function body, the serialized signature is not affected. PiperOrigin-RevId: 338345710 Change-Id: Ie31b9e2206de57aca4e592bbd43fafbff0d2bda6
This commit is contained in:
parent
642c3e8498
commit
12d00c3e34
@ -1058,6 +1058,10 @@ class TrtGraphConverterV2(object):
|
||||
[tensor.name for tensor in func.outputs])
|
||||
rebuilt_func.graph.structured_outputs = nest.pack_sequence_as(
|
||||
func.graph.structured_outputs, rebuilt_func.graph.structured_outputs)
|
||||
# Copy structured input signature from original function (used during
|
||||
# serialization)
|
||||
rebuilt_func.graph.structured_input_signature = (
|
||||
func.structured_input_signature)
|
||||
return rebuilt_func
|
||||
|
||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||
@ -1110,6 +1114,10 @@ class TrtGraphConverterV2(object):
|
||||
self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
|
||||
func.graph.structured_outputs,
|
||||
self._converted_func.graph.structured_outputs)
|
||||
# Copy structured input signature from original function (used during
|
||||
# serialization)
|
||||
self._converted_func.graph.structured_input_signature = (
|
||||
func.structured_input_signature)
|
||||
|
||||
if self._need_calibration:
|
||||
for inp in calibration_input_fn():
|
||||
@ -1265,6 +1273,8 @@ class TrtGraphConverterV2(object):
|
||||
reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
|
||||
self._converted_func.graph.structured_outputs,
|
||||
reset_converted_func.graph.structured_outputs)
|
||||
reset_converted_func.graph.strucutred_input_signature = (
|
||||
self._converted_func.structured_input_signature)
|
||||
self._converted_func = reset_converted_func
|
||||
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
|
@ -869,6 +869,13 @@ class Functional(training_lib.Model):
|
||||
def _trackable_saved_model_saver(self):
|
||||
return network_serialization.NetworkSavedModelSaver(self)
|
||||
|
||||
def _get_save_spec(self, dynamic_batch=True):
|
||||
if getattr(self, '_has_explicit_input_shape', True):
|
||||
# Functional models and Sequential models that have an explicit input
|
||||
# shape should use the batch size set by the input layer.
|
||||
dynamic_batch = False
|
||||
return super(Functional, self)._get_save_spec(dynamic_batch)
|
||||
|
||||
|
||||
def _make_node_key(layer_name, node_index):
|
||||
return layer_name + '_ib-' + str(node_index)
|
||||
|
@ -668,6 +668,23 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
loader.load(session, [tag_constants.SERVING], export_dir)
|
||||
|
||||
def test_concrete_function_with_set_shape(self,):
|
||||
# Serialized concrete function should retain the shape from the TensorSpec,
|
||||
# instead of using the shape of the inputs (which are changed by set_shape).
|
||||
@def_function.function
|
||||
def f(x):
|
||||
x.set_shape((5, 1))
|
||||
return x
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
concrete = f.get_concrete_function(
|
||||
tensor_spec.TensorSpec((None, 1), name="name"))
|
||||
save.save(root, path, signatures={"key": concrete})
|
||||
imported = load.load(path)
|
||||
self.assertEqual(imported.signatures["key"].structured_input_signature[1],
|
||||
{"name": tensor_spec.TensorSpec((None, 1), name="name")})
|
||||
|
||||
|
||||
class VariablePolicyEnumTest(test.TestCase):
|
||||
|
||||
|
@ -124,15 +124,26 @@ def canonicalize_signatures(signatures):
|
||||
structured_outputs = signature_function(**kwargs)
|
||||
return _normalize_outputs(
|
||||
structured_outputs, signature_function.name, signature_key)
|
||||
# TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
|
||||
# always match keyword arguments.
|
||||
tensor_spec_signature = {}
|
||||
for keyword, tensor in zip(
|
||||
if signature_function.structured_input_signature is not None:
|
||||
# The structured input signature may contain other non-tensor arguments.
|
||||
inputs = filter(
|
||||
lambda x: isinstance(x, tensor_spec.TensorSpec),
|
||||
nest.flatten(signature_function.structured_input_signature,
|
||||
expand_composites=True))
|
||||
else:
|
||||
# Structured input signature isn't always defined for some functions.
|
||||
inputs = signature_function.inputs
|
||||
|
||||
for keyword, inp in zip(
|
||||
signature_function._arg_keywords, # pylint: disable=protected-access
|
||||
signature_function.inputs):
|
||||
inputs):
|
||||
keyword = compat.as_str(keyword)
|
||||
tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
|
||||
tensor, name=keyword)
|
||||
if isinstance(inp, tensor_spec.TensorSpec):
|
||||
spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
|
||||
else:
|
||||
spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
|
||||
tensor_spec_signature[keyword] = spec
|
||||
final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access
|
||||
**tensor_spec_signature)
|
||||
# pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user