diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 02e43110697..8075cc3fd0f 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -20,7 +20,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.engine import base_layer @@ -170,6 +172,13 @@ class InputLayer(base_layer.Layer): input_tensor._keras_mask = None node_module.Node(layer=self, outputs=input_tensor) + # Store type spec + if isinstance(input_tensor, composite_tensor.CompositeTensor): + self._type_spec = input_tensor._type_spec # pylint: disable=protected-access + else: + self._type_spec = tensor_spec.TensorSpec( + shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name) + def get_config(self): config = { 'batch_input_shape': self._batch_input_shape, diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 4ada84191dc..c6cc2f7a1d5 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -57,6 +57,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import save as tf_save @@ -730,6 +731,42 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose(layer.states, loaded_layer.states) self.assertAllClose(model(input_arr), loaded(input_arr)) + def testSaveWithRaggedInputs(self): + + class EmbeddingMerger(keras.layers.Layer): + + def __init__(self, list_features, **kwargs): + super().__init__(**kwargs) + self._supports_ragged_inputs = True + self.embeddings = { + feature: keras.layers.Embedding(10, 3) for feature in list_features} + self.mean = keras.layers.Lambda( + math_ops.reduce_mean, arguments=dict(axis=1)) + + def call(self, inputs): + tensors = [self.embeddings[col](inputs[col]) for col in inputs] + tensors = [self.mean(inp) for inp in tensors] + return keras.layers.Add()(tensors) + + list_features = ['feature_1', 'feature_2'] + feature_1 = ragged_factory_ops.constant([[0.], [1, 3]]) + feature_2 = ragged_factory_ops.constant([[1., 2], [4]]) + f = {'feature_1': feature_1, + 'feature_2': feature_2} + f_inputs = { + 'feature_1': keras.Input(shape=(None,), name='feature_1', ragged=True), + 'feature_2': keras.Input(shape=(None,), name='feature_2', ragged=True)} + + out = EmbeddingMerger(list_features)(f_inputs) + model = keras.Model(f_inputs, out) + self.evaluate(variables.variables_initializer(model.variables)) + saved_model_dir = self._save_model_dir() + tf_save.save(model, saved_model_dir) + + loaded = keras_load.load(saved_model_dir) + self.evaluate(variables.variables_initializer(loaded.variables)) + self.assertAllClose(model.predict(f), loaded.predict(f)) + class TestLayerCallTracing(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index b87ca1623b0..2c8f0f58f6c 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -481,11 +481,15 @@ def dataset_is_infinite(dataset): def get_tensor_spec(t, dynamic_batch=False, name=None): """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" + # pylint: disable=protected-access if isinstance(t, type_spec.TypeSpec): spec = t elif isinstance(t, composite_tensor.CompositeTensor): # TODO(b/148821952): Should these specs have a name attr? - spec = t._type_spec # pylint: disable=protected-access + spec = t._type_spec + elif (hasattr(t, '_keras_history') and + hasattr(t._keras_history[0], '_type_spec')): + return t._keras_history[0]._type_spec elif hasattr(t, 'shape') and hasattr(t, 'dtype'): spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) else: @@ -496,11 +500,12 @@ def get_tensor_spec(t, dynamic_batch=False, name=None): dynamic_batch_spec = copy.deepcopy(spec) # RaggedTensorSpec only has a private _shape. - shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access + shape = dynamic_batch_spec._shape.as_list() if shape: shape[0] = None - dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access + dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) return dynamic_batch_spec + # pylint: enable=protected-access def to_numpy_or_python_type(tensors):