Reduce 1-Layer Functional.__call__ overhead by ~10%.
Moves Model._set_save_spec to Layer. This allows Layer.__call__ to avoid a hasattr check and also Model.__call__ to avoid an expensive call to a method wrapped in trackable.no_automatic_dependency_tracking. This should also allow SavedModel to use this spec in place of build_input_shape in the future. PiperOrigin-RevId: 313795786 Change-Id: Id7b23f98911468ed3e11261ac60989685de47aa1
This commit is contained in:
parent
6c5999948c
commit
2244921925
|
@ -316,6 +316,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
|
||||
# submitted.
|
||||
self._build_input_shape = None
|
||||
self._saved_model_inputs_spec = None
|
||||
# Provides information about which inputs are compatible with the layer.
|
||||
self._input_spec = None
|
||||
self.supports_masking = False
|
||||
|
@ -1002,7 +1003,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
outputs = self.call(cast_inputs, *args, **kwargs)
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
|
||||
if hasattr(self, '_set_save_spec'):
|
||||
if self._saved_model_inputs_spec is None:
|
||||
self._set_save_spec(cast_inputs)
|
||||
|
||||
return outputs
|
||||
|
@ -2809,6 +2810,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
|
||||
# SavedModel properties. Please see keras/saving/saved_model for details.
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_save_spec(self, inputs):
|
||||
if self._saved_model_inputs_spec is not None:
|
||||
return # Already set.
|
||||
|
||||
self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec,
|
||||
inputs)
|
||||
|
||||
def _get_save_spec(self, dynamic_batch=True):
|
||||
if self._saved_model_inputs_spec is None:
|
||||
return None
|
||||
|
||||
return nest.map_structure(
|
||||
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
|
||||
self._saved_model_inputs_spec)
|
||||
|
||||
@property
|
||||
def _trackable_saved_model_saver(self):
|
||||
return layer_serialization.LayerSavedModelSaver(self)
|
||||
|
|
|
@ -2366,14 +2366,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
|
||||
self._saved_model_inputs_spec = specs
|
||||
|
||||
def _get_save_spec(self, dynamic_batch=True):
|
||||
if self._saved_model_inputs_spec is None:
|
||||
return None
|
||||
|
||||
return nest.map_structure(
|
||||
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
|
||||
self._saved_model_inputs_spec)
|
||||
|
||||
def _assert_weights_created(self):
|
||||
"""Asserts that all the weights for the model have been created.
|
||||
|
||||
|
|
Loading…
Reference in New Issue