diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 10d36bc09da..391c695b18f 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -420,6 +420,30 @@ def set_learning_phase(value): value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train + Raises: + ValueError: if `value` is neither `0` nor `1`. + """ + deprecated_internal_set_learning_phase(value) + + +def deprecated_internal_set_learning_phase(value): + """A deprecated internal implementation of set_learning_phase. + + This method is an internal-only version of `set_learning_phase` that + does not raise a deprecation error. It is required because + saved_model needs to keep working with user code that uses the deprecated + learning phase methods until those apis are fully removed from the public api. + + Specifically SavedModel saving needs to make sure the learning phase is 0 + during tracing even if users overwrote it to a different value. + + But, we don't want to raise deprecation warnings for users when savedmodel + sets learning phase just for compatibility with code that relied on + explicitly setting the learning phase for other values. + + Arguments: + value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train + Raises: ValueError: if `value` is neither `0` nor `1`. """ @@ -435,6 +459,9 @@ def set_learning_phase(value): _GRAPH_LEARNING_PHASES[get_graph()] = value +@deprecated('2020-10-11', + 'Simply pass a True/False value to the `training` argument ' + 'of the `__call__` method of your layer or model.') @keras_export('keras.backend.learning_phase_scope') @tf_contextlib.contextmanager def learning_phase_scope(value): @@ -449,6 +476,35 @@ def learning_phase_scope(value): Yields: None. + Raises: + ValueError: if `value` is neither `0` nor `1`. + """ + with deprecated_internal_learning_phase_scope(value): + try: + yield + finally: + pass + + +@tf_contextlib.contextmanager +def deprecated_internal_learning_phase_scope(value): + """An internal-only version of `learning_phase_scope`. + + Unlike the public method, this method does not raise a deprecation warning. + This is needed because saved model saving needs to set learning phase + to maintain compatibility + with code that sets/gets the learning phase, but saved model + saving itself shouldn't raise a deprecation warning. + + We can get rid of this method and its usages when the public api is + removed. + + Arguments: + value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train + + Yields: + None. + Raises: ValueError: if `value` is neither `0` nor `1`. """ @@ -464,7 +520,7 @@ def learning_phase_scope(value): learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set try: - set_learning_phase(value) + deprecated_internal_set_learning_phase(value) yield finally: # Restore learning phase to initial value. diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py index 9338b4b5434..9d4ca5e2c59 100644 --- a/tensorflow/python/keras/saving/saved_model/save.py +++ b/tensorflow/python/keras/saving/saved_model/save.py @@ -68,9 +68,11 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, orig_optimizer = model.optimizer model.optimizer = None - # Trace all functions and signatures with `training=0` instead of using the - # default learning phase placeholder. - with K.learning_phase_scope(0): + # Trace all functions and signatures with `training=0` instead of using an + # already-set learning phase placeholder. + # This is needed for compatibility reasons until learning phase setting + # is removed from the public apis. + with K.deprecated_internal_learning_phase_scope(0): # When saving a model involving batch norm layer within a strategy scope, # the replica context is not available when calling `add_update()`, and thus # we use the default replica context here. diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index f2e6c967b14..c2e4f96e127 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -414,7 +414,7 @@ class LayerCallCollection(object): if self._expects_training_arg: def trace_with_training(value, fn=fn): utils.set_training_arg(value, self._training_arg_index, args, kwargs) - with K.learning_phase_scope(value): + with K.deprecated_internal_learning_phase_scope(value): fn.get_concrete_function(*args, **kwargs) trace_with_training(True) diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index c6cc2f7a1d5..c208805686d 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -26,6 +26,7 @@ from __future__ import print_function import os import shutil +import sys from absl.testing import parameterized import numpy as np @@ -375,9 +376,16 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): keras.layers.BatchNormalization(input_shape=(1,))) self.evaluate(variables.variables_initializer(model.variables)) saved_model_dir = self._save_model_dir() - model.save(saved_model_dir, save_format='tf') - loaded = keras_load.load(saved_model_dir) - self.evaluate(variables.variables_initializer(loaded.variables)) + + with self.captureWritesToStream(sys.stderr) as captured_logs: + model.save(saved_model_dir, save_format='tf') + loaded = keras_load.load(saved_model_dir) + + # Assert that saving does not log deprecation warnings + # (even if it needs to set learning phase for compat reasons) + if context.executing_eagerly(): + self.assertNotIn('deprecated', captured_logs.contents()) + input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0]) diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index bc55aee1eff..553f0ec73bf 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -20,6 +20,8 @@ from __future__ import print_function import functools import weakref +from absl import logging + from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun @@ -101,12 +103,20 @@ class AutoTrackable(base.Trackable): """Return a dict of `Function`s of a trackable.""" functions = {} for attribute_name in dir(self): + # We get the attributes, suppressing warnings and exceptions. + logging_verbosity = logging.get_verbosity() try: + logging.set_verbosity(logging.FATAL) attribute_value = getattr(self, attribute_name, None) except Exception: # pylint: disable=broad-except # We really don't want to throw an exception just because some object's # attribute accessor is broken. attribute_value = None + finally: + # We reset the verbosity setting in a `finally` block, to make + # sure it always happens, even if we make the exception catching above + # be less broad. + logging.set_verbosity(logging_verbosity) if isinstance(attribute_value, (def_function.Function, defun.ConcreteFunction)): functions[attribute_name] = attribute_value