Always wrap non-sequence types like maps in a list when tracing Model call
Avoids input_signature and unwrapping issues. Fixes #26591 (again). PiperOrigin-RevId: 238725001
This commit is contained in:
parent
fed677e9dd
commit
4f086f4c0f
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -82,8 +84,13 @@ def trace_model_call(model, input_signature=None):
|
|||||||
input_specs = nest.pack_sequence_as(structure=inputs,
|
input_specs = nest.pack_sequence_as(structure=inputs,
|
||||||
flat_sequence=flat_input_specs)
|
flat_sequence=flat_input_specs)
|
||||||
# The input signature of the call function is a list with one element, since
|
# The input signature of the call function is a list with one element, since
|
||||||
# all tensor inputs must be passed in as the first argument.
|
# all tensor inputs must be passed in as the first argument. Single-element
|
||||||
input_signature = [input_specs] if len(input_specs) > 1 else input_specs
|
# dictionaries and other non-sequence types must also be wrapped.
|
||||||
|
if (len(input_specs) > 1
|
||||||
|
or not isinstance(input_specs, collections.Sequence)):
|
||||||
|
input_signature = [input_specs]
|
||||||
|
else:
|
||||||
|
input_signature = input_specs
|
||||||
|
|
||||||
# TODO(mdan): Should the model's call be autographed by default?
|
# TODO(mdan): Should the model's call be autographed by default?
|
||||||
@def_function.function(input_signature=input_signature, autograph=False)
|
@def_function.function(input_signature=input_signature, autograph=False)
|
||||||
|
@ -28,12 +28,15 @@ from tensorflow.python import tf2
|
|||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.feature_column import feature_column_v2
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.saving import saving_utils
|
from tensorflow.python.keras.saving import saving_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -130,6 +133,27 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
self._assert_all_close(expected_outputs, signature_outputs)
|
self._assert_all_close(expected_outputs, signature_outputs)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_trace_features_layer(self):
|
||||||
|
columns = [feature_column_v2.numeric_column('x')]
|
||||||
|
model = sequential.Sequential(
|
||||||
|
[feature_column_v2.DenseFeatures(columns)])
|
||||||
|
model_input = {'x': constant_op.constant([[1.]])}
|
||||||
|
model.predict(model_input, steps=1)
|
||||||
|
fn = saving_utils.trace_model_call(model)
|
||||||
|
self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))
|
||||||
|
|
||||||
|
columns = [feature_column_v2.numeric_column('x'),
|
||||||
|
feature_column_v2.numeric_column('y')]
|
||||||
|
model = sequential.Sequential(
|
||||||
|
[feature_column_v2.DenseFeatures(columns)])
|
||||||
|
model_input = {'x': constant_op.constant([[1.]]),
|
||||||
|
'y': constant_op.constant([[2.]])}
|
||||||
|
model.predict(model_input, steps=1)
|
||||||
|
fn = saving_utils.trace_model_call(model)
|
||||||
|
self.assertAllClose({'output_1': [[1., 2.]]},
|
||||||
|
fn({'x': [[1.]], 'y': [[2.]]}))
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_specify_input_signature(self):
|
def test_specify_input_signature(self):
|
||||||
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
model = testing_utils.get_small_sequential_mlp(10, 3, None)
|
||||||
|
@ -1186,6 +1186,18 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
**model_input).values()
|
**model_input).values()
|
||||||
self.assertAllClose([[1., 2.]], signature_output)
|
self.assertAllClose([[1., 2.]], signature_output)
|
||||||
|
|
||||||
|
def test_dense_features_layer_fit(self, cycles):
|
||||||
|
columns = [feature_column_v2.numeric_column("x")]
|
||||||
|
model = sequential.Sequential(
|
||||||
|
[feature_column_v2.DenseFeatures(columns),
|
||||||
|
core.Dense(1)])
|
||||||
|
model_input = {"x": constant_op.constant([[1.]])}
|
||||||
|
model.compile(optimizer="adam", loss="mse")
|
||||||
|
model.fit(model_input, constant_op.constant([[3.]]))
|
||||||
|
loaded = self.cycle(model, cycles)
|
||||||
|
loaded._default_save_signature(model_input)
|
||||||
|
loaded.signatures["serving_default"](**model_input)
|
||||||
|
|
||||||
|
|
||||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user