Allow autograph to be applied in internal helper utility. That avoids downstream calls from being confused about whether autograph is enabled in their context.
This adds a small overhead to the building of model metrics (~200ms). This overhead should not be noticeable outside of tests which create very large numbers of models. PiperOrigin-RevId: 341642593 Change-Id: I7d1e13d70d5df072b5215c69f9480f18480b92b5
This commit is contained in:
parent
7d2981e88c
commit
ad95899595
@ -122,8 +122,7 @@ def trace_model_call(model, input_signature=None):
|
||||
if input_signature is None:
|
||||
raise_model_input_error(model)
|
||||
|
||||
# TODO(mdan): Should the model's call be autographed by default?
|
||||
@def_function.function(input_signature=input_signature, autograph=False)
|
||||
@def_function.function(input_signature=input_signature)
|
||||
def _wrapped_model(*args):
|
||||
"""A concrete tf.function that wraps the model's call function."""
|
||||
# When given a single input, Keras models will call the model on the tensor
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as load_lib
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
@ -268,10 +269,80 @@ def _import_and_infer(save_dir, inputs):
|
||||
return session.run(output_dict, feed_dict=feed_dict)
|
||||
|
||||
|
||||
class AutographedMetric(keras.metrics.Metric):
|
||||
|
||||
def build(self, input_shape):
|
||||
pass
|
||||
|
||||
def update_state(self, values):
|
||||
if constant_op.constant(False):
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
return x
|
||||
|
||||
def reset_states(self):
|
||||
pass
|
||||
|
||||
def result(self):
|
||||
return constant_op.constant(0)
|
||||
|
||||
def GetMean(self):
|
||||
return constant_op.constant(0)
|
||||
|
||||
def GetCount(self):
|
||||
return constant_op.constant(0)
|
||||
|
||||
|
||||
class BasicAutographedMetricLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self._metric = AutographedMetric()
|
||||
|
||||
def call(self, inp):
|
||||
self._metric.update_state(inp)
|
||||
# TODO(b/172853147): Test control flow here.
|
||||
return inp
|
||||
|
||||
|
||||
class BasicAutographedMetricModel(keras.models.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(BasicAutographedMetricModel, self).__init__(name='test_model')
|
||||
self._layer = BasicAutographedMetricLayer()
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
return self._layer(inputs)
|
||||
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class ModelSaveTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_model_save_preserves_autograph(self):
|
||||
model = BasicAutographedMetricModel()
|
||||
inputs = array_ops.ones((8, 5))
|
||||
model._set_inputs(inputs)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save_lib.save(model, save_dir)
|
||||
|
||||
if model.output_names:
|
||||
output_name = model.output_names[0]
|
||||
input_name = model.input_names[0]
|
||||
else:
|
||||
output_name = 'output_1'
|
||||
input_name = 'input_1'
|
||||
|
||||
self.assertAllClose({output_name: model.predict_on_batch(inputs)},
|
||||
_import_and_infer(save_dir,
|
||||
{input_name: np.ones((8, 5))}))
|
||||
|
||||
# Test v2 loading.
|
||||
# TODO(mdan): tests using _import_and_infer should uniformly do this.
|
||||
self.assertAllClose(model.predict_on_batch(inputs),
|
||||
load_lib.load(save_dir)(inputs))
|
||||
|
||||
def test_model_save(self):
|
||||
input_dim = 5
|
||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user