diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 146199fd7c1..5b838295bee 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -541,6 +541,19 @@ class NestedTrackingTest(test.TestCase): self.assertEqual([layer.v], layer.variables) + def test_layer_class_not_tracked_as_sublayer(self): + # See https://github.com/tensorflow/tensorflow/issues/27431 for details. + + class LayerWithClassAttribute(keras.layers.Layer): + + def __init__(self): + super(LayerWithClassAttribute, self).__init__() + self.layer_fn = keras.layers.Dense + + layer = LayerWithClassAttribute() + self.assertEmpty(layer.variables) + self.assertEmpty(layer.submodules) + @test_util.run_all_in_graph_and_eager_modes class NameScopingTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/training/tracking/layer_utils.py b/tensorflow/python/training/tracking/layer_utils.py index 818563c32fa..66f8e3a4f69 100644 --- a/tensorflow/python/training/tracking/layer_utils.py +++ b/tensorflow/python/training/tracking/layer_utils.py @@ -27,14 +27,15 @@ from tensorflow.python.training.tracking import object_identity def is_layer(obj): """Implicit check for Layer-like objects.""" # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). - return hasattr(obj, "_is_layer") + return hasattr(obj, "_is_layer") and not isinstance(obj, type) def has_weights(obj): """Implicit check for Layer-like objects.""" # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). return (hasattr(obj, "trainable_weights") - and hasattr(obj, "non_trainable_weights")) + and hasattr(obj, "non_trainable_weights") + and not isinstance(obj, type)) def filter_empty_layer_containers(layer_list):