Store TypeSpec in Keras input layer, and use it when tracing the model.

PiperOrigin-RevId: 313714149
Change-Id: I893d7fecda2ac41568a6bc658251a4be14c2211d
This commit is contained in:
Katherine Wu 2020-05-28 21:02:58 -07:00 committed by TensorFlower Gardener
parent e9ad6196a6
commit 618ff4c618
3 changed files with 54 additions and 3 deletions

View File

@ -20,7 +20,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.engine import base_layer
@ -170,6 +172,13 @@ class InputLayer(base_layer.Layer):
input_tensor._keras_mask = None
node_module.Node(layer=self, outputs=input_tensor)
# Store type spec
if isinstance(input_tensor, composite_tensor.CompositeTensor):
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
else:
self._type_spec = tensor_spec.TensorSpec(
shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name)
def get_config(self):
config = {
'batch_input_shape': self._batch_input_shape,

View File

@ -57,6 +57,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import save as tf_save
@ -730,6 +731,42 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(layer.states, loaded_layer.states)
self.assertAllClose(model(input_arr), loaded(input_arr))
def testSaveWithRaggedInputs(self):
class EmbeddingMerger(keras.layers.Layer):
def __init__(self, list_features, **kwargs):
super().__init__(**kwargs)
self._supports_ragged_inputs = True
self.embeddings = {
feature: keras.layers.Embedding(10, 3) for feature in list_features}
self.mean = keras.layers.Lambda(
math_ops.reduce_mean, arguments=dict(axis=1))
def call(self, inputs):
tensors = [self.embeddings[col](inputs[col]) for col in inputs]
tensors = [self.mean(inp) for inp in tensors]
return keras.layers.Add()(tensors)
list_features = ['feature_1', 'feature_2']
feature_1 = ragged_factory_ops.constant([[0.], [1, 3]])
feature_2 = ragged_factory_ops.constant([[1., 2], [4]])
f = {'feature_1': feature_1,
'feature_2': feature_2}
f_inputs = {
'feature_1': keras.Input(shape=(None,), name='feature_1', ragged=True),
'feature_2': keras.Input(shape=(None,), name='feature_2', ragged=True)}
out = EmbeddingMerger(list_features)(f_inputs)
model = keras.Model(f_inputs, out)
self.evaluate(variables.variables_initializer(model.variables))
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
loaded = keras_load.load(saved_model_dir)
self.evaluate(variables.variables_initializer(loaded.variables))
self.assertAllClose(model.predict(f), loaded.predict(f))
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):

View File

@ -481,11 +481,15 @@ def dataset_is_infinite(dataset):
def get_tensor_spec(t, dynamic_batch=False, name=None):
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
# pylint: disable=protected-access
if isinstance(t, type_spec.TypeSpec):
spec = t
elif isinstance(t, composite_tensor.CompositeTensor):
# TODO(b/148821952): Should these specs have a name attr?
spec = t._type_spec # pylint: disable=protected-access
spec = t._type_spec
elif (hasattr(t, '_keras_history') and
hasattr(t._keras_history[0], '_type_spec')):
return t._keras_history[0]._type_spec
elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
@ -496,11 +500,12 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
dynamic_batch_spec = copy.deepcopy(spec)
# RaggedTensorSpec only has a private _shape.
shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access
shape = dynamic_batch_spec._shape.as_list()
if shape:
shape[0] = None
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape)
return dynamic_batch_spec
# pylint: enable=protected-access
def to_numpy_or_python_type(tensors):