Reduce 1-Layer Functional.__call__ overhead by ~10%.

Moves Model._set_save_spec to Layer. This allows Layer.__call__ to avoid a
hasattr check and also Model.__call__ to avoid an expensive call to a method
wrapped in trackable.no_automatic_dependency_tracking.

This should also allow SavedModel to use this spec in place of
build_input_shape in the future.

PiperOrigin-RevId: 313795786
Change-Id: Id7b23f98911468ed3e11261ac60989685de47aa1
This commit is contained in:
Thomas O'Malley 2020-05-29 09:34:47 -07:00 committed by TensorFlower Gardener
parent 6c5999948c
commit 2244921925
2 changed files with 18 additions and 9 deletions

View File

@ -316,6 +316,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
# submitted.
self._build_input_shape = None
self._saved_model_inputs_spec = None
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
self.supports_masking = False
@ -1002,7 +1003,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
if hasattr(self, '_set_save_spec'):
if self._saved_model_inputs_spec is None:
self._set_save_spec(cast_inputs)
return outputs
@ -2809,6 +2810,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# SavedModel properties. Please see keras/saving/saved_model for details.
@trackable.no_automatic_dependency_tracking
def _set_save_spec(self, inputs):
if self._saved_model_inputs_spec is not None:
return # Already set.
self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec,
inputs)
def _get_save_spec(self, dynamic_batch=True):
if self._saved_model_inputs_spec is None:
return None
return nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_inputs_spec)
@property
def _trackable_saved_model_saver(self):
return layer_serialization.LayerSavedModelSaver(self)

View File

@ -2366,14 +2366,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
self._saved_model_inputs_spec = specs
def _get_save_spec(self, dynamic_batch=True):
if self._saved_model_inputs_spec is None:
return None
return nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_inputs_spec)
def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created.