Add TypeSpec to JSON encoder/Decoder
This change adds an argument to `register_type_spec_from_value_converter` so that the type spec class is also saved to the registry. PiperOrigin-RevId: 345335194 Change-Id: I2561e285ad2643cc720533bf3a47b8abffa11608
This commit is contained in:
parent
fc52e64328
commit
bf2209a095
@ -492,6 +492,7 @@ class NoneTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||
# functionality.
|
||||
@type_spec.register("tf.NoneTensorSpec")
|
||||
class NoneTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for `None` value."""
|
||||
|
||||
|
@ -292,6 +292,7 @@ _pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
|
||||
|
||||
|
||||
@tf_export("SparseTensorSpec")
|
||||
@type_spec.register("tf.SparseTensorSpec")
|
||||
class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for a `tf.sparse.SparseTensor`."""
|
||||
|
||||
|
@ -114,6 +114,7 @@ class DenseSpec(type_spec.TypeSpec):
|
||||
|
||||
|
||||
@tf_export("TensorSpec")
|
||||
@type_spec.register("tf.TensorSpec")
|
||||
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
|
||||
"""Describes a tf.Tensor.
|
||||
|
||||
|
@ -32,12 +32,14 @@ import wrapt
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import type_spec
|
||||
|
||||
|
||||
class Encoder(json.JSONEncoder):
|
||||
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||
|
||||
def default(self, obj):
|
||||
"""Encodes objects for types that aren't handled by the default encoder."""
|
||||
if isinstance(obj, tensor_shape.TensorShape):
|
||||
items = obj.as_list() if obj.rank is not None else None
|
||||
return {'class_name': 'TensorShape', 'items': items}
|
||||
@ -68,6 +70,9 @@ def _decode_helper(obj):
|
||||
if isinstance(obj, dict) and 'class_name' in obj:
|
||||
if obj['class_name'] == 'TensorShape':
|
||||
return tensor_shape.TensorShape(obj['items'])
|
||||
elif obj['class_name'] == 'TypeSpec':
|
||||
return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access
|
||||
_decode_helper(obj['serialized']))
|
||||
elif obj['class_name'] == '__tuple__':
|
||||
return tuple(_decode_helper(i) for i in obj['items'])
|
||||
elif obj['class_name'] == '__ellipsis__':
|
||||
@ -125,4 +130,15 @@ def get_json_type(obj):
|
||||
if isinstance(obj, wrapt.ObjectProxy):
|
||||
return obj.__wrapped__
|
||||
|
||||
if isinstance(obj, type_spec.TypeSpec):
|
||||
try:
|
||||
type_spec_name = type_spec.get_name(type(obj))
|
||||
return {'class_name': 'TypeSpec', 'type_spec': type_spec_name,
|
||||
'serialized': obj._serialize()} # pylint: disable=protected-access
|
||||
except ValueError:
|
||||
raise ValueError('Unable to serialize {} to JSON, because the TypeSpec '
|
||||
'class {} has not been registered.'
|
||||
.format(obj, type(obj)))
|
||||
|
||||
raise TypeError('Not JSON Serializable:', obj)
|
||||
|
||||
|
@ -19,7 +19,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -50,6 +52,18 @@ class JsonUtilsTest(test.TestCase):
|
||||
self.assertAllEqual(loaded['key1'], (3, 5))
|
||||
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
|
||||
|
||||
def test_encode_decode_type_spec(self):
|
||||
spec = tensor_spec.TensorSpec((1, 5), dtypes.float32)
|
||||
string = json_utils.Encoder().encode(spec)
|
||||
loaded = json_utils.decode(string)
|
||||
self.assertEqual(spec, loaded)
|
||||
|
||||
invalid_type_spec = {'class_name': 'TypeSpec', 'type_spec': 'Invalid Type',
|
||||
'serialized': None}
|
||||
string = json_utils.Encoder().encode(invalid_type_spec)
|
||||
with self.assertRaisesRegexp(ValueError, 'No TypeSpec has been registered'):
|
||||
loaded = json_utils.decode(string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -2178,6 +2178,7 @@ def match_row_splits_dtypes(*tensors, **kwargs):
|
||||
# RaggedTensorSpec
|
||||
#===============================================================================
|
||||
@tf_export("RaggedTensorSpec")
|
||||
@type_spec.register("tf.RaggedTensorSpec")
|
||||
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for a `tf.RaggedTensor`."""
|
||||
|
||||
|
@ -1313,6 +1313,7 @@ def _check_dtypes(value, dtype):
|
||||
|
||||
|
||||
@tf_export("TensorArraySpec")
|
||||
@type_spec.register("tf.TensorArraySpec")
|
||||
class TensorArraySpec(type_spec.TypeSpec):
|
||||
"""Type specification for a `tf.TensorArray`."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user