Allow autograph to be applied in internal helper utility. That avoids downstream calls from being confused about whether autograph is enabled in their context.

This adds a small overhead to the building of model metrics (~200ms). This overhead should not be noticeable outside of tests which create very large numbers of models.

PiperOrigin-RevId: 341642593
Change-Id: I7d1e13d70d5df072b5215c69f9480f18480b92b5
This commit is contained in:
Dan Moldovan 2020-11-10 10:04:38 -08:00 committed by TensorFlower Gardener
parent 7d2981e88c
commit ad95899595
2 changed files with 72 additions and 2 deletions

View File

@ -122,8 +122,7 @@ def trace_model_call(model, input_signature=None):
if input_signature is None:
raise_model_input_error(model)
# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)
@def_function.function(input_signature=input_signature)
def _wrapped_model(*args):
"""A concrete tf.function that wraps the model's call function."""
# When given a single input, Keras models will call the model on the tensor

View File

@ -45,6 +45,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as load_lib
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import signature_constants
@ -268,10 +269,80 @@ def _import_and_infer(save_dir, inputs):
return session.run(output_dict, feed_dict=feed_dict)
class AutographedMetric(keras.metrics.Metric):
def build(self, input_shape):
pass
def update_state(self, values):
if constant_op.constant(False):
x = 1
else:
x = 2
return x
def reset_states(self):
pass
def result(self):
return constant_op.constant(0)
def GetMean(self):
return constant_op.constant(0)
def GetCount(self):
return constant_op.constant(0)
class BasicAutographedMetricLayer(keras.layers.Layer):
def build(self, input_shape):
self._metric = AutographedMetric()
def call(self, inp):
self._metric.update_state(inp)
# TODO(b/172853147): Test control flow here.
return inp
class BasicAutographedMetricModel(keras.models.Model):
def __init__(self):
super(BasicAutographedMetricModel, self).__init__(name='test_model')
self._layer = BasicAutographedMetricLayer()
def call(self, inputs, **kwargs):
return self._layer(inputs)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class ModelSaveTest(keras_parameterized.TestCase):
def test_model_save_preserves_autograph(self):
model = BasicAutographedMetricModel()
inputs = array_ops.ones((8, 5))
model._set_inputs(inputs)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save_lib.save(model, save_dir)
if model.output_names:
output_name = model.output_names[0]
input_name = model.input_names[0]
else:
output_name = 'output_1'
input_name = 'input_1'
self.assertAllClose({output_name: model.predict_on_batch(inputs)},
_import_and_infer(save_dir,
{input_name: np.ones((8, 5))}))
# Test v2 loading.
# TODO(mdan): tests using _import_and_infer should uniformly do this.
self.assertAllClose(model.predict_on_batch(inputs),
load_lib.load(save_dir)(inputs))
def test_model_save(self):
input_dim = 5
model = testing_utils.get_small_mlp(10, 3, input_dim)