Fix bug where v1 vs v2 model class swapping does not trigger correctly in rare circumstances.
Specifically: If model is subclassed and the subclass initializes like a functional model but does not explicitly extend the internal `Functional` class, class swapping according to if eager execution is enabled did not trigger on `Functional` the first time the model is constructed. This is unlikely to be an issue in practice unless code interleaves v1 eager-disabled code and v2 code a lot, and uses subclass models that initialize like functional models. We found it to cause rare nondeterministic test failures in our test suites, because we regularly test both legacy v1 graphs & v2 code. PiperOrigin-RevId: 315599954 Change-Id: I20eb52c67e5af12696425b93eeebe3664f8785ea
This commit is contained in:
parent
de0a617f4e
commit
81f28b5ad9
@ -134,6 +134,7 @@ def disable_multi_worker(method):
|
||||
|
||||
|
||||
def inject_functional_model_class(cls):
|
||||
"""Inject `Functional` into the hierarchy of this class if needed."""
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
|
||||
if cls == Model or cls == training_v1.Model:
|
||||
@ -141,6 +142,10 @@ def inject_functional_model_class(cls):
|
||||
|
||||
cls.__bases__ = tuple(inject_functional_model_class(base)
|
||||
for base in cls.__bases__)
|
||||
# Trigger any `__new__` class swapping that needed to happen on `Functional`
|
||||
# but did not because functional was not in the class hierarchy.
|
||||
cls.__new__(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
|
@ -56,6 +56,35 @@ class SplitUtilsTest(keras_parameterized.TestCase):
|
||||
self._check_model_class(model.__class__.__bases__[0])
|
||||
self._check_layer_class(model)
|
||||
|
||||
def test_subclass_model_with_functional_init(self):
|
||||
inputs = keras.Input(10)
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
|
||||
class MyModel(keras.Model):
|
||||
pass
|
||||
|
||||
model = MyModel(inputs, outputs)
|
||||
model_class = model.__class__.__bases__[0].__bases__[0]
|
||||
self._check_model_class(model_class)
|
||||
self._check_layer_class(model)
|
||||
|
||||
def test_subclass_model_with_functional_init_interleaved_v1_functional(self):
|
||||
with ops.Graph().as_default():
|
||||
inputs = keras.Input(10)
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
_ = keras.Model(inputs, outputs)
|
||||
|
||||
inputs = keras.Input(10)
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
|
||||
class MyModel(keras.Model):
|
||||
pass
|
||||
|
||||
model = MyModel(inputs, outputs)
|
||||
model_class = model.__class__.__bases__[0].__bases__[0]
|
||||
self._check_model_class(model_class)
|
||||
self._check_layer_class(model)
|
||||
|
||||
def test_sequential_model(self):
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
model_class = model.__class__.__bases__[0].__bases__[0]
|
||||
|
Loading…
Reference in New Issue
Block a user