Merge pull request #31490 from k-w-w/cherrypicks_HH6F5

Set default training value to `False` when exporting to SavedModel
This commit is contained in:
Goldie Gadde 2019-08-12 10:46:06 -07:00 committed by GitHub
commit 78b028c2d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 1 deletions

View File

@ -47,6 +47,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import save as tf_save
from tensorflow.python.util import tf_inspect
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
@ -384,6 +385,56 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(model.predict(input_arr), outputs['predictions'])
self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])
def testTrainingDefaults(self):
def assert_training_default(fn, default_value):
arg_spec = tf_inspect.getfullargspec(fn)
index = len(arg_spec.args) - arg_spec.args.index('training')
self.assertEqual(arg_spec.defaults[-index], default_value)
class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):
def call(self, inputs, training):
return tf_utils.smart_cond(
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):
def call(self, inputs, training=True):
return tf_utils.smart_cond(
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))
class Model(keras.models.Model):
def __init__(self):
super(Model, self).__init__()
self.layer_with_training_default_none = LayerWithLearningPhase()
self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()
def call(self, inputs):
x = self.layer_with_training_default_none(inputs)
x += self.layer_with_training_default_true(inputs)
x += self.layer_with_required_training_arg(inputs, False)
return x
model = Model()
# Build and set model inputs
model.predict(np.ones([1, 3]).astype('float32'))
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
load = tf_load.load(saved_model_dir)
assert_training_default(load.__call__, False)
assert_training_default(
load.layer_with_training_default_none.__call__, False)
assert_training_default(
load.layer_with_training_default_true.__call__, True)
# Assert that there are no defaults for layer with required training arg
arg_spec = tf_inspect.getfullargspec(
load.layer_with_required_training_arg.__call__)
self.assertFalse(arg_spec.defaults) # defaults is None or empty
class TestLayerCallTracing(test.TestCase):

View File

@ -113,18 +113,27 @@ def maybe_add_training_arg(
# Create arg spec for decorated function. If 'training' is not defined in the
# args of the original arg spec, then add it to kwonlyargs.
arg_spec = tf_inspect.getfullargspec(original_call)
defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
kwonlyargs = arg_spec.kwonlyargs
kwonlydefaults = arg_spec.kwonlydefaults or {}
# Add training arg if it does not exist, or set the default training value.
if 'training' not in arg_spec.args:
kwonlyargs.append('training')
kwonlydefaults['training'] = default_training_value
else:
index = arg_spec.args.index('training')
training_default_index = len(arg_spec.args) - index
if (arg_spec.defaults and
len(arg_spec.defaults) >= training_default_index and
defaults[-training_default_index] is None):
defaults[-training_default_index] = default_training_value
decorator_argspec = tf_inspect.FullArgSpec(
args=arg_spec.args,
varargs=arg_spec.varargs,
varkw=arg_spec.varkw,
defaults=arg_spec.defaults,
defaults=defaults,
kwonlyargs=kwonlyargs,
kwonlydefaults=kwonlydefaults,
annotations=arg_spec.annotations)