Update Keras instrumentation, and add tests to make sure it triggers. Specifically, make legacy TF layers no longer instrument as if they are Keras layers.

PiperOrigin-RevId: 337397565
Change-Id: I6f1290f05aec8b26525806c67ded2ce5607a4cab
This commit is contained in:
Tomer Kaftan 2020-10-15 15:41:34 -07:00 committed by TensorFlower Gardener
parent 7c595a218b
commit 30060737d7
6 changed files with 46 additions and 7 deletions

View File

@ -304,6 +304,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# not available to the restoration code).
_must_restore_from_config = False
def _instrument_layer_creation(self):
self._instrumented_keras_api = False
self._instrumented_keras_layer_class = False
self._instrumented_keras_model_class = False
if not getattr(self, '_disable_keras_instrumentation', False):
keras_api_gauge.get_cell('layer').set(True)
self._instrumented_keras_api = True
if getattr(self, '_is_model_for_instrumentation', False):
keras_models_gauge.get_cell(self.__class__.__name__).set(True)
self._instrumented_keras_model_class = True
else:
keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
self._instrumented_keras_layer_class = True
@trackable.no_automatic_dependency_tracking
def __init__(self,
trainable=True,
@ -311,11 +325,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
dtype=None,
dynamic=False,
**kwargs):
keras_api_gauge.get_cell('layer').set(True)
if getattr(self, '_is_model_for_instrumentation', False):
keras_models_gauge.get_cell(self.__class__.__name__).set(True)
else:
keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
self._instrument_layer_creation()
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
# are only applicable to input layers: do not pass these keywords

View File

@ -85,6 +85,13 @@ class InvalidLayer(base_layer.Layer):
class BaseLayerTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_layer_instrumentation(self):
layer = layers.Add()
self.assertTrue(layer._instrumented_keras_api)
self.assertTrue(layer._instrumented_keras_layer_class)
self.assertFalse(layer._instrumented_keras_model_class)
@combinations.generate(combinations.times(
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))

View File

@ -152,8 +152,8 @@ class Layer(base_layer.Layer):
@trackable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
**kwargs):
base_layer.keras_api_gauge.get_cell('layer').set(True)
base_layer.keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
self._instrument_layer_creation()
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
# are only applicable to input layers: do not pass these keywords

View File

@ -68,6 +68,19 @@ except ImportError:
class TrainingTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types
def test_model_instrumentation(self):
layers = [
layers_module.Dense(10, dtype=np.float64),
layers_module.Dense(10, dtype=np.float64)
]
model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
self.assertTrue(model._instrumented_keras_api)
self.assertTrue(model._instrumented_keras_model_class)
self.assertFalse(model._instrumented_keras_layer_class)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_fit_training_arg(self):

View File

@ -210,6 +210,9 @@ class Layer(base_layer.Layer):
if 'autocast' not in kwargs:
kwargs['autocast'] = False
# Mark that legacy layers should not be instrumented as Keras usage
self._disable_keras_instrumentation = True
super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
**kwargs)

View File

@ -60,6 +60,9 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
layer = base_layers.Layer(name='my_layer', trainable=False)
self.assertEqual(layer.trainable, False)
# Assert that the layer was not instrumented as a Keras layer
self.assertFalse(layer._instrumented_keras_api)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testInt64Layer(self):
layer = base_layers.Layer(name='my_layer', dtype='int64')
@ -83,6 +86,8 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
with base_layers.keras_style_scope():
layer = base_layers.Layer(name='my_layer')
# Assert that the layer was not instrumented as a Keras layer
self.assertFalse(layer._instrumented_keras_api)
# Test basic variable creation.
with backend.name_scope('bar'):
variable = layer.add_variable(