Always wrap non-sequence types like maps in a list when tracing Model call
Avoids input_signature and unwrapping issues. Fixes #26591 (again). PiperOrigin-RevId: 238725001
This commit is contained in:
parent
fed677e9dd
commit
4f086f4c0f
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.util import nest
|
||||
@ -82,8 +84,13 @@ def trace_model_call(model, input_signature=None):
|
||||
input_specs = nest.pack_sequence_as(structure=inputs,
|
||||
flat_sequence=flat_input_specs)
|
||||
# The input signature of the call function is a list with one element, since
|
||||
# all tensor inputs must be passed in as the first argument.
|
||||
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
|
||||
# all tensor inputs must be passed in as the first argument. Single-element
|
||||
# dictionaries and other non-sequence types must also be wrapped.
|
||||
if (len(input_specs) > 1
|
||||
or not isinstance(input_specs, collections.Sequence)):
|
||||
input_signature = [input_specs]
|
||||
else:
|
||||
input_signature = input_specs
|
||||
|
||||
# TODO(mdan): Should the model's call be autographed by default?
|
||||
@def_function.function(input_signature=input_signature, autograph=False)
|
||||
|
@ -28,12 +28,15 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.feature_column import feature_column_v2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -130,6 +133,27 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_trace_features_layer(self):
|
||||
columns = [feature_column_v2.numeric_column('x')]
|
||||
model = sequential.Sequential(
|
||||
[feature_column_v2.DenseFeatures(columns)])
|
||||
model_input = {'x': constant_op.constant([[1.]])}
|
||||
model.predict(model_input, steps=1)
|
||||
fn = saving_utils.trace_model_call(model)
|
||||
self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))
|
||||
|
||||
columns = [feature_column_v2.numeric_column('x'),
|
||||
feature_column_v2.numeric_column('y')]
|
||||
model = sequential.Sequential(
|
||||
[feature_column_v2.DenseFeatures(columns)])
|
||||
model_input = {'x': constant_op.constant([[1.]]),
|
||||
'y': constant_op.constant([[2.]])}
|
||||
model.predict(model_input, steps=1)
|
||||
fn = saving_utils.trace_model_call(model)
|
||||
self.assertAllClose({'output_1': [[1., 2.]]},
|
||||
fn({'x': [[1.]], 'y': [[2.]]}))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_specify_input_signature(self):
|
||||
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
||||
|
@ -1186,6 +1186,18 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
**model_input).values()
|
||||
self.assertAllClose([[1., 2.]], signature_output)
|
||||
|
||||
def test_dense_features_layer_fit(self, cycles):
|
||||
columns = [feature_column_v2.numeric_column("x")]
|
||||
model = sequential.Sequential(
|
||||
[feature_column_v2.DenseFeatures(columns),
|
||||
core.Dense(1)])
|
||||
model_input = {"x": constant_op.constant([[1.]])}
|
||||
model.compile(optimizer="adam", loss="mse")
|
||||
model.fit(model_input, constant_op.constant([[3.]]))
|
||||
loaded = self.cycle(model, cycles)
|
||||
loaded._default_save_signature(model_input)
|
||||
loaded.signatures["serving_default"](**model_input)
|
||||
|
||||
|
||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user