Remove irrelevant warnings.

- tf.train is not longer part of the API.
- `Sequence`-based generator training is a LTS API and should not print warning when used the way it's intended to be used.

PiperOrigin-RevId: 351188267
Change-Id: Ibf8e47f4cc475351df732f91a8cb2105fc42ea13
This commit is contained in:
Francois Chollet 2021-01-11 10:47:29 -08:00 committed by TensorFlower Gardener
parent a17e1fc7ef
commit 4544864a34
3 changed files with 0 additions and 31 deletions
tensorflow/python/keras

View File

@ -2160,16 +2160,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
session = None
else:
session = backend.get_session()
optimizer = getattr(self, 'optimizer', None)
if (optimizer
and not isinstance(optimizer, trackable.Trackable)):
logging.warning(
('This model was compiled with a Keras optimizer (%s) but is being '
'saved in TensorFlow format with `save_weights`. The model\'s '
'weights will be saved, but unlike with TensorFlow optimizers in '
'the TensorFlow format the optimizer\'s state will not be '
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
% (optimizer,))
self._trackable_saver.save(filepath, session=session, options=options)
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
checkpoint_management.update_checkpoint_state_internal(

View File

@ -40,7 +40,6 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as training_module
from tensorflow.python.training.tracking import util as trackable
@ -402,21 +401,6 @@ class SubclassedModel(training.Model):
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
def test_keras_optimizer_warning(self):
graph = ops.Graph()
with graph.as_default(), self.session(graph):
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
model.compile(loss='mse', optimizer=optimizer_v1.Adam(), metrics=['acc'])
if not ops.executing_eagerly_outside_functions():
model._make_train_function()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
with test.mock.patch.object(logging, 'warning') as mock_log:
model.save_weights(prefix)
self.assertRegex(str(mock_log.call_args), 'Keras optimizer')
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_tensorflow_format_overwrite(self):
with self.cached_session() as session:

View File

@ -46,7 +46,6 @@ from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
@ -530,10 +529,6 @@ def get_pool_class(use_multiprocessing):
global _FORCE_THREADPOOL
if not use_multiprocessing or _FORCE_THREADPOOL:
return multiprocessing.dummy.Pool # ThreadPool
logging.warning(
'multiprocessing can interact badly with TensorFlow, causing '
'nondeterministic deadlocks. For high performance data pipelines tf.data '
'is recommended.')
return multiprocessing.Pool