Store TypeSpec in Keras input layer, and use it when tracing the model.
PiperOrigin-RevId: 313714149 Change-Id: I893d7fecda2ac41568a6bc658251a4be14c2211d
This commit is contained in:
parent
e9ad6196a6
commit
618ff4c618
@ -20,7 +20,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -170,6 +172,13 @@ class InputLayer(base_layer.Layer):
|
||||
input_tensor._keras_mask = None
|
||||
node_module.Node(layer=self, outputs=input_tensor)
|
||||
|
||||
# Store type spec
|
||||
if isinstance(input_tensor, composite_tensor.CompositeTensor):
|
||||
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
|
||||
else:
|
||||
self._type_spec = tensor_spec.TensorSpec(
|
||||
shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'batch_input_shape': self._batch_input_shape,
|
||||
|
@ -57,6 +57,7 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
@ -730,6 +731,42 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.assertAllClose(layer.states, loaded_layer.states)
|
||||
self.assertAllClose(model(input_arr), loaded(input_arr))
|
||||
|
||||
def testSaveWithRaggedInputs(self):
|
||||
|
||||
class EmbeddingMerger(keras.layers.Layer):
|
||||
|
||||
def __init__(self, list_features, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._supports_ragged_inputs = True
|
||||
self.embeddings = {
|
||||
feature: keras.layers.Embedding(10, 3) for feature in list_features}
|
||||
self.mean = keras.layers.Lambda(
|
||||
math_ops.reduce_mean, arguments=dict(axis=1))
|
||||
|
||||
def call(self, inputs):
|
||||
tensors = [self.embeddings[col](inputs[col]) for col in inputs]
|
||||
tensors = [self.mean(inp) for inp in tensors]
|
||||
return keras.layers.Add()(tensors)
|
||||
|
||||
list_features = ['feature_1', 'feature_2']
|
||||
feature_1 = ragged_factory_ops.constant([[0.], [1, 3]])
|
||||
feature_2 = ragged_factory_ops.constant([[1., 2], [4]])
|
||||
f = {'feature_1': feature_1,
|
||||
'feature_2': feature_2}
|
||||
f_inputs = {
|
||||
'feature_1': keras.Input(shape=(None,), name='feature_1', ragged=True),
|
||||
'feature_2': keras.Input(shape=(None,), name='feature_2', ragged=True)}
|
||||
|
||||
out = EmbeddingMerger(list_features)(f_inputs)
|
||||
model = keras.Model(f_inputs, out)
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(model.predict(f), loaded.predict(f))
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -481,11 +481,15 @@ def dataset_is_infinite(dataset):
|
||||
|
||||
def get_tensor_spec(t, dynamic_batch=False, name=None):
|
||||
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(t, type_spec.TypeSpec):
|
||||
spec = t
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
# TODO(b/148821952): Should these specs have a name attr?
|
||||
spec = t._type_spec # pylint: disable=protected-access
|
||||
spec = t._type_spec
|
||||
elif (hasattr(t, '_keras_history') and
|
||||
hasattr(t._keras_history[0], '_type_spec')):
|
||||
return t._keras_history[0]._type_spec
|
||||
elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
|
||||
spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
|
||||
else:
|
||||
@ -496,11 +500,12 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
|
||||
|
||||
dynamic_batch_spec = copy.deepcopy(spec)
|
||||
# RaggedTensorSpec only has a private _shape.
|
||||
shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access
|
||||
shape = dynamic_batch_spec._shape.as_list()
|
||||
if shape:
|
||||
shape[0] = None
|
||||
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access
|
||||
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape)
|
||||
return dynamic_batch_spec
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def to_numpy_or_python_type(tensors):
|
||||
|
Loading…
Reference in New Issue
Block a user