The layer class was incorrectly tracked by layer._layer, which should
only track layer instance. This should be mitigated once b/110718070 is
fixed.

PiperOrigin-RevId: 243103748
This commit is contained in:
Scott Zhu 2019-04-11 11:09:06 -07:00 committed by TensorFlower Gardener
parent 90a4b1ecfe
commit 9d724a8e60
2 changed files with 16 additions and 2 deletions

View File

@ -541,6 +541,19 @@ class NestedTrackingTest(test.TestCase):
self.assertEqual([layer.v], layer.variables)
def test_layer_class_not_tracked_as_sublayer(self):
# See https://github.com/tensorflow/tensorflow/issues/27431 for details.
class LayerWithClassAttribute(keras.layers.Layer):
def __init__(self):
super(LayerWithClassAttribute, self).__init__()
self.layer_fn = keras.layers.Dense
layer = LayerWithClassAttribute()
self.assertEmpty(layer.variables)
self.assertEmpty(layer.submodules)
@test_util.run_all_in_graph_and_eager_modes
class NameScopingTest(keras_parameterized.TestCase):

View File

@ -27,14 +27,15 @@ from tensorflow.python.training.tracking import object_identity
def is_layer(obj):
"""Implicit check for Layer-like objects."""
# TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
return hasattr(obj, "_is_layer")
return hasattr(obj, "_is_layer") and not isinstance(obj, type)
def has_weights(obj):
"""Implicit check for Layer-like objects."""
# TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
return (hasattr(obj, "trainable_weights")
and hasattr(obj, "non_trainable_weights"))
and hasattr(obj, "non_trainable_weights")
and not isinstance(obj, type))
def filter_empty_layer_containers(layer_list):