PR #40169: ensure model initialized on ANY trackable attr set

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40169

In particular, empty tuples should not trigger this.
Copybara import of the project:

--
17b7e16913 by Dominic Jack <thedomjack@gmail.com>:

ensure model initialized on ANY trackable attr set

--
57eccc7bc2 by Dominic Jack <thedomjack@gmail.com>:

added test

PiperOrigin-RevId: 316743715
Change-Id: I038a0261fbb3a0dac50c62a50c787bade10abb6a
This commit is contained in:
A. Unique TensorFlower 2020-06-16 13:14:37 -07:00 committed by TensorFlower Gardener
parent 1a342fb760
commit a71c78bcf9
2 changed files with 4 additions and 16 deletions

View File

@ -324,10 +324,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
super(Model, self).__setattr__(name, value)
return
if any(
isinstance(v, (base_layer.Layer, data_structures.TrackableDataStructure
)) or trackable_layer_utils.has_weights(v)
for v in nest.flatten(value)):
if all(
isinstance(v, (base_layer.Layer,
data_structures.TrackableDataStructure)) or
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
try:
self._base_model_initialized
except AttributeError:

View File

@ -3383,18 +3383,6 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
self.assertEqual([m.name for m in outer_model.metrics],
['loss', 'acc2', 'mean', 'mean1', 'mean2'])
def test_subclassed_model_with_empty_list_attr(self):
class ModelSubclass(training_module.Model):
def __init__(self):
self.empty_list = []
inputs = layers_module.Input(shape=())
outputs = inputs + 1
super(ModelSubclass, self).__init__(inputs, outputs)
ModelSubclass() # empty_list attr assignment should not raise
class BareUpdateLayer(layers_module.Layer):