diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index dd55b32a115..02df39ff418 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring from tensorflow.python.framework import composite_tensor_utils from tensorflow.python.framework import constant_op @@ -403,7 +404,9 @@ class Model(network.Network): 'is enabled.') if not self.dynamic: if self._run_eagerly is None: - return False + # Respect `tf.config.experimental_run_functions_eagerly` unless + # `run_eagerly` was explicitly passed to `compile`. + return def_function.RUN_FUNCTIONS_EAGERLY else: return self._run_eagerly else: diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 2d0bcda0336..9410b8df0f8 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -30,6 +30,7 @@ from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -81,6 +82,16 @@ class CompileTest(keras_parameterized.TestCase): self.assertEqual(model.loss_functions[i].fn, loss_list[i]) self.assertAllEqual(model._loss_weights_list, [1.] * len(loss_list)) + def test_respect_run_functions_eagerly(self): + with context.eager_mode(): + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=2, input_dim=3) + model.compile('sgd', 'mse') + def_function.run_functions_eagerly(True) + self.assertTrue(model.run_eagerly) + def_function.run_functions_eagerly(False) + self.assertFalse(model.run_eagerly) + @keras_parameterized.run_all_keras_modes @parameterized.named_parameters(('loss_string', 'mse'), ('loss_function', losses.mean_squared_error),