Always wrap non-sequence types like maps in a list when tracing Model call

Avoids input_signature and unwrapping issues.

Fixes #26591 (again).

PiperOrigin-RevId: 238725001
This commit is contained in:
Allen Lavoie 2019-03-15 15:47:51 -07:00 committed by TensorFlower Gardener
parent fed677e9dd
commit 4f086f4c0f
3 changed files with 45 additions and 2 deletions

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import nest
@ -82,8 +84,13 @@ def trace_model_call(model, input_signature=None):
input_specs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_input_specs)
# The input signature of the call function is a list with one element, since
# all tensor inputs must be passed in as the first argument.
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
# all tensor inputs must be passed in as the first argument. Single-element
# dictionaries and other non-sequence types must also be wrapped.
if (len(input_specs) > 1
or not isinstance(input_specs, collections.Sequence)):
input_signature = [input_specs]
else:
input_signature = input_specs
# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)

View File

@ -28,12 +28,15 @@ from tensorflow.python import tf2
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -130,6 +133,27 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self._assert_all_close(expected_outputs, signature_outputs)
@keras_parameterized.run_all_keras_modes
def test_trace_features_layer(self):
columns = [feature_column_v2.numeric_column('x')]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns)])
model_input = {'x': constant_op.constant([[1.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))
columns = [feature_column_v2.numeric_column('x'),
feature_column_v2.numeric_column('y')]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns)])
model_input = {'x': constant_op.constant([[1.]]),
'y': constant_op.constant([[2.]])}
model.predict(model_input, steps=1)
fn = saving_utils.trace_model_call(model)
self.assertAllClose({'output_1': [[1., 2.]]},
fn({'x': [[1.]], 'y': [[2.]]}))
@keras_parameterized.run_all_keras_modes
def test_specify_input_signature(self):
model = testing_utils.get_small_sequential_mlp(10, 3, None)

View File

@ -1186,6 +1186,18 @@ class LoadTest(test.TestCase, parameterized.TestCase):
**model_input).values()
self.assertAllClose([[1., 2.]], signature_output)
def test_dense_features_layer_fit(self, cycles):
columns = [feature_column_v2.numeric_column("x")]
model = sequential.Sequential(
[feature_column_v2.DenseFeatures(columns),
core.Dense(1)])
model_input = {"x": constant_op.constant([[1.]])}
model.compile(optimizer="adam", loss="mse")
model.fit(model_input, constant_op.constant([[3.]]))
loaded = self.cycle(model, cycles)
loaded._default_save_signature(model_input)
loaded.signatures["serving_default"](**model_input)
class SingleCycleTests(test.TestCase, parameterized.TestCase):