Remove irrelevant warnings.
- tf.train is not longer part of the API. - `Sequence`-based generator training is a LTS API and should not print warning when used the way it's intended to be used. PiperOrigin-RevId: 351188267 Change-Id: Ibf8e47f4cc475351df732f91a8cb2105fc42ea13
This commit is contained in:
parent
a17e1fc7ef
commit
4544864a34
tensorflow/python/keras
@ -2160,16 +2160,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
session = None
|
||||
else:
|
||||
session = backend.get_session()
|
||||
optimizer = getattr(self, 'optimizer', None)
|
||||
if (optimizer
|
||||
and not isinstance(optimizer, trackable.Trackable)):
|
||||
logging.warning(
|
||||
('This model was compiled with a Keras optimizer (%s) but is being '
|
||||
'saved in TensorFlow format with `save_weights`. The model\'s '
|
||||
'weights will be saved, but unlike with TensorFlow optimizers in '
|
||||
'the TensorFlow format the optimizer\'s state will not be '
|
||||
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
|
||||
% (optimizer,))
|
||||
self._trackable_saver.save(filepath, session=session, options=options)
|
||||
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
|
@ -40,7 +40,6 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import training as training_module
|
||||
from tensorflow.python.training.tracking import util as trackable
|
||||
@ -402,21 +401,6 @@ class SubclassedModel(training.Model):
|
||||
|
||||
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_keras_optimizer_warning(self):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), self.session(graph):
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(2, input_shape=(3,)))
|
||||
model.add(keras.layers.Dense(3))
|
||||
model.compile(loss='mse', optimizer=optimizer_v1.Adam(), metrics=['acc'])
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
model._make_train_function()
|
||||
temp_dir = self.get_temp_dir()
|
||||
prefix = os.path.join(temp_dir, 'ckpt')
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
model.save_weights(prefix)
|
||||
self.assertRegex(str(mock_log.call_args), 'Keras optimizer')
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_tensorflow_format_overwrite(self):
|
||||
with self.cached_session() as session:
|
||||
|
@ -46,7 +46,6 @@ from six.moves.urllib.request import urlopen
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -530,10 +529,6 @@ def get_pool_class(use_multiprocessing):
|
||||
global _FORCE_THREADPOOL
|
||||
if not use_multiprocessing or _FORCE_THREADPOOL:
|
||||
return multiprocessing.dummy.Pool # ThreadPool
|
||||
logging.warning(
|
||||
'multiprocessing can interact badly with TensorFlow, causing '
|
||||
'nondeterministic deadlocks. For high performance data pipelines tf.data '
|
||||
'is recommended.')
|
||||
return multiprocessing.Pool
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user