From 4f086f4c0fc5547bf91d0036ccecf33d69c1303a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 15 Mar 2019 15:47:51 -0700 Subject: [PATCH] Always wrap non-sequence types like maps in a list when tracing Model call Avoids input_signature and unwrapping issues. Fixes #26591 (again). PiperOrigin-RevId: 238725001 --- .../python/keras/saving/saving_utils.py | 11 +++++++-- .../python/keras/saving/saving_utils_test.py | 24 +++++++++++++++++++ tensorflow/python/saved_model/load_test.py | 12 ++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index 55ddde5c7c5..2751fbc1954 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.python.eager import def_function from tensorflow.python.framework import tensor_spec from tensorflow.python.util import nest @@ -82,8 +84,13 @@ def trace_model_call(model, input_signature=None): input_specs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input_specs) # The input signature of the call function is a list with one element, since - # all tensor inputs must be passed in as the first argument. - input_signature = [input_specs] if len(input_specs) > 1 else input_specs + # all tensor inputs must be passed in as the first argument. Single-element + # dictionaries and other non-sequence types must also be wrapped. + if (len(input_specs) > 1 + or not isinstance(input_specs, collections.Sequence)): + input_signature = [input_specs] + else: + input_signature = input_specs # TODO(mdan): Should the model's call be autographed by default? @def_function.function(input_signature=input_signature, autograph=False) diff --git a/tensorflow/python/keras/saving/saving_utils_test.py b/tensorflow/python/keras/saving/saving_utils_test.py index 5952a4d7638..65105f80f12 100644 --- a/tensorflow/python/keras/saving/saving_utils_test.py +++ b/tensorflow/python/keras/saving/saving_utils_test.py @@ -28,12 +28,15 @@ from tensorflow.python import tf2 from tensorflow.python.client import session as session_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.feature_column import feature_column_v2 +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.saving import saving_utils from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -130,6 +133,27 @@ class TraceModelCallTest(keras_parameterized.TestCase): self._assert_all_close(expected_outputs, signature_outputs) + @keras_parameterized.run_all_keras_modes + def test_trace_features_layer(self): + columns = [feature_column_v2.numeric_column('x')] + model = sequential.Sequential( + [feature_column_v2.DenseFeatures(columns)]) + model_input = {'x': constant_op.constant([[1.]])} + model.predict(model_input, steps=1) + fn = saving_utils.trace_model_call(model) + self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]})) + + columns = [feature_column_v2.numeric_column('x'), + feature_column_v2.numeric_column('y')] + model = sequential.Sequential( + [feature_column_v2.DenseFeatures(columns)]) + model_input = {'x': constant_op.constant([[1.]]), + 'y': constant_op.constant([[2.]])} + model.predict(model_input, steps=1) + fn = saving_utils.trace_model_call(model) + self.assertAllClose({'output_1': [[1., 2.]]}, + fn({'x': [[1.]], 'y': [[2.]]})) + @keras_parameterized.run_all_keras_modes def test_specify_input_signature(self): model = testing_utils.get_small_sequential_mlp(10, 3, None) diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index d80de3c87e3..311685c0565 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -1186,6 +1186,18 @@ class LoadTest(test.TestCase, parameterized.TestCase): **model_input).values() self.assertAllClose([[1., 2.]], signature_output) + def test_dense_features_layer_fit(self, cycles): + columns = [feature_column_v2.numeric_column("x")] + model = sequential.Sequential( + [feature_column_v2.DenseFeatures(columns), + core.Dense(1)]) + model_input = {"x": constant_op.constant([[1.]])} + model.compile(optimizer="adam", loss="mse") + model.fit(model_input, constant_op.constant([[3.]])) + loaded = self.cycle(model, cycles) + loaded._default_save_signature(model_input) + loaded.signatures["serving_default"](**model_input) + class SingleCycleTests(test.TestCase, parameterized.TestCase):