Fix run_eagerly logic in Model.compile
PiperOrigin-RevId: 307629912 Change-Id: I7852a36835ef9857adb905b0e63a8fd64bccc17b
This commit is contained in:
parent
c5fd4efc4c
commit
6f9a5289f9
@ -350,7 +350,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
_keras_api_gauge.get_cell('compile').set(True)
|
||||
with self.distribute_strategy.scope():
|
||||
self._validate_compile(optimizer, metrics, **kwargs)
|
||||
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
||||
self._run_eagerly = run_eagerly
|
||||
|
||||
self.optimizer = self._get_optimizer(optimizer)
|
||||
self.compiled_loss = compile_utils.LossesContainer(
|
||||
|
@ -88,6 +88,13 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
hist = model.fit(x=np.array([0.]), y=np.array([0.]))
|
||||
self.assertAllClose(hist.history['loss'][0], 10000)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_run_eagerly_setting(self):
|
||||
model = sequential.Sequential([layers_module.Dense(1)])
|
||||
run_eagerly = testing_utils.should_run_eagerly()
|
||||
model.compile('sgd', 'mse', run_eagerly=run_eagerly)
|
||||
self.assertEqual(model.run_eagerly, run_eagerly)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_fit_and_validate_learning_phase(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user