Remove error raising in keras compile for optimizers.

PiperOrigin-RevId: 234083146
This commit is contained in:
Zhenyu Tan 2019-02-14 21:37:24 -08:00 committed by TensorFlower Gardener
parent 4938573f6b
commit e1ab41387a

View File

@ -41,7 +41,6 @@ from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import losses_utils
@ -49,7 +48,6 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -233,12 +231,6 @@ class Model(Network):
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
if self._distribution_strategy:
if not isinstance(optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
optimizer_v2.OptimizerV2)):
raise NotImplementedError(
'optimizer must be an instance of '
'tf.train.Optimizer, not a %s' % type(optimizer))
if sample_weight_mode:
raise NotImplementedError('sample_weight_mode is not supported with '
'DistributionStrategy.')
@ -250,13 +242,6 @@ class Model(Network):
'DistributionStrategy.')
loss = loss or {}
if self.run_eagerly and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
optimizer_v2.OptimizerV2)):
raise ValueError(
'When running a model in eager execution, the optimizer must be an '
'instance of tf.train.Optimizer. Received: '
'%s' % optimizer)
self.optimizer = optimizer
# We've disabled automatic dependency tracking for this method, but do want