Update Keras instrumentation, and add tests to make sure it triggers. Specifically, make legacy TF layers no longer instrument as if they are Keras layers.
PiperOrigin-RevId: 337397565 Change-Id: I6f1290f05aec8b26525806c67ded2ce5607a4cab
This commit is contained in:
parent
7c595a218b
commit
30060737d7
tensorflow/python/keras
engine
legacy_tf_layers
@ -304,6 +304,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# not available to the restoration code).
|
||||
_must_restore_from_config = False
|
||||
|
||||
def _instrument_layer_creation(self):
|
||||
self._instrumented_keras_api = False
|
||||
self._instrumented_keras_layer_class = False
|
||||
self._instrumented_keras_model_class = False
|
||||
if not getattr(self, '_disable_keras_instrumentation', False):
|
||||
keras_api_gauge.get_cell('layer').set(True)
|
||||
self._instrumented_keras_api = True
|
||||
if getattr(self, '_is_model_for_instrumentation', False):
|
||||
keras_models_gauge.get_cell(self.__class__.__name__).set(True)
|
||||
self._instrumented_keras_model_class = True
|
||||
else:
|
||||
keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
|
||||
self._instrumented_keras_layer_class = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self,
|
||||
trainable=True,
|
||||
@ -311,11 +325,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
dtype=None,
|
||||
dynamic=False,
|
||||
**kwargs):
|
||||
keras_api_gauge.get_cell('layer').set(True)
|
||||
if getattr(self, '_is_model_for_instrumentation', False):
|
||||
keras_models_gauge.get_cell(self.__class__.__name__).set(True)
|
||||
else:
|
||||
keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
|
||||
self._instrument_layer_creation()
|
||||
|
||||
# These properties should be set by the user via keyword arguments.
|
||||
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
||||
# are only applicable to input layers: do not pass these keywords
|
||||
|
@ -85,6 +85,13 @@ class InvalidLayer(base_layer.Layer):
|
||||
|
||||
class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_layer_instrumentation(self):
|
||||
layer = layers.Add()
|
||||
self.assertTrue(layer._instrumented_keras_api)
|
||||
self.assertTrue(layer._instrumented_keras_layer_class)
|
||||
self.assertFalse(layer._instrumented_keras_model_class)
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
|
@ -152,8 +152,8 @@ class Layer(base_layer.Layer):
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
|
||||
**kwargs):
|
||||
base_layer.keras_api_gauge.get_cell('layer').set(True)
|
||||
base_layer.keras_layers_gauge.get_cell(self.__class__.__name__).set(True)
|
||||
self._instrument_layer_creation()
|
||||
|
||||
# These properties should be set by the user via keyword arguments.
|
||||
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
||||
# are only applicable to input layers: do not pass these keywords
|
||||
|
@ -68,6 +68,19 @@ except ImportError:
|
||||
|
||||
class TrainingTest(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_model_instrumentation(self):
|
||||
layers = [
|
||||
layers_module.Dense(10, dtype=np.float64),
|
||||
layers_module.Dense(10, dtype=np.float64)
|
||||
]
|
||||
model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
|
||||
|
||||
self.assertTrue(model._instrumented_keras_api)
|
||||
self.assertTrue(model._instrumented_keras_model_class)
|
||||
self.assertFalse(model._instrumented_keras_layer_class)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_fit_training_arg(self):
|
||||
|
@ -210,6 +210,9 @@ class Layer(base_layer.Layer):
|
||||
if 'autocast' not in kwargs:
|
||||
kwargs['autocast'] = False
|
||||
|
||||
# Mark that legacy layers should not be instrumented as Keras usage
|
||||
self._disable_keras_instrumentation = True
|
||||
|
||||
super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
|
||||
**kwargs)
|
||||
|
||||
|
@ -60,6 +60,9 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
||||
layer = base_layers.Layer(name='my_layer', trainable=False)
|
||||
self.assertEqual(layer.trainable, False)
|
||||
|
||||
# Assert that the layer was not instrumented as a Keras layer
|
||||
self.assertFalse(layer._instrumented_keras_api)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testInt64Layer(self):
|
||||
layer = base_layers.Layer(name='my_layer', dtype='int64')
|
||||
@ -83,6 +86,8 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
with base_layers.keras_style_scope():
|
||||
layer = base_layers.Layer(name='my_layer')
|
||||
# Assert that the layer was not instrumented as a Keras layer
|
||||
self.assertFalse(layer._instrumented_keras_api)
|
||||
# Test basic variable creation.
|
||||
with backend.name_scope('bar'):
|
||||
variable = layer.add_variable(
|
||||
|
Loading…
Reference in New Issue
Block a user