Make sure model saving does not log deprecation warnings.

PiperOrigin-RevId: 316191312
Change-Id: I4545782287265bc238e9dd759ebd5571de415b18
This commit is contained in:
Tomer Kaftan 2020-06-12 15:42:46 -07:00 committed by TensorFlower Gardener
parent b1a61df7a0
commit 69cbcb8c20
5 changed files with 84 additions and 8 deletions

View File

@ -420,6 +420,30 @@ def set_learning_phase(value):
value: Learning phase value, either 0 or 1 (integers).
0 = test, 1 = train
Raises:
ValueError: if `value` is neither `0` nor `1`.
"""
deprecated_internal_set_learning_phase(value)
def deprecated_internal_set_learning_phase(value):
"""A deprecated internal implementation of set_learning_phase.
This method is an internal-only version of `set_learning_phase` that
does not raise a deprecation error. It is required because
saved_model needs to keep working with user code that uses the deprecated
learning phase methods until those apis are fully removed from the public api.
Specifically SavedModel saving needs to make sure the learning phase is 0
during tracing even if users overwrote it to a different value.
But, we don't want to raise deprecation warnings for users when savedmodel
sets learning phase just for compatibility with code that relied on
explicitly setting the learning phase for other values.
Arguments:
value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
Raises:
ValueError: if `value` is neither `0` nor `1`.
"""
@ -435,6 +459,9 @@ def set_learning_phase(value):
_GRAPH_LEARNING_PHASES[get_graph()] = value
@deprecated('2020-10-11',
'Simply pass a True/False value to the `training` argument '
'of the `__call__` method of your layer or model.')
@keras_export('keras.backend.learning_phase_scope')
@tf_contextlib.contextmanager
def learning_phase_scope(value):
@ -449,6 +476,35 @@ def learning_phase_scope(value):
Yields:
None.
Raises:
ValueError: if `value` is neither `0` nor `1`.
"""
with deprecated_internal_learning_phase_scope(value):
try:
yield
finally:
pass
@tf_contextlib.contextmanager
def deprecated_internal_learning_phase_scope(value):
"""An internal-only version of `learning_phase_scope`.
Unlike the public method, this method does not raise a deprecation warning.
This is needed because saved model saving needs to set learning phase
to maintain compatibility
with code that sets/gets the learning phase, but saved model
saving itself shouldn't raise a deprecation warning.
We can get rid of this method and its usages when the public api is
removed.
Arguments:
value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
Yields:
None.
Raises:
ValueError: if `value` is neither `0` nor `1`.
"""
@ -464,7 +520,7 @@ def learning_phase_scope(value):
learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
try:
set_learning_phase(value)
deprecated_internal_set_learning_phase(value)
yield
finally:
# Restore learning phase to initial value.

View File

@ -68,9 +68,11 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
orig_optimizer = model.optimizer
model.optimizer = None
# Trace all functions and signatures with `training=0` instead of using the
# default learning phase placeholder.
with K.learning_phase_scope(0):
# Trace all functions and signatures with `training=0` instead of using an
# already-set learning phase placeholder.
# This is needed for compatibility reasons until learning phase setting
# is removed from the public apis.
with K.deprecated_internal_learning_phase_scope(0):
# When saving a model involving batch norm layer within a strategy scope,
# the replica context is not available when calling `add_update()`, and thus
# we use the default replica context here.

View File

@ -414,7 +414,7 @@ class LayerCallCollection(object):
if self._expects_training_arg:
def trace_with_training(value, fn=fn):
utils.set_training_arg(value, self._training_arg_index, args, kwargs)
with K.learning_phase_scope(value):
with K.deprecated_internal_learning_phase_scope(value):
fn.get_concrete_function(*args, **kwargs)
trace_with_training(True)

View File

@ -26,6 +26,7 @@ from __future__ import print_function
import os
import shutil
import sys
from absl.testing import parameterized
import numpy as np
@ -375,9 +376,16 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
keras.layers.BatchNormalization(input_shape=(1,)))
self.evaluate(variables.variables_initializer(model.variables))
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
self.evaluate(variables.variables_initializer(loaded.variables))
with self.captureWritesToStream(sys.stderr) as captured_logs:
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
# Assert that saving does not log deprecation warnings
# (even if it needs to set learning phase for compat reasons)
if context.executing_eagerly():
self.assertNotIn('deprecated', captured_logs.contents())
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0])

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import functools
import weakref
from absl import logging
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -101,12 +103,20 @@ class AutoTrackable(base.Trackable):
"""Return a dict of `Function`s of a trackable."""
functions = {}
for attribute_name in dir(self):
# We get the attributes, suppressing warnings and exceptions.
logging_verbosity = logging.get_verbosity()
try:
logging.set_verbosity(logging.FATAL)
attribute_value = getattr(self, attribute_name, None)
except Exception: # pylint: disable=broad-except
# We really don't want to throw an exception just because some object's
# attribute accessor is broken.
attribute_value = None
finally:
# We reset the verbosity setting in a `finally` block, to make
# sure it always happens, even if we make the exception catching above
# be less broad.
logging.set_verbosity(logging_verbosity)
if isinstance(attribute_value, (def_function.Function,
defun.ConcreteFunction)):
functions[attribute_name] = attribute_value