From 81f28b5ad92df51435d8e800bab0ddd85a77ffd8 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Tue, 9 Jun 2020 18:04:47 -0700 Subject: [PATCH] Fix bug where v1 vs v2 model class swapping does not trigger correctly in rare circumstances. Specifically: If model is subclassed and the subclass initializes like a functional model but does not explicitly extend the internal `Functional` class, class swapping according to if eager execution is enabled did not trigger on `Functional` the first time the model is constructed. This is unlikely to be an issue in practice unless code interleaves v1 eager-disabled code and v2 code a lot, and uses subclass models that initialize like functional models. We found it to cause rare nondeterministic test failures in our test suites, because we regularly test both legacy v1 graphs & v2 code. PiperOrigin-RevId: 315599954 Change-Id: I20eb52c67e5af12696425b93eeebe3664f8785ea --- tensorflow/python/keras/engine/training.py | 5 ++++ .../python/keras/utils/version_utils_test.py | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 29ff31d56db..d7918f1a1e1 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -134,6 +134,7 @@ def disable_multi_worker(method): def inject_functional_model_class(cls): + """Inject `Functional` into the hierarchy of this class if needed.""" from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top if cls == Model or cls == training_v1.Model: @@ -141,6 +142,10 @@ def inject_functional_model_class(cls): cls.__bases__ = tuple(inject_functional_model_class(base) for base in cls.__bases__) + # Trigger any `__new__` class swapping that needed to happen on `Functional` + # but did not because functional was not in the class hierarchy. + cls.__new__(cls) + return cls diff --git a/tensorflow/python/keras/utils/version_utils_test.py b/tensorflow/python/keras/utils/version_utils_test.py index 0a3cd53f3c0..41370e316af 100644 --- a/tensorflow/python/keras/utils/version_utils_test.py +++ b/tensorflow/python/keras/utils/version_utils_test.py @@ -56,6 +56,35 @@ class SplitUtilsTest(keras_parameterized.TestCase): self._check_model_class(model.__class__.__bases__[0]) self._check_layer_class(model) + def test_subclass_model_with_functional_init(self): + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + + class MyModel(keras.Model): + pass + + model = MyModel(inputs, outputs) + model_class = model.__class__.__bases__[0].__bases__[0] + self._check_model_class(model_class) + self._check_layer_class(model) + + def test_subclass_model_with_functional_init_interleaved_v1_functional(self): + with ops.Graph().as_default(): + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + _ = keras.Model(inputs, outputs) + + inputs = keras.Input(10) + outputs = keras.layers.Dense(1)(inputs) + + class MyModel(keras.Model): + pass + + model = MyModel(inputs, outputs) + model_class = model.__class__.__bases__[0].__bases__[0] + self._check_model_class(model_class) + self._check_layer_class(model) + def test_sequential_model(self): model = keras.Sequential([keras.layers.Dense(1)]) model_class = model.__class__.__bases__[0].__bases__[0]