diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 6bacd7a962f..3663d729996 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -342,6 +342,10 @@ class PolymorphicFunction(object): """The python function wrapped in this tf.function.""" return self._python_function + @property + def input_signature(self): + return self._input_signature + def get_initialization_function(self, *args, **kwargs): """Returns a `Function` object which initializes this function's variables. diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 36fea36389d..faf58e0d93f 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -848,6 +848,7 @@ py_test( deps = [ ":keras", "//tensorflow/python:client_testlib", + "//tensorflow/python/saved_model:save_test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 462694fda69..fe44bc20a1c 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1539,8 +1539,7 @@ class Model(Network): outputs = nest.flatten(outputs) self.outputs = outputs - self.output_names = [ - 'output_%d' % (i + 1) for i in range(len(self.outputs))] + self.output_names = training_utils.generic_output_names(outputs) self.built = True def fit(self, @@ -2580,6 +2579,10 @@ class Model(Network): batch_size = 32 return batch_size + @property + def _default_save_signature(self): + return training_utils.trace_model_call(self) + class DistributedCallbackModel(Model): """Model that is used for callbacks with DistributionStrategy.""" diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 01a09eb031e..ec6b39704a0 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -27,9 +27,11 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks @@ -1191,3 +1193,61 @@ def get_static_batch_size(layer): if batch_input_shape is not None: return tensor_shape.as_dimension(batch_input_shape[0]).value return None + + +def generic_output_names(outputs_list): + return ['output_%d' % (i + 1) for i in range(len(outputs_list))] + + +def trace_model_call(model, input_signature=None): + """Trace the model call to create a tf.function for exporting a Keras model. + + Args: + model: A Keras model. + input_signature: optional, a list of tf.TensorSpec objects specifying the + inputs to the model. + + Returns: + A tf.function wrapping the model's call function with input signatures set. + + Raises: + ValueError: if input signature cannot be inferred from the model. + """ + if input_signature is None: + if isinstance(model.call, def_function.PolymorphicFunction): + input_signature = model.call.input_signature + + if input_signature is None: + try: + inputs = model.inputs + input_names = model.input_names + except AttributeError: + raise ValueError( + 'Model {} cannot be saved because the input shapes have not been ' + 'set. Usually, input shapes are automatically determined from calling' + ' .fit() or .predict(). To manually set the shapes, call ' + 'model._set_inputs(inputs).'.format(model)) + input_specs = [] + for input_tensor, input_name in zip(inputs, input_names): + input_specs.append(tensor_spec.TensorSpec( + shape=input_tensor.shape, dtype=input_tensor.dtype, + name=input_name)) + # The input signature of the call function is a list with one element, since + # all tensor inputs must be passed in as the first argument. + input_signature = [input_specs] if len(input_specs) > 1 else input_specs + + @def_function.function(input_signature=input_signature) + def _wrapped_model(*args): + """A concrete tf.function that wraps the model's call function.""" + # When given a single input, Keras models will call the model on the tensor + # rather than a list consisting of the single tensor. + inputs = args[0] if len(input_signature) == 1 else list(args) + outputs_list = nest.flatten(model(inputs=inputs)) + try: + output_names = model.output_names + except AttributeError: + output_names = generic_output_names(outputs_list) + return {name: output for name, output in zip(output_names, outputs_list)} + + return _wrapped_model + diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py index 44ea23998fe..0250e604266 100644 --- a/tensorflow/python/keras/engine/training_utils_test.py +++ b/tensorflow/python/keras/engine/training_utils_test.py @@ -18,13 +18,25 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np +from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +from tensorflow.python.saved_model import save as save_lib +from tensorflow.python.saved_model import save_test class ModelInputsTest(test.TestCase): @@ -85,5 +97,150 @@ class ModelInputsTest(test.TestCase): self.assertTrue(tf_utils.is_symbolic_tensor(vals['b'])) +class TraceModelCallTest(keras_parameterized.TestCase): + + def _assert_all_close(self, expected, actual): + if not context.executing_eagerly(): + with self.cached_session() as sess: + K._initialize_variables(sess) + self.assertAllClose(expected, actual) + else: + self.assertAllClose(expected, actual) + + @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes + def test_trace_model_outputs(self): + input_dim = 5 if testing_utils.get_model_type() == 'functional' else None + model = testing_utils.get_small_mlp(10, 3, input_dim) + inputs = array_ops.ones((8, 5)) + + if input_dim is None: + with self.assertRaisesRegexp(ValueError, + 'input shapes have not been set'): + training_utils.trace_model_call(model) + model._set_inputs(inputs) + + fn = training_utils.trace_model_call(model) + signature_outputs = fn(inputs) + expected_outputs = {model.output_names[0]: model(inputs)} + + self._assert_all_close(expected_outputs, signature_outputs) + + @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes + def test_trace_model_outputs_after_fitting(self): + input_dim = 5 if testing_utils.get_model_type() == 'functional' else None + model = testing_utils.get_small_mlp(10, 3, input_dim) + model.compile(optimizer='sgd', loss='mse') + model.fit(x=np.random.random((8, 5)), + y=np.random.random((8, 3)), epochs=2) + + inputs = array_ops.ones((8, 5)) + + fn = training_utils.trace_model_call(model) + signature_outputs = fn(inputs) + expected_outputs = {model.output_names[0]: model(inputs)} + + self._assert_all_close(expected_outputs, signature_outputs) + + @keras_parameterized.run_with_all_model_types(exclude_models='sequential') + @keras_parameterized.run_all_keras_modes + def test_trace_multi_io_model_outputs(self): + input_dim = 5 + num_classes = 3 + num_classes_b = 4 + input_a = keras.layers.Input(shape=(input_dim,), name='input_a') + input_b = keras.layers.Input(shape=(input_dim,), name='input_b') + + dense = keras.layers.Dense(num_classes, name='dense') + dense2 = keras.layers.Dense(num_classes_b, name='dense2') + dropout = keras.layers.Dropout(0.5, name='dropout') + branch_a = [input_a, dense] + branch_b = [input_b, dense, dense2, dropout] + + model = testing_utils.get_multi_io_model(branch_a, branch_b) + + input_a_np = np.random.random((10, input_dim)).astype(np.float32) + input_b_np = np.random.random((10, input_dim)).astype(np.float32) + + if testing_utils.get_model_type() == 'subclass': + with self.assertRaisesRegexp(ValueError, + 'input shapes have not been set'): + training_utils.trace_model_call(model) + + model.compile(optimizer='sgd', loss='mse') + model.fit(x=[np.random.random((8, input_dim)).astype(np.float32), + np.random.random((8, input_dim)).astype(np.float32)], + y=[np.random.random((8, num_classes)).astype(np.float32), + np.random.random((8, num_classes_b)).astype(np.float32)], + epochs=2) + + fn = training_utils.trace_model_call(model) + signature_outputs = fn([input_a_np, input_b_np]) + outputs = model([input_a_np, input_b_np]) + expected_outputs = {model.output_names[0]: outputs[0], + model.output_names[1]: outputs[1]} + + self._assert_all_close(expected_outputs, signature_outputs) + + @keras_parameterized.run_all_keras_modes + def test_specify_input_signature(self): + model = testing_utils.get_small_sequential_mlp(10, 3, None) + inputs = array_ops.ones((8, 5)) + + with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'): + training_utils.trace_model_call(model) + + fn = training_utils.trace_model_call( + model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)]) + signature_outputs = fn(inputs) + expected_outputs = {model.output_names[0]: model(inputs)} + self._assert_all_close(expected_outputs, signature_outputs) + + @keras_parameterized.run_all_keras_modes + def test_subclassed_model_with_input_signature(self): + + class Model(keras.Model): + + def __init__(self): + super(Model, self).__init__() + self.dense = keras.layers.Dense(3, name='dense') + + @def_function.function( + input_signature=[[tensor_spec.TensorSpec([None, 5], dtypes.float32), + tensor_spec.TensorSpec([None], dtypes.float32)]],) + def call(self, inputs, *args): + x, y = inputs + return self.dense(x) + y + + model = Model() + fn = training_utils.trace_model_call(model) + x = array_ops.ones((8, 5), dtype=dtypes.float32) + y = array_ops.ones((3,), dtype=dtypes.float32) + expected_outputs = {'output_1': model([x, y])} + signature_outputs = fn([x, y]) + self._assert_all_close(expected_outputs, signature_outputs) + + +class ModelSaveTest(keras_parameterized.TestCase): + + @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_model_save(self): + input_dim = 5 + model = testing_utils.get_small_mlp(10, 3, input_dim) + inputs = array_ops.ones((8, 5)) + + if testing_utils.get_model_type() == 'subclass': + model._set_inputs(inputs) + + save_dir = os.path.join(self.get_temp_dir(), 'saved_model') + save_lib.save(model, save_dir) + + self.assertAllClose( + {model.output_names[0]: model.predict_on_batch(inputs)}, + save_test._import_and_infer(save_dir, + {model.input_names[0]: np.ones((8, 5))})) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index ab6fcb7196f..e2726087a5c 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -31,7 +31,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -50,28 +49,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export - -def _check_for_functional_keras_model(root): - """Makes an export signature for `root` if it's a functional Keras Model.""" - # If nothing is decorated yet but this is a functional Keras Model (duck - # typed), we'll try to make a signature ourselves. - try: - inputs = root.inputs - input_names = root.input_names - except AttributeError: - return None - input_signature = [] - for input_tensor, input_name in zip(inputs, input_names): - input_signature.append(tensor_spec.TensorSpec( - shape=input_tensor.shape, dtype=input_tensor.dtype, - name=input_name)) - - @def_function.function(input_signature=input_signature) - def _wrapped_model(*args): - outputs_list = nest.flatten(root(inputs=list(args))) - return {name: output for name, output - in zip(root.output_names, outputs_list)} - return _wrapped_model +DEFAULT_SIGNATURE_ATTR = "_default_save_signature" def _find_function_to_export(root): @@ -93,7 +71,7 @@ def _find_function_to_export(root): exported_function = attribute_value previous_attribute_name = attribute_name if exported_function is None: - exported_function = _check_for_functional_keras_model(root) + exported_function = getattr(root, DEFAULT_SIGNATURE_ATTR, None) if exported_function is None: raise ValueError( ("Exporting an object with no tf.saved_model.save(..., signatures=...) " diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 97218a98eae..1c6eb1b538a 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -21,8 +21,6 @@ from __future__ import print_function import os import sys -import numpy - from tensorflow.python.client import session as session_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function @@ -32,12 +30,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import input_layer -from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core -from tensorflow.python.keras.layers import merge from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -50,10 +44,9 @@ from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util -class _ModelWithOptimizer(training.Model): +class _ModelWithOptimizer(util.Checkpoint): def __init__(self): - super(_ModelWithOptimizer, self).__init__() self.dense = core.Dense(1) self.optimizer = adam.AdamOptimizer(0.01) @@ -63,7 +56,7 @@ class _ModelWithOptimizer(training.Model): def call(self, x, y): with backprop.GradientTape() as tape: loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.) - trainable_variables = self.trainable_variables + trainable_variables = self.dense.trainable_variables gradients = tape.gradient(loss, trainable_variables) self.optimizer.apply_gradients(zip(gradients, trainable_variables)) return {"loss": loss} @@ -179,10 +172,10 @@ class SaveTest(test.TestCase): x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) model = _ModelWithOptimizer() - first_loss = model(x, y) + first_loss = model.call(x, y) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(model, save_dir, model.call) - second_loss = model(x, y) + second_loss = model.call(x, y) self.assertNotEqual(first_loss, second_loss) self.assertAllClose( second_loss, @@ -197,7 +190,7 @@ class SaveTest(test.TestCase): model = _ModelWithOptimizer() x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) - model(x, y) + model.call(x, y) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(model, save_dir) self.assertIn("loss", @@ -217,25 +210,40 @@ class SaveTest(test.TestCase): model = _ModelWithOptimizer() x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) - model(x, y) + model.call(x, y) model.second_function = def_function.function(lambda: 1.) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "call.*second_function"): save.save(model, save_dir) - def test_subclassed_no_signature(self): + def test_no_signature(self): - class Subclassed(training.Model): + class Model(util.Checkpoint): def call(self, inputs): return inputs * 2. save_dir = os.path.join(self.get_temp_dir(), "saved_model") - model = Subclassed() + model = Model() with self.assertRaisesRegexp( ValueError, "no @tf.function-decorated methods"): save.save(model, save_dir) + def test_find_default_save_function(self): + + class ObjWithDefaultSignature(util.Checkpoint): + + @def_function.function(input_signature=[tensor_spec.TensorSpec( + shape=None, dtype=dtypes.float32)]) + def _default_save_signature(self, x): + return x + x + 1 + + obj = ObjWithDefaultSignature() + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(obj, save_dir) + self.assertAllClose( + {"output_0": 7.}, _import_and_infer(save_dir, {"x": 3.})) + def test_docstring(self): class Adder(util.Checkpoint): @@ -276,46 +284,6 @@ class SaveTest(test.TestCase): self.assertNotIn("T", complex_node.attr) self.assertNotIn("Tout", complex_node.attr) - def test_export_functional_keras_model(self): - x = input_layer.Input((4,), name="x") - y = core.Dense(4, name="out")(x) - model = training.Model(x, y) - save_dir = os.path.join(self.get_temp_dir(), "saved_model") - save.save(model, save_dir) - self.assertAllClose( - {"out": model(array_ops.ones([1, 4]))}, - _import_and_infer(save_dir, {"x": [[1., 1., 1., 1.]]})) - - @test_util.run_v1_only("b/120545219") - def test_export_functional_keras_model_after_fit(self): - x = input_layer.Input((1,)) - y = core.Dense(1, name="y")(x) - model = training.Model(x, y) - model.compile(optimizer="sgd", loss="mse") - model.fit(x=numpy.array([[1.]]), - y=numpy.array([2.]), epochs=2) - save_dir = os.path.join(self.get_temp_dir(), "saved_model") - save.save(model, save_dir) - self.assertAllClose( - {"y": model(constant_op.constant([[1.], [2.]]))}, - _import_and_infer(save_dir, {"input_1": [[1.], [2.]]})) - - def test_export_multi_input_functional_keras_model(self): - x1 = input_layer.Input((2,), name="x1") - x2 = input_layer.Input((2,), name="x2") - y1 = core.Dense(4)(merge.Add()([x1, x2])) - y2 = core.Dense(4)(merge.Multiply()([x1, x2])) - model = training.Model([x1, x2], [y1, y2]) - save_dir = os.path.join(self.get_temp_dir(), "saved_model") - save.save(model, save_dir) - outputs = model([array_ops.ones([1, 2]), 2. * array_ops.ones([1, 2])]) - self.assertAllClose( - {"dense": outputs[0], "dense_1": outputs[1]}, - _import_and_infer( - save_dir, - {"x1": [[1., 1.]], - "x2": [[2., 2.]]})) - class AssetTests(test.TestCase): @@ -376,7 +344,7 @@ class MemoryTests(test.TestCase): def test_no_reference_cycles(self): x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) - self._model(x, y) + self._model.call(x, y) if sys.version_info[0] < 3: # TODO(allenl): debug reference cycles in Python 2.x self.skipTest("This test only works in Python 3+. Reference cycles are "