Fix run_eagerly logic in Model.compile

PiperOrigin-RevId: 307629912
Change-Id: I7852a36835ef9857adb905b0e63a8fd64bccc17b
This commit is contained in:
Thomas O'Malley 2020-04-21 10:19:25 -07:00 committed by TensorFlower Gardener
parent c5fd4efc4c
commit 6f9a5289f9
2 changed files with 8 additions and 1 deletions

View File

@ -350,7 +350,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
_keras_api_gauge.get_cell('compile').set(True)
with self.distribute_strategy.scope():
self._validate_compile(optimizer, metrics, **kwargs)
self._run_eagerly = kwargs.pop('run_eagerly', None)
self._run_eagerly = run_eagerly
self.optimizer = self._get_optimizer(optimizer)
self.compiled_loss = compile_utils.LossesContainer(

View File

@ -88,6 +88,13 @@ class TrainingTest(keras_parameterized.TestCase):
hist = model.fit(x=np.array([0.]), y=np.array([0.]))
self.assertAllClose(hist.history['loss'][0], 10000)
@keras_parameterized.run_all_keras_modes
def test_run_eagerly_setting(self):
model = sequential.Sequential([layers_module.Dense(1)])
run_eagerly = testing_utils.should_run_eagerly()
model.compile('sgd', 'mse', run_eagerly=run_eagerly)
self.assertEqual(model.run_eagerly, run_eagerly)
@keras_parameterized.run_all_keras_modes
def test_fit_and_validate_learning_phase(self):