Merge pull request #31490 from k-w-w/cherrypicks_HH6F5
Set default training value to `False` when exporting to SavedModel
This commit is contained in:
commit
78b028c2d5
@ -47,6 +47,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
|
||||
@ -384,6 +385,56 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.assertAllClose(model.predict(input_arr), outputs['predictions'])
|
||||
self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])
|
||||
|
||||
def testTrainingDefaults(self):
|
||||
def assert_training_default(fn, default_value):
|
||||
arg_spec = tf_inspect.getfullargspec(fn)
|
||||
index = len(arg_spec.args) - arg_spec.args.index('training')
|
||||
self.assertEqual(arg_spec.defaults[-index], default_value)
|
||||
|
||||
class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):
|
||||
|
||||
def call(self, inputs, training):
|
||||
return tf_utils.smart_cond(
|
||||
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
|
||||
|
||||
class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
return tf_utils.smart_cond(
|
||||
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
|
||||
|
||||
class Model(keras.models.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer_with_training_default_none = LayerWithLearningPhase()
|
||||
self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
|
||||
self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.layer_with_training_default_none(inputs)
|
||||
x += self.layer_with_training_default_true(inputs)
|
||||
x += self.layer_with_required_training_arg(inputs, False)
|
||||
return x
|
||||
|
||||
model = Model()
|
||||
# Build and set model inputs
|
||||
model.predict(np.ones([1, 3]).astype('float32'))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
load = tf_load.load(saved_model_dir)
|
||||
|
||||
assert_training_default(load.__call__, False)
|
||||
assert_training_default(
|
||||
load.layer_with_training_default_none.__call__, False)
|
||||
assert_training_default(
|
||||
load.layer_with_training_default_true.__call__, True)
|
||||
|
||||
# Assert that there are no defaults for layer with required training arg
|
||||
arg_spec = tf_inspect.getfullargspec(
|
||||
load.layer_with_required_training_arg.__call__)
|
||||
self.assertFalse(arg_spec.defaults) # defaults is None or empty
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase):
|
||||
|
||||
|
||||
@ -113,18 +113,27 @@ def maybe_add_training_arg(
|
||||
# Create arg spec for decorated function. If 'training' is not defined in the
|
||||
# args of the original arg spec, then add it to kwonlyargs.
|
||||
arg_spec = tf_inspect.getfullargspec(original_call)
|
||||
defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
|
||||
|
||||
kwonlyargs = arg_spec.kwonlyargs
|
||||
kwonlydefaults = arg_spec.kwonlydefaults or {}
|
||||
# Add training arg if it does not exist, or set the default training value.
|
||||
if 'training' not in arg_spec.args:
|
||||
kwonlyargs.append('training')
|
||||
kwonlydefaults['training'] = default_training_value
|
||||
else:
|
||||
index = arg_spec.args.index('training')
|
||||
training_default_index = len(arg_spec.args) - index
|
||||
if (arg_spec.defaults and
|
||||
len(arg_spec.defaults) >= training_default_index and
|
||||
defaults[-training_default_index] is None):
|
||||
defaults[-training_default_index] = default_training_value
|
||||
|
||||
decorator_argspec = tf_inspect.FullArgSpec(
|
||||
args=arg_spec.args,
|
||||
varargs=arg_spec.varargs,
|
||||
varkw=arg_spec.varkw,
|
||||
defaults=arg_spec.defaults,
|
||||
defaults=defaults,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwonlydefaults=kwonlydefaults,
|
||||
annotations=arg_spec.annotations)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user