From 224492192533d83f2c2d03a769a6f69b0426f29a Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 29 May 2020 09:34:47 -0700 Subject: [PATCH] Reduce 1-Layer Functional.__call__ overhead by ~10%. Moves Model._set_save_spec to Layer. This allows Layer.__call__ to avoid a hasattr check and also Model.__call__ to avoid an expensive call to a method wrapped in trackable.no_automatic_dependency_tracking. This should also allow SavedModel to use this spec in place of build_input_shape in the future. PiperOrigin-RevId: 313795786 Change-Id: Id7b23f98911468ed3e11261ac60989685de47aa1 --- tensorflow/python/keras/engine/base_layer.py | 19 ++++++++++++++++++- tensorflow/python/keras/engine/training.py | 8 -------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 9958f70ed55..4a33e8f4e20 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -316,6 +316,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is # submitted. self._build_input_shape = None + self._saved_model_inputs_spec = None # Provides information about which inputs are compatible with the layer. self._input_spec = None self.supports_masking = False @@ -1002,7 +1003,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): outputs = self.call(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks, build_graph) - if hasattr(self, '_set_save_spec'): + if self._saved_model_inputs_spec is None: self._set_save_spec(cast_inputs) return outputs @@ -2809,6 +2810,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # SavedModel properties. Please see keras/saving/saved_model for details. + @trackable.no_automatic_dependency_tracking + def _set_save_spec(self, inputs): + if self._saved_model_inputs_spec is not None: + return # Already set. + + self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec, + inputs) + + def _get_save_spec(self, dynamic_batch=True): + if self._saved_model_inputs_spec is None: + return None + + return nest.map_structure( + lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), + self._saved_model_inputs_spec) + @property def _trackable_saved_model_saver(self): return layer_serialization.LayerSavedModelSaver(self) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d8c95b2a972..6c6d9ee897b 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2366,14 +2366,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): self._saved_model_inputs_spec = specs - def _get_save_spec(self, dynamic_batch=True): - if self._saved_model_inputs_spec is None: - return None - - return nest.map_structure( - lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), - self._saved_model_inputs_spec) - def _assert_weights_created(self): """Asserts that all the weights for the model have been created.