From 30060737d73f17aa39277a217298972d910ee3e8 Mon Sep 17 00:00:00 2001
From: Tomer Kaftan <kaftan@google.com>
Date: Thu, 15 Oct 2020 15:41:34 -0700
Subject: [PATCH] Update Keras instrumentation, and add tests to make sure it
 triggers. Specifically, make legacy TF layers no longer instrument as if they
 are Keras layers.

PiperOrigin-RevId: 337397565
Change-Id: I6f1290f05aec8b26525806c67ded2ce5607a4cab
---
 tensorflow/python/keras/engine/base_layer.py  | 21 ++++++++++++++-----
 .../python/keras/engine/base_layer_test.py    |  7 +++++++
 .../python/keras/engine/base_layer_v1.py      |  4 ++--
 .../python/keras/engine/training_test.py      | 13 ++++++++++++
 .../python/keras/legacy_tf_layers/base.py     |  3 +++
 .../keras/legacy_tf_layers/base_test.py       |  5 +++++
 6 files changed, 46 insertions(+), 7 deletions(-)

diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 07c371465be..3a3f6363e3c 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -304,6 +304,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
   # not available to the restoration code).
   _must_restore_from_config = False
 
+  def _instrument_layer_creation(self):
+    self._instrumented_keras_api = False
+    self._instrumented_keras_layer_class = False
+    self._instrumented_keras_model_class = False
+    if not getattr(self, '_disable_keras_instrumentation', False):
+      keras_api_gauge.get_cell('layer').set(True)
+      self._instrumented_keras_api = True
+      if getattr(self, '_is_model_for_instrumentation', False):
+        keras_models_gauge.get_cell(self.__class__.__name__).set(True)
+        self._instrumented_keras_model_class = True
+      else:
+        keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
+        self._instrumented_keras_layer_class = True
+
   @trackable.no_automatic_dependency_tracking
   def __init__(self,
                trainable=True,
@@ -311,11 +325,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
                dtype=None,
                dynamic=False,
                **kwargs):
-    keras_api_gauge.get_cell('layer').set(True)
-    if getattr(self, '_is_model_for_instrumentation', False):
-      keras_models_gauge.get_cell(self.__class__.__name__).set(True)
-    else:
-      keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
+    self._instrument_layer_creation()
+
     # These properties should be set by the user via keyword arguments.
     # note that 'dtype', 'input_shape' and 'batch_input_shape'
     # are only applicable to input layers: do not pass these keywords
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 8334df112d6..029447b3cba 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -85,6 +85,13 @@ class InvalidLayer(base_layer.Layer):
 
 class BaseLayerTest(keras_parameterized.TestCase):
 
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def test_layer_instrumentation(self):
+    layer = layers.Add()
+    self.assertTrue(layer._instrumented_keras_api)
+    self.assertTrue(layer._instrumented_keras_layer_class)
+    self.assertFalse(layer._instrumented_keras_model_class)
+
   @combinations.generate(combinations.times(
       combinations.keras_model_type_combinations(),
       combinations.keras_tensor_combinations()))
diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py
index d69f9d07702..238f2ae5248 100644
--- a/tensorflow/python/keras/engine/base_layer_v1.py
+++ b/tensorflow/python/keras/engine/base_layer_v1.py
@@ -152,8 +152,8 @@ class Layer(base_layer.Layer):
   @trackable.no_automatic_dependency_tracking
   def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
                **kwargs):
-    base_layer.keras_api_gauge.get_cell('layer').set(True)
-    base_layer.keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
+    self._instrument_layer_creation()
+
     # These properties should be set by the user via keyword arguments.
     # note that 'dtype', 'input_shape' and 'batch_input_shape'
     # are only applicable to input layers: do not pass these keywords
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index a651dbf1f05..dee1055bbc4 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -68,6 +68,19 @@ except ImportError:
 
 class TrainingTest(keras_parameterized.TestCase):
 
+  @keras_parameterized.run_all_keras_modes
+  @keras_parameterized.run_with_all_model_types
+  def test_model_instrumentation(self):
+    layers = [
+        layers_module.Dense(10, dtype=np.float64),
+        layers_module.Dense(10, dtype=np.float64)
+    ]
+    model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
+
+    self.assertTrue(model._instrumented_keras_api)
+    self.assertTrue(model._instrumented_keras_model_class)
+    self.assertFalse(model._instrumented_keras_layer_class)
+
   @keras_parameterized.run_with_all_model_types
   @keras_parameterized.run_all_keras_modes
   def test_fit_training_arg(self):
diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py
index 8052651efa7..1e2a7e7861c 100644
--- a/tensorflow/python/keras/legacy_tf_layers/base.py
+++ b/tensorflow/python/keras/legacy_tf_layers/base.py
@@ -210,6 +210,9 @@ class Layer(base_layer.Layer):
     if 'autocast' not in kwargs:
       kwargs['autocast'] = False
 
+    # Mark that legacy layers should not be instrumented as Keras usage
+    self._disable_keras_instrumentation = True
+
     super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
                                 **kwargs)
 
diff --git a/tensorflow/python/keras/legacy_tf_layers/base_test.py b/tensorflow/python/keras/legacy_tf_layers/base_test.py
index 2c9810c4109..90d57fae407 100644
--- a/tensorflow/python/keras/legacy_tf_layers/base_test.py
+++ b/tensorflow/python/keras/legacy_tf_layers/base_test.py
@@ -60,6 +60,9 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
     layer = base_layers.Layer(name='my_layer', trainable=False)
     self.assertEqual(layer.trainable, False)
 
+    # Assert that the layer was not instrumented as a Keras layer
+    self.assertFalse(layer._instrumented_keras_api)
+
   @combinations.generate(combinations.combine(mode=['graph', 'eager']))
   def testInt64Layer(self):
     layer = base_layers.Layer(name='my_layer', dtype='int64')
@@ -83,6 +86,8 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
 
     with base_layers.keras_style_scope():
       layer = base_layers.Layer(name='my_layer')
+    # Assert that the layer was not instrumented as a Keras layer
+    self.assertFalse(layer._instrumented_keras_api)
     # Test basic variable creation.
     with backend.name_scope('bar'):
       variable = layer.add_variable(