Make sure model saving does not log deprecation warnings.
PiperOrigin-RevId: 316191312 Change-Id: I4545782287265bc238e9dd759ebd5571de415b18
This commit is contained in:
parent
b1a61df7a0
commit
69cbcb8c20
@ -420,6 +420,30 @@ def set_learning_phase(value):
|
||||
value: Learning phase value, either 0 or 1 (integers).
|
||||
0 = test, 1 = train
|
||||
|
||||
Raises:
|
||||
ValueError: if `value` is neither `0` nor `1`.
|
||||
"""
|
||||
deprecated_internal_set_learning_phase(value)
|
||||
|
||||
|
||||
def deprecated_internal_set_learning_phase(value):
|
||||
"""A deprecated internal implementation of set_learning_phase.
|
||||
|
||||
This method is an internal-only version of `set_learning_phase` that
|
||||
does not raise a deprecation error. It is required because
|
||||
saved_model needs to keep working with user code that uses the deprecated
|
||||
learning phase methods until those apis are fully removed from the public api.
|
||||
|
||||
Specifically SavedModel saving needs to make sure the learning phase is 0
|
||||
during tracing even if users overwrote it to a different value.
|
||||
|
||||
But, we don't want to raise deprecation warnings for users when savedmodel
|
||||
sets learning phase just for compatibility with code that relied on
|
||||
explicitly setting the learning phase for other values.
|
||||
|
||||
Arguments:
|
||||
value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
|
||||
|
||||
Raises:
|
||||
ValueError: if `value` is neither `0` nor `1`.
|
||||
"""
|
||||
@ -435,6 +459,9 @@ def set_learning_phase(value):
|
||||
_GRAPH_LEARNING_PHASES[get_graph()] = value
|
||||
|
||||
|
||||
@deprecated('2020-10-11',
|
||||
'Simply pass a True/False value to the `training` argument '
|
||||
'of the `__call__` method of your layer or model.')
|
||||
@keras_export('keras.backend.learning_phase_scope')
|
||||
@tf_contextlib.contextmanager
|
||||
def learning_phase_scope(value):
|
||||
@ -449,6 +476,35 @@ def learning_phase_scope(value):
|
||||
Yields:
|
||||
None.
|
||||
|
||||
Raises:
|
||||
ValueError: if `value` is neither `0` nor `1`.
|
||||
"""
|
||||
with deprecated_internal_learning_phase_scope(value):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def deprecated_internal_learning_phase_scope(value):
|
||||
"""An internal-only version of `learning_phase_scope`.
|
||||
|
||||
Unlike the public method, this method does not raise a deprecation warning.
|
||||
This is needed because saved model saving needs to set learning phase
|
||||
to maintain compatibility
|
||||
with code that sets/gets the learning phase, but saved model
|
||||
saving itself shouldn't raise a deprecation warning.
|
||||
|
||||
We can get rid of this method and its usages when the public api is
|
||||
removed.
|
||||
|
||||
Arguments:
|
||||
value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
|
||||
|
||||
Yields:
|
||||
None.
|
||||
|
||||
Raises:
|
||||
ValueError: if `value` is neither `0` nor `1`.
|
||||
"""
|
||||
@ -464,7 +520,7 @@ def learning_phase_scope(value):
|
||||
|
||||
learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
|
||||
try:
|
||||
set_learning_phase(value)
|
||||
deprecated_internal_set_learning_phase(value)
|
||||
yield
|
||||
finally:
|
||||
# Restore learning phase to initial value.
|
||||
|
||||
@ -68,9 +68,11 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
orig_optimizer = model.optimizer
|
||||
model.optimizer = None
|
||||
|
||||
# Trace all functions and signatures with `training=0` instead of using the
|
||||
# default learning phase placeholder.
|
||||
with K.learning_phase_scope(0):
|
||||
# Trace all functions and signatures with `training=0` instead of using an
|
||||
# already-set learning phase placeholder.
|
||||
# This is needed for compatibility reasons until learning phase setting
|
||||
# is removed from the public apis.
|
||||
with K.deprecated_internal_learning_phase_scope(0):
|
||||
# When saving a model involving batch norm layer within a strategy scope,
|
||||
# the replica context is not available when calling `add_update()`, and thus
|
||||
# we use the default replica context here.
|
||||
|
||||
@ -414,7 +414,7 @@ class LayerCallCollection(object):
|
||||
if self._expects_training_arg:
|
||||
def trace_with_training(value, fn=fn):
|
||||
utils.set_training_arg(value, self._training_arg_index, args, kwargs)
|
||||
with K.learning_phase_scope(value):
|
||||
with K.deprecated_internal_learning_phase_scope(value):
|
||||
fn.get_concrete_function(*args, **kwargs)
|
||||
|
||||
trace_with_training(True)
|
||||
|
||||
@ -26,6 +26,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -375,9 +376,16 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
keras.layers.BatchNormalization(input_shape=(1,)))
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
|
||||
with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
|
||||
# Assert that saving does not log deprecation warnings
|
||||
# (even if it needs to set learning phase for compat reasons)
|
||||
if context.executing_eagerly():
|
||||
self.assertNotIn('deprecated', captured_logs.contents())
|
||||
|
||||
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
|
||||
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
|
||||
self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0])
|
||||
|
||||
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
@ -101,12 +103,20 @@ class AutoTrackable(base.Trackable):
|
||||
"""Return a dict of `Function`s of a trackable."""
|
||||
functions = {}
|
||||
for attribute_name in dir(self):
|
||||
# We get the attributes, suppressing warnings and exceptions.
|
||||
logging_verbosity = logging.get_verbosity()
|
||||
try:
|
||||
logging.set_verbosity(logging.FATAL)
|
||||
attribute_value = getattr(self, attribute_name, None)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# We really don't want to throw an exception just because some object's
|
||||
# attribute accessor is broken.
|
||||
attribute_value = None
|
||||
finally:
|
||||
# We reset the verbosity setting in a `finally` block, to make
|
||||
# sure it always happens, even if we make the exception catching above
|
||||
# be less broad.
|
||||
logging.set_verbosity(logging_verbosity)
|
||||
if isinstance(attribute_value, (def_function.Function,
|
||||
defun.ConcreteFunction)):
|
||||
functions[attribute_name] = attribute_value
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user