Fix bug where v1 vs v2 model class swapping does not trigger correctly in rare circumstances.

Specifically: If model is subclassed and the subclass initializes like a functional model but does not explicitly extend the internal `Functional` class, class swapping according to if eager execution is enabled did not trigger on `Functional` the first time the model is constructed.

This is unlikely to be an issue in practice unless code interleaves v1 eager-disabled code and v2 code a lot, and uses subclass models that initialize like functional models.

We found it to cause rare nondeterministic test failures in our test suites, because we regularly test both legacy v1 graphs & v2 code.

PiperOrigin-RevId: 315599954
Change-Id: I20eb52c67e5af12696425b93eeebe3664f8785ea
This commit is contained in:
Tomer Kaftan 2020-06-09 18:04:47 -07:00 committed by TensorFlower Gardener
parent de0a617f4e
commit 81f28b5ad9
2 changed files with 34 additions and 0 deletions

View File

@ -134,6 +134,7 @@ def disable_multi_worker(method):
def inject_functional_model_class(cls):
"""Inject `Functional` into the hierarchy of this class if needed."""
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
if cls == Model or cls == training_v1.Model:
@ -141,6 +142,10 @@ def inject_functional_model_class(cls):
cls.__bases__ = tuple(inject_functional_model_class(base)
for base in cls.__bases__)
# Trigger any `__new__` class swapping that needed to happen on `Functional`
# but did not because functional was not in the class hierarchy.
cls.__new__(cls)
return cls

View File

@ -56,6 +56,35 @@ class SplitUtilsTest(keras_parameterized.TestCase):
self._check_model_class(model.__class__.__bases__[0])
self._check_layer_class(model)
def test_subclass_model_with_functional_init(self):
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
class MyModel(keras.Model):
pass
model = MyModel(inputs, outputs)
model_class = model.__class__.__bases__[0].__bases__[0]
self._check_model_class(model_class)
self._check_layer_class(model)
def test_subclass_model_with_functional_init_interleaved_v1_functional(self):
with ops.Graph().as_default():
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
_ = keras.Model(inputs, outputs)
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
class MyModel(keras.Model):
pass
model = MyModel(inputs, outputs)
model_class = model.__class__.__bases__[0].__bases__[0]
self._check_model_class(model_class)
self._check_layer_class(model)
def test_sequential_model(self):
model = keras.Sequential([keras.layers.Dense(1)])
model_class = model.__class__.__bases__[0].__bases__[0]