From 30060737d73f17aa39277a217298972d910ee3e8 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan <kaftan@google.com> Date: Thu, 15 Oct 2020 15:41:34 -0700 Subject: [PATCH] Update Keras instrumentation, and add tests to make sure it triggers. Specifically, make legacy TF layers no longer instrument as if they are Keras layers. PiperOrigin-RevId: 337397565 Change-Id: I6f1290f05aec8b26525806c67ded2ce5607a4cab --- tensorflow/python/keras/engine/base_layer.py | 21 ++++++++++++++----- .../python/keras/engine/base_layer_test.py | 7 +++++++ .../python/keras/engine/base_layer_v1.py | 4 ++-- .../python/keras/engine/training_test.py | 13 ++++++++++++ .../python/keras/legacy_tf_layers/base.py | 3 +++ .../keras/legacy_tf_layers/base_test.py | 5 +++++ 6 files changed, 46 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 07c371465be..3a3f6363e3c 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -304,6 +304,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # not available to the restoration code). _must_restore_from_config = False + def _instrument_layer_creation(self): + self._instrumented_keras_api = False + self._instrumented_keras_layer_class = False + self._instrumented_keras_model_class = False + if not getattr(self, '_disable_keras_instrumentation', False): + keras_api_gauge.get_cell('layer').set(True) + self._instrumented_keras_api = True + if getattr(self, '_is_model_for_instrumentation', False): + keras_models_gauge.get_cell(self.__class__.__name__).set(True) + self._instrumented_keras_model_class = True + else: + keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrumented_keras_layer_class = True + @trackable.no_automatic_dependency_tracking def __init__(self, trainable=True, @@ -311,11 +325,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): dtype=None, dynamic=False, **kwargs): - keras_api_gauge.get_cell('layer').set(True) - if getattr(self, '_is_model_for_instrumentation', False): - keras_models_gauge.get_cell(self.__class__.__name__).set(True) - else: - keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrument_layer_creation() + # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 8334df112d6..029447b3cba 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -85,6 +85,13 @@ class InvalidLayer(base_layer.Layer): class BaseLayerTest(keras_parameterized.TestCase): + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_layer_instrumentation(self): + layer = layers.Add() + self.assertTrue(layer._instrumented_keras_api) + self.assertTrue(layer._instrumented_keras_layer_class) + self.assertFalse(layer._instrumented_keras_model_class) + @combinations.generate(combinations.times( combinations.keras_model_type_combinations(), combinations.keras_tensor_combinations())) diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index d69f9d07702..238f2ae5248 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -152,8 +152,8 @@ class Layer(base_layer.Layer): @trackable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs): - base_layer.keras_api_gauge.get_cell('layer').set(True) - base_layer.keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrument_layer_creation() + # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index a651dbf1f05..dee1055bbc4 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -68,6 +68,19 @@ except ImportError: class TrainingTest(keras_parameterized.TestCase): + @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_with_all_model_types + def test_model_instrumentation(self): + layers = [ + layers_module.Dense(10, dtype=np.float64), + layers_module.Dense(10, dtype=np.float64) + ] + model = testing_utils.get_model_from_layers(layers, input_shape=(1,)) + + self.assertTrue(model._instrumented_keras_api) + self.assertTrue(model._instrumented_keras_model_class) + self.assertFalse(model._instrumented_keras_layer_class) + @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes def test_fit_training_arg(self): diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py index 8052651efa7..1e2a7e7861c 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base.py +++ b/tensorflow/python/keras/legacy_tf_layers/base.py @@ -210,6 +210,9 @@ class Layer(base_layer.Layer): if 'autocast' not in kwargs: kwargs['autocast'] = False + # Mark that legacy layers should not be instrumented as Keras usage + self._disable_keras_instrumentation = True + super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) diff --git a/tensorflow/python/keras/legacy_tf_layers/base_test.py b/tensorflow/python/keras/legacy_tf_layers/base_test.py index 2c9810c4109..90d57fae407 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/base_test.py @@ -60,6 +60,9 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase): layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) + # Assert that the layer was not instrumented as a Keras layer + self.assertFalse(layer._instrumented_keras_api) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def testInt64Layer(self): layer = base_layers.Layer(name='my_layer', dtype='int64') @@ -83,6 +86,8 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase): with base_layers.keras_style_scope(): layer = base_layers.Layer(name='my_layer') + # Assert that the layer was not instrumented as a Keras layer + self.assertFalse(layer._instrumented_keras_api) # Test basic variable creation. with backend.name_scope('bar'): variable = layer.add_variable(