Add attribute to Keras model which generates an exportable tf.function. SaveModel save now looks for this attribute when searching for a function to export.
PiperOrigin-RevId: 224861089
This commit is contained in:
parent
841f5d9fc9
commit
ee418c8ee2
@ -342,6 +342,10 @@ class PolymorphicFunction(object):
|
||||
"""The python function wrapped in this tf.function."""
|
||||
return self._python_function
|
||||
|
||||
@property
|
||||
def input_signature(self):
|
||||
return self._input_signature
|
||||
|
||||
def get_initialization_function(self, *args, **kwargs):
|
||||
"""Returns a `Function` object which initializes this function's variables.
|
||||
|
||||
|
@ -848,6 +848,7 @@ py_test(
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/saved_model:save_test",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -1539,8 +1539,7 @@ class Model(Network):
|
||||
|
||||
outputs = nest.flatten(outputs)
|
||||
self.outputs = outputs
|
||||
self.output_names = [
|
||||
'output_%d' % (i + 1) for i in range(len(self.outputs))]
|
||||
self.output_names = training_utils.generic_output_names(outputs)
|
||||
self.built = True
|
||||
|
||||
def fit(self,
|
||||
@ -2580,6 +2579,10 @@ class Model(Network):
|
||||
batch_size = 32
|
||||
return batch_size
|
||||
|
||||
@property
|
||||
def _default_save_signature(self):
|
||||
return training_utils.trace_model_call(self)
|
||||
|
||||
|
||||
class DistributedCallbackModel(Model):
|
||||
"""Model that is used for callbacks with DistributionStrategy."""
|
||||
|
@ -27,9 +27,11 @@ import six
|
||||
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
@ -1191,3 +1193,61 @@ def get_static_batch_size(layer):
|
||||
if batch_input_shape is not None:
|
||||
return tensor_shape.as_dimension(batch_input_shape[0]).value
|
||||
return None
|
||||
|
||||
|
||||
def generic_output_names(outputs_list):
|
||||
return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
|
||||
|
||||
|
||||
def trace_model_call(model, input_signature=None):
|
||||
"""Trace the model call to create a tf.function for exporting a Keras model.
|
||||
|
||||
Args:
|
||||
model: A Keras model.
|
||||
input_signature: optional, a list of tf.TensorSpec objects specifying the
|
||||
inputs to the model.
|
||||
|
||||
Returns:
|
||||
A tf.function wrapping the model's call function with input signatures set.
|
||||
|
||||
Raises:
|
||||
ValueError: if input signature cannot be inferred from the model.
|
||||
"""
|
||||
if input_signature is None:
|
||||
if isinstance(model.call, def_function.PolymorphicFunction):
|
||||
input_signature = model.call.input_signature
|
||||
|
||||
if input_signature is None:
|
||||
try:
|
||||
inputs = model.inputs
|
||||
input_names = model.input_names
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
'Model {} cannot be saved because the input shapes have not been '
|
||||
'set. Usually, input shapes are automatically determined from calling'
|
||||
' .fit() or .predict(). To manually set the shapes, call '
|
||||
'model._set_inputs(inputs).'.format(model))
|
||||
input_specs = []
|
||||
for input_tensor, input_name in zip(inputs, input_names):
|
||||
input_specs.append(tensor_spec.TensorSpec(
|
||||
shape=input_tensor.shape, dtype=input_tensor.dtype,
|
||||
name=input_name))
|
||||
# The input signature of the call function is a list with one element, since
|
||||
# all tensor inputs must be passed in as the first argument.
|
||||
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
|
||||
|
||||
@def_function.function(input_signature=input_signature)
|
||||
def _wrapped_model(*args):
|
||||
"""A concrete tf.function that wraps the model's call function."""
|
||||
# When given a single input, Keras models will call the model on the tensor
|
||||
# rather than a list consisting of the single tensor.
|
||||
inputs = args[0] if len(input_signature) == 1 else list(args)
|
||||
outputs_list = nest.flatten(model(inputs=inputs))
|
||||
try:
|
||||
output_names = model.output_names
|
||||
except AttributeError:
|
||||
output_names = generic_output_names(outputs_list)
|
||||
return {name: output for name, output in zip(output_names, outputs_list)}
|
||||
|
||||
return _wrapped_model
|
||||
|
||||
|
@ -18,13 +18,25 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import save_test
|
||||
|
||||
|
||||
class ModelInputsTest(test.TestCase):
|
||||
@ -85,5 +97,150 @@ class ModelInputsTest(test.TestCase):
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(vals['b']))
|
||||
|
||||
|
||||
class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
|
||||
def _assert_all_close(self, expected, actual):
|
||||
if not context.executing_eagerly():
|
||||
with self.cached_session() as sess:
|
||||
K._initialize_variables(sess)
|
||||
self.assertAllClose(expected, actual)
|
||||
else:
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_trace_model_outputs(self):
|
||||
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
|
||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||
inputs = array_ops.ones((8, 5))
|
||||
|
||||
if input_dim is None:
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'input shapes have not been set'):
|
||||
training_utils.trace_model_call(model)
|
||||
model._set_inputs(inputs)
|
||||
|
||||
fn = training_utils.trace_model_call(model)
|
||||
signature_outputs = fn(inputs)
|
||||
expected_outputs = {model.output_names[0]: model(inputs)}
|
||||
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_trace_model_outputs_after_fitting(self):
|
||||
input_dim = 5 if testing_utils.get_model_type() == 'functional' else None
|
||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||
model.compile(optimizer='sgd', loss='mse')
|
||||
model.fit(x=np.random.random((8, 5)),
|
||||
y=np.random.random((8, 3)), epochs=2)
|
||||
|
||||
inputs = array_ops.ones((8, 5))
|
||||
|
||||
fn = training_utils.trace_model_call(model)
|
||||
signature_outputs = fn(inputs)
|
||||
expected_outputs = {model.output_names[0]: model(inputs)}
|
||||
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models='sequential')
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_trace_multi_io_model_outputs(self):
|
||||
input_dim = 5
|
||||
num_classes = 3
|
||||
num_classes_b = 4
|
||||
input_a = keras.layers.Input(shape=(input_dim,), name='input_a')
|
||||
input_b = keras.layers.Input(shape=(input_dim,), name='input_b')
|
||||
|
||||
dense = keras.layers.Dense(num_classes, name='dense')
|
||||
dense2 = keras.layers.Dense(num_classes_b, name='dense2')
|
||||
dropout = keras.layers.Dropout(0.5, name='dropout')
|
||||
branch_a = [input_a, dense]
|
||||
branch_b = [input_b, dense, dense2, dropout]
|
||||
|
||||
model = testing_utils.get_multi_io_model(branch_a, branch_b)
|
||||
|
||||
input_a_np = np.random.random((10, input_dim)).astype(np.float32)
|
||||
input_b_np = np.random.random((10, input_dim)).astype(np.float32)
|
||||
|
||||
if testing_utils.get_model_type() == 'subclass':
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'input shapes have not been set'):
|
||||
training_utils.trace_model_call(model)
|
||||
|
||||
model.compile(optimizer='sgd', loss='mse')
|
||||
model.fit(x=[np.random.random((8, input_dim)).astype(np.float32),
|
||||
np.random.random((8, input_dim)).astype(np.float32)],
|
||||
y=[np.random.random((8, num_classes)).astype(np.float32),
|
||||
np.random.random((8, num_classes_b)).astype(np.float32)],
|
||||
epochs=2)
|
||||
|
||||
fn = training_utils.trace_model_call(model)
|
||||
signature_outputs = fn([input_a_np, input_b_np])
|
||||
outputs = model([input_a_np, input_b_np])
|
||||
expected_outputs = {model.output_names[0]: outputs[0],
|
||||
model.output_names[1]: outputs[1]}
|
||||
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_specify_input_signature(self):
|
||||
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
||||
inputs = array_ops.ones((8, 5))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'):
|
||||
training_utils.trace_model_call(model)
|
||||
|
||||
fn = training_utils.trace_model_call(
|
||||
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
|
||||
signature_outputs = fn(inputs)
|
||||
expected_outputs = {model.output_names[0]: model(inputs)}
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_subclassed_model_with_input_signature(self):
|
||||
|
||||
class Model(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.dense = keras.layers.Dense(3, name='dense')
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[[tensor_spec.TensorSpec([None, 5], dtypes.float32),
|
||||
tensor_spec.TensorSpec([None], dtypes.float32)]],)
|
||||
def call(self, inputs, *args):
|
||||
x, y = inputs
|
||||
return self.dense(x) + y
|
||||
|
||||
model = Model()
|
||||
fn = training_utils.trace_model_call(model)
|
||||
x = array_ops.ones((8, 5), dtype=dtypes.float32)
|
||||
y = array_ops.ones((3,), dtype=dtypes.float32)
|
||||
expected_outputs = {'output_1': model([x, y])}
|
||||
signature_outputs = fn([x, y])
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
|
||||
class ModelSaveTest(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_model_save(self):
|
||||
input_dim = 5
|
||||
model = testing_utils.get_small_mlp(10, 3, input_dim)
|
||||
inputs = array_ops.ones((8, 5))
|
||||
|
||||
if testing_utils.get_model_type() == 'subclass':
|
||||
model._set_inputs(inputs)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save_lib.save(model, save_dir)
|
||||
|
||||
self.assertAllClose(
|
||||
{model.output_names[0]: model.predict_on_batch(inputs)},
|
||||
save_test._import_and_infer(save_dir,
|
||||
{model.input_names[0]: np.ones((8, 5))}))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -31,7 +31,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -50,28 +49,7 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _check_for_functional_keras_model(root):
|
||||
"""Makes an export signature for `root` if it's a functional Keras Model."""
|
||||
# If nothing is decorated yet but this is a functional Keras Model (duck
|
||||
# typed), we'll try to make a signature ourselves.
|
||||
try:
|
||||
inputs = root.inputs
|
||||
input_names = root.input_names
|
||||
except AttributeError:
|
||||
return None
|
||||
input_signature = []
|
||||
for input_tensor, input_name in zip(inputs, input_names):
|
||||
input_signature.append(tensor_spec.TensorSpec(
|
||||
shape=input_tensor.shape, dtype=input_tensor.dtype,
|
||||
name=input_name))
|
||||
|
||||
@def_function.function(input_signature=input_signature)
|
||||
def _wrapped_model(*args):
|
||||
outputs_list = nest.flatten(root(inputs=list(args)))
|
||||
return {name: output for name, output
|
||||
in zip(root.output_names, outputs_list)}
|
||||
return _wrapped_model
|
||||
DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
|
||||
|
||||
|
||||
def _find_function_to_export(root):
|
||||
@ -93,7 +71,7 @@ def _find_function_to_export(root):
|
||||
exported_function = attribute_value
|
||||
previous_attribute_name = attribute_name
|
||||
if exported_function is None:
|
||||
exported_function = _check_for_functional_keras_model(root)
|
||||
exported_function = getattr(root, DEFAULT_SIGNATURE_ATTR, None)
|
||||
if exported_function is None:
|
||||
raise ValueError(
|
||||
("Exporting an object with no tf.saved_model.save(..., signatures=...) "
|
||||
|
@ -21,8 +21,6 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -32,12 +30,8 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import merge
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -50,10 +44,9 @@ from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
|
||||
|
||||
class _ModelWithOptimizer(training.Model):
|
||||
class _ModelWithOptimizer(util.Checkpoint):
|
||||
|
||||
def __init__(self):
|
||||
super(_ModelWithOptimizer, self).__init__()
|
||||
self.dense = core.Dense(1)
|
||||
self.optimizer = adam.AdamOptimizer(0.01)
|
||||
|
||||
@ -63,7 +56,7 @@ class _ModelWithOptimizer(training.Model):
|
||||
def call(self, x, y):
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
|
||||
trainable_variables = self.trainable_variables
|
||||
trainable_variables = self.dense.trainable_variables
|
||||
gradients = tape.gradient(loss, trainable_variables)
|
||||
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||||
return {"loss": loss}
|
||||
@ -179,10 +172,10 @@ class SaveTest(test.TestCase):
|
||||
x = constant_op.constant([[3., 4.]])
|
||||
y = constant_op.constant([2.])
|
||||
model = _ModelWithOptimizer()
|
||||
first_loss = model(x, y)
|
||||
first_loss = model.call(x, y)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir, model.call)
|
||||
second_loss = model(x, y)
|
||||
second_loss = model.call(x, y)
|
||||
self.assertNotEqual(first_loss, second_loss)
|
||||
self.assertAllClose(
|
||||
second_loss,
|
||||
@ -197,7 +190,7 @@ class SaveTest(test.TestCase):
|
||||
model = _ModelWithOptimizer()
|
||||
x = constant_op.constant([[3., 4.]])
|
||||
y = constant_op.constant([2.])
|
||||
model(x, y)
|
||||
model.call(x, y)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir)
|
||||
self.assertIn("loss",
|
||||
@ -217,25 +210,40 @@ class SaveTest(test.TestCase):
|
||||
model = _ModelWithOptimizer()
|
||||
x = constant_op.constant([[3., 4.]])
|
||||
y = constant_op.constant([2.])
|
||||
model(x, y)
|
||||
model.call(x, y)
|
||||
model.second_function = def_function.function(lambda: 1.)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
with self.assertRaisesRegexp(ValueError, "call.*second_function"):
|
||||
save.save(model, save_dir)
|
||||
|
||||
def test_subclassed_no_signature(self):
|
||||
def test_no_signature(self):
|
||||
|
||||
class Subclassed(training.Model):
|
||||
class Model(util.Checkpoint):
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs * 2.
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
model = Subclassed()
|
||||
model = Model()
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "no @tf.function-decorated methods"):
|
||||
save.save(model, save_dir)
|
||||
|
||||
def test_find_default_save_function(self):
|
||||
|
||||
class ObjWithDefaultSignature(util.Checkpoint):
|
||||
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec(
|
||||
shape=None, dtype=dtypes.float32)])
|
||||
def _default_save_signature(self, x):
|
||||
return x + x + 1
|
||||
|
||||
obj = ObjWithDefaultSignature()
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(obj, save_dir)
|
||||
self.assertAllClose(
|
||||
{"output_0": 7.}, _import_and_infer(save_dir, {"x": 3.}))
|
||||
|
||||
def test_docstring(self):
|
||||
|
||||
class Adder(util.Checkpoint):
|
||||
@ -276,46 +284,6 @@ class SaveTest(test.TestCase):
|
||||
self.assertNotIn("T", complex_node.attr)
|
||||
self.assertNotIn("Tout", complex_node.attr)
|
||||
|
||||
def test_export_functional_keras_model(self):
|
||||
x = input_layer.Input((4,), name="x")
|
||||
y = core.Dense(4, name="out")(x)
|
||||
model = training.Model(x, y)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir)
|
||||
self.assertAllClose(
|
||||
{"out": model(array_ops.ones([1, 4]))},
|
||||
_import_and_infer(save_dir, {"x": [[1., 1., 1., 1.]]}))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def test_export_functional_keras_model_after_fit(self):
|
||||
x = input_layer.Input((1,))
|
||||
y = core.Dense(1, name="y")(x)
|
||||
model = training.Model(x, y)
|
||||
model.compile(optimizer="sgd", loss="mse")
|
||||
model.fit(x=numpy.array([[1.]]),
|
||||
y=numpy.array([2.]), epochs=2)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir)
|
||||
self.assertAllClose(
|
||||
{"y": model(constant_op.constant([[1.], [2.]]))},
|
||||
_import_and_infer(save_dir, {"input_1": [[1.], [2.]]}))
|
||||
|
||||
def test_export_multi_input_functional_keras_model(self):
|
||||
x1 = input_layer.Input((2,), name="x1")
|
||||
x2 = input_layer.Input((2,), name="x2")
|
||||
y1 = core.Dense(4)(merge.Add()([x1, x2]))
|
||||
y2 = core.Dense(4)(merge.Multiply()([x1, x2]))
|
||||
model = training.Model([x1, x2], [y1, y2])
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir)
|
||||
outputs = model([array_ops.ones([1, 2]), 2. * array_ops.ones([1, 2])])
|
||||
self.assertAllClose(
|
||||
{"dense": outputs[0], "dense_1": outputs[1]},
|
||||
_import_and_infer(
|
||||
save_dir,
|
||||
{"x1": [[1., 1.]],
|
||||
"x2": [[2., 2.]]}))
|
||||
|
||||
|
||||
class AssetTests(test.TestCase):
|
||||
|
||||
@ -376,7 +344,7 @@ class MemoryTests(test.TestCase):
|
||||
def test_no_reference_cycles(self):
|
||||
x = constant_op.constant([[3., 4.]])
|
||||
y = constant_op.constant([2.])
|
||||
self._model(x, y)
|
||||
self._model.call(x, y)
|
||||
if sys.version_info[0] < 3:
|
||||
# TODO(allenl): debug reference cycles in Python 2.x
|
||||
self.skipTest("This test only works in Python 3+. Reference cycles are "
|
||||
|
Loading…
Reference in New Issue
Block a user