PR #40169: ensure model initialized on ANY trackable attr set
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40169 In particular, empty tuples should not trigger this. Copybara import of the project: --17b7e16913
by Dominic Jack <thedomjack@gmail.com>: ensure model initialized on ANY trackable attr set --57eccc7bc2
by Dominic Jack <thedomjack@gmail.com>: added test PiperOrigin-RevId: 316743715 Change-Id: I038a0261fbb3a0dac50c62a50c787bade10abb6a
This commit is contained in:
parent
1a342fb760
commit
a71c78bcf9
@ -324,10 +324,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
super(Model, self).__setattr__(name, value)
|
||||
return
|
||||
|
||||
if any(
|
||||
isinstance(v, (base_layer.Layer, data_structures.TrackableDataStructure
|
||||
)) or trackable_layer_utils.has_weights(v)
|
||||
for v in nest.flatten(value)):
|
||||
if all(
|
||||
isinstance(v, (base_layer.Layer,
|
||||
data_structures.TrackableDataStructure)) or
|
||||
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
|
||||
try:
|
||||
self._base_model_initialized
|
||||
except AttributeError:
|
||||
|
@ -3383,18 +3383,6 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
|
||||
self.assertEqual([m.name for m in outer_model.metrics],
|
||||
['loss', 'acc2', 'mean', 'mean1', 'mean2'])
|
||||
|
||||
def test_subclassed_model_with_empty_list_attr(self):
|
||||
|
||||
class ModelSubclass(training_module.Model):
|
||||
|
||||
def __init__(self):
|
||||
self.empty_list = []
|
||||
inputs = layers_module.Input(shape=())
|
||||
outputs = inputs + 1
|
||||
super(ModelSubclass, self).__init__(inputs, outputs)
|
||||
|
||||
ModelSubclass() # empty_list attr assignment should not raise
|
||||
|
||||
|
||||
class BareUpdateLayer(layers_module.Layer):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user