Make the default for run_eagerly
the value of
`tf.config.experimental_run_functions_eagerly` PiperOrigin-RevId: 251570860
This commit is contained in:
parent
a169923759
commit
0054b6aaa2
@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import monitoring
|
from tensorflow.python.eager import monitoring
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
from tensorflow.python.framework import composite_tensor_utils
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -403,7 +404,9 @@ class Model(network.Network):
|
|||||||
'is enabled.')
|
'is enabled.')
|
||||||
if not self.dynamic:
|
if not self.dynamic:
|
||||||
if self._run_eagerly is None:
|
if self._run_eagerly is None:
|
||||||
return False
|
# Respect `tf.config.experimental_run_functions_eagerly` unless
|
||||||
|
# `run_eagerly` was explicitly passed to `compile`.
|
||||||
|
return def_function.RUN_FUNCTIONS_EAGERLY
|
||||||
else:
|
else:
|
||||||
return self._run_eagerly
|
return self._run_eagerly
|
||||||
else:
|
else:
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python import keras
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -81,6 +82,16 @@ class CompileTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(model.loss_functions[i].fn, loss_list[i])
|
self.assertEqual(model.loss_functions[i].fn, loss_list[i])
|
||||||
self.assertAllEqual(model._loss_weights_list, [1.] * len(loss_list))
|
self.assertAllEqual(model._loss_weights_list, [1.] * len(loss_list))
|
||||||
|
|
||||||
|
def test_respect_run_functions_eagerly(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
model = testing_utils.get_small_sequential_mlp(
|
||||||
|
num_hidden=10, num_classes=2, input_dim=3)
|
||||||
|
model.compile('sgd', 'mse')
|
||||||
|
def_function.run_functions_eagerly(True)
|
||||||
|
self.assertTrue(model.run_eagerly)
|
||||||
|
def_function.run_functions_eagerly(False)
|
||||||
|
self.assertFalse(model.run_eagerly)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
@parameterized.named_parameters(('loss_string', 'mse'),
|
@parameterized.named_parameters(('loss_string', 'mse'),
|
||||||
('loss_function', losses.mean_squared_error),
|
('loss_function', losses.mean_squared_error),
|
||||||
|
Loading…
Reference in New Issue
Block a user