Make the default for run_eagerly
the value of
`tf.config.experimental_run_functions_eagerly` PiperOrigin-RevId: 251570860
This commit is contained in:
parent
a169923759
commit
0054b6aaa2
@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -403,7 +404,9 @@ class Model(network.Network):
|
||||
'is enabled.')
|
||||
if not self.dynamic:
|
||||
if self._run_eagerly is None:
|
||||
return False
|
||||
# Respect `tf.config.experimental_run_functions_eagerly` unless
|
||||
# `run_eagerly` was explicitly passed to `compile`.
|
||||
return def_function.RUN_FUNCTIONS_EAGERLY
|
||||
else:
|
||||
return self._run_eagerly
|
||||
else:
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -81,6 +82,16 @@ class CompileTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(model.loss_functions[i].fn, loss_list[i])
|
||||
self.assertAllEqual(model._loss_weights_list, [1.] * len(loss_list))
|
||||
|
||||
def test_respect_run_functions_eagerly(self):
|
||||
with context.eager_mode():
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=2, input_dim=3)
|
||||
model.compile('sgd', 'mse')
|
||||
def_function.run_functions_eagerly(True)
|
||||
self.assertTrue(model.run_eagerly)
|
||||
def_function.run_functions_eagerly(False)
|
||||
self.assertFalse(model.run_eagerly)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters(('loss_string', 'mse'),
|
||||
('loss_function', losses.mean_squared_error),
|
||||
|
Loading…
Reference in New Issue
Block a user