Remove error raising in keras compile for optimizers.
PiperOrigin-RevId: 234083146
This commit is contained in:
parent
4938573f6b
commit
e1ab41387a
@ -41,7 +41,6 @@ from tensorflow.python.keras.engine import training_eager
|
||||
from tensorflow.python.keras.engine import training_generator
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.engine.network import Network
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
@ -49,7 +48,6 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import optimizer as tf_optimizer_module
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -233,12 +231,6 @@ class Model(Network):
|
||||
# Validate that arguments passed by the user to `compile` are supported by
|
||||
# DistributionStrategy.
|
||||
if self._distribution_strategy:
|
||||
if not isinstance(optimizer,
|
||||
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
|
||||
optimizer_v2.OptimizerV2)):
|
||||
raise NotImplementedError(
|
||||
'optimizer must be an instance of '
|
||||
'tf.train.Optimizer, not a %s' % type(optimizer))
|
||||
if sample_weight_mode:
|
||||
raise NotImplementedError('sample_weight_mode is not supported with '
|
||||
'DistributionStrategy.')
|
||||
@ -250,13 +242,6 @@ class Model(Network):
|
||||
'DistributionStrategy.')
|
||||
|
||||
loss = loss or {}
|
||||
if self.run_eagerly and not isinstance(
|
||||
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
|
||||
optimizer_v2.OptimizerV2)):
|
||||
raise ValueError(
|
||||
'When running a model in eager execution, the optimizer must be an '
|
||||
'instance of tf.train.Optimizer. Received: '
|
||||
'%s' % optimizer)
|
||||
|
||||
self.optimizer = optimizer
|
||||
# We've disabled automatic dependency tracking for this method, but do want
|
||||
|
Loading…
Reference in New Issue
Block a user