Merge pull request #40169 from jackd:keras-model-attrs

PiperOrigin-RevId: 316723863
Change-Id: If1c6ea2eab56c3bbfe68a094752a05417d1cd789
This commit is contained in:
TensorFlower Gardener 2020-06-16 11:53:16 -07:00
commit 8d9fee90ba
2 changed files with 16 additions and 4 deletions

View File

@ -324,10 +324,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
super(Model, self).__setattr__(name, value)
return
if all(
isinstance(v, (base_layer.Layer,
data_structures.TrackableDataStructure)) or
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
if any(
isinstance(v, (base_layer.Layer, data_structures.TrackableDataStructure
)) or trackable_layer_utils.has_weights(v)
for v in nest.flatten(value)):
try:
self._base_model_initialized
except AttributeError:

View File

@ -3383,6 +3383,18 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
self.assertEqual([m.name for m in outer_model.metrics],
['loss', 'acc2', 'mean', 'mean1', 'mean2'])
def test_subclassed_model_with_empty_list_attr(self):
class ModelSubclass(training_module.Model):
def __init__(self):
self.empty_list = []
inputs = layers_module.Input(shape=())
outputs = inputs + 1
super(ModelSubclass, self).__init__(inputs, outputs)
ModelSubclass() # empty_list attr assignment should not raise
class BareUpdateLayer(layers_module.Layer):