Add TypeSpec to JSON encoder/Decoder

This change adds an argument to `register_type_spec_from_value_converter` so that the type spec class is also saved to the registry.

PiperOrigin-RevId: 345335194
Change-Id: I2561e285ad2643cc720533bf3a47b8abffa11608
This commit is contained in:
Katherine Wu 2020-12-02 16:17:50 -08:00 committed by TensorFlower Gardener
parent fc52e64328
commit bf2209a095
7 changed files with 35 additions and 0 deletions

View File

@ -492,6 +492,7 @@ class NoneTensor(composite_tensor.CompositeTensor):
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
# functionality.
@type_spec.register("tf.NoneTensorSpec")
class NoneTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for `None` value."""

View File

@ -292,6 +292,7 @@ _pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
@tf_export("SparseTensorSpec")
@type_spec.register("tf.SparseTensorSpec")
class SparseTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.sparse.SparseTensor`."""

View File

@ -114,6 +114,7 @@ class DenseSpec(type_spec.TypeSpec):
@tf_export("TensorSpec")
@type_spec.register("tf.TensorSpec")
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
"""Describes a tf.Tensor.

View File

@ -32,12 +32,14 @@ import wrapt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
class Encoder(json.JSONEncoder):
"""JSON encoder and decoder that handles TensorShapes and tuples."""
def default(self, obj):
"""Encodes objects for types that aren't handled by the default encoder."""
if isinstance(obj, tensor_shape.TensorShape):
items = obj.as_list() if obj.rank is not None else None
return {'class_name': 'TensorShape', 'items': items}
@ -68,6 +70,9 @@ def _decode_helper(obj):
if isinstance(obj, dict) and 'class_name' in obj:
if obj['class_name'] == 'TensorShape':
return tensor_shape.TensorShape(obj['items'])
elif obj['class_name'] == 'TypeSpec':
return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access
_decode_helper(obj['serialized']))
elif obj['class_name'] == '__tuple__':
return tuple(_decode_helper(i) for i in obj['items'])
elif obj['class_name'] == '__ellipsis__':
@ -125,4 +130,15 @@ def get_json_type(obj):
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__
if isinstance(obj, type_spec.TypeSpec):
try:
type_spec_name = type_spec.get_name(type(obj))
return {'class_name': 'TypeSpec', 'type_spec': type_spec_name,
'serialized': obj._serialize()} # pylint: disable=protected-access
except ValueError:
raise ValueError('Unable to serialize {} to JSON, because the TypeSpec '
'class {} has not been registered.'
.format(obj, type(obj)))
raise TypeError('Not JSON Serializable:', obj)

View File

@ -19,7 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.platform import test
@ -50,6 +52,18 @@ class JsonUtilsTest(test.TestCase):
self.assertAllEqual(loaded['key1'], (3, 5))
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
def test_encode_decode_type_spec(self):
spec = tensor_spec.TensorSpec((1, 5), dtypes.float32)
string = json_utils.Encoder().encode(spec)
loaded = json_utils.decode(string)
self.assertEqual(spec, loaded)
invalid_type_spec = {'class_name': 'TypeSpec', 'type_spec': 'Invalid Type',
'serialized': None}
string = json_utils.Encoder().encode(invalid_type_spec)
with self.assertRaisesRegexp(ValueError, 'No TypeSpec has been registered'):
loaded = json_utils.decode(string)
if __name__ == '__main__':
test.main()

View File

@ -2178,6 +2178,7 @@ def match_row_splits_dtypes(*tensors, **kwargs):
# RaggedTensorSpec
#===============================================================================
@tf_export("RaggedTensorSpec")
@type_spec.register("tf.RaggedTensorSpec")
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.RaggedTensor`."""

View File

@ -1313,6 +1313,7 @@ def _check_dtypes(value, dtype):
@tf_export("TensorArraySpec")
@type_spec.register("tf.TensorArraySpec")
class TensorArraySpec(type_spec.TypeSpec):
"""Type specification for a `tf.TensorArray`."""