Add TypeSpec to JSON encoder/Decoder
This change adds an argument to `register_type_spec_from_value_converter` so that the type spec class is also saved to the registry. PiperOrigin-RevId: 345335194 Change-Id: I2561e285ad2643cc720533bf3a47b8abffa11608
This commit is contained in:
parent
fc52e64328
commit
bf2209a095
@ -492,6 +492,7 @@ class NoneTensor(composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||||
# functionality.
|
# functionality.
|
||||||
|
@type_spec.register("tf.NoneTensorSpec")
|
||||||
class NoneTensorSpec(type_spec.BatchableTypeSpec):
|
class NoneTensorSpec(type_spec.BatchableTypeSpec):
|
||||||
"""Type specification for `None` value."""
|
"""Type specification for `None` value."""
|
||||||
|
|
||||||
|
@ -292,6 +292,7 @@ _pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("SparseTensorSpec")
|
@tf_export("SparseTensorSpec")
|
||||||
|
@type_spec.register("tf.SparseTensorSpec")
|
||||||
class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||||
"""Type specification for a `tf.sparse.SparseTensor`."""
|
"""Type specification for a `tf.sparse.SparseTensor`."""
|
||||||
|
|
||||||
|
@ -114,6 +114,7 @@ class DenseSpec(type_spec.TypeSpec):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("TensorSpec")
|
@tf_export("TensorSpec")
|
||||||
|
@type_spec.register("tf.TensorSpec")
|
||||||
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
|
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
|
||||||
"""Describes a tf.Tensor.
|
"""Describes a tf.Tensor.
|
||||||
|
|
||||||
|
@ -32,12 +32,14 @@ import wrapt
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
|
|
||||||
|
|
||||||
class Encoder(json.JSONEncoder):
|
class Encoder(json.JSONEncoder):
|
||||||
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||||
|
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
|
"""Encodes objects for types that aren't handled by the default encoder."""
|
||||||
if isinstance(obj, tensor_shape.TensorShape):
|
if isinstance(obj, tensor_shape.TensorShape):
|
||||||
items = obj.as_list() if obj.rank is not None else None
|
items = obj.as_list() if obj.rank is not None else None
|
||||||
return {'class_name': 'TensorShape', 'items': items}
|
return {'class_name': 'TensorShape', 'items': items}
|
||||||
@ -68,6 +70,9 @@ def _decode_helper(obj):
|
|||||||
if isinstance(obj, dict) and 'class_name' in obj:
|
if isinstance(obj, dict) and 'class_name' in obj:
|
||||||
if obj['class_name'] == 'TensorShape':
|
if obj['class_name'] == 'TensorShape':
|
||||||
return tensor_shape.TensorShape(obj['items'])
|
return tensor_shape.TensorShape(obj['items'])
|
||||||
|
elif obj['class_name'] == 'TypeSpec':
|
||||||
|
return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access
|
||||||
|
_decode_helper(obj['serialized']))
|
||||||
elif obj['class_name'] == '__tuple__':
|
elif obj['class_name'] == '__tuple__':
|
||||||
return tuple(_decode_helper(i) for i in obj['items'])
|
return tuple(_decode_helper(i) for i in obj['items'])
|
||||||
elif obj['class_name'] == '__ellipsis__':
|
elif obj['class_name'] == '__ellipsis__':
|
||||||
@ -125,4 +130,15 @@ def get_json_type(obj):
|
|||||||
if isinstance(obj, wrapt.ObjectProxy):
|
if isinstance(obj, wrapt.ObjectProxy):
|
||||||
return obj.__wrapped__
|
return obj.__wrapped__
|
||||||
|
|
||||||
|
if isinstance(obj, type_spec.TypeSpec):
|
||||||
|
try:
|
||||||
|
type_spec_name = type_spec.get_name(type(obj))
|
||||||
|
return {'class_name': 'TypeSpec', 'type_spec': type_spec_name,
|
||||||
|
'serialized': obj._serialize()} # pylint: disable=protected-access
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError('Unable to serialize {} to JSON, because the TypeSpec '
|
||||||
|
'class {} has not been registered.'
|
||||||
|
.format(obj, type(obj)))
|
||||||
|
|
||||||
raise TypeError('Not JSON Serializable:', obj)
|
raise TypeError('Not JSON Serializable:', obj)
|
||||||
|
|
||||||
|
@ -19,7 +19,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -50,6 +52,18 @@ class JsonUtilsTest(test.TestCase):
|
|||||||
self.assertAllEqual(loaded['key1'], (3, 5))
|
self.assertAllEqual(loaded['key1'], (3, 5))
|
||||||
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
|
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
|
||||||
|
|
||||||
|
def test_encode_decode_type_spec(self):
|
||||||
|
spec = tensor_spec.TensorSpec((1, 5), dtypes.float32)
|
||||||
|
string = json_utils.Encoder().encode(spec)
|
||||||
|
loaded = json_utils.decode(string)
|
||||||
|
self.assertEqual(spec, loaded)
|
||||||
|
|
||||||
|
invalid_type_spec = {'class_name': 'TypeSpec', 'type_spec': 'Invalid Type',
|
||||||
|
'serialized': None}
|
||||||
|
string = json_utils.Encoder().encode(invalid_type_spec)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'No TypeSpec has been registered'):
|
||||||
|
loaded = json_utils.decode(string)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2178,6 +2178,7 @@ def match_row_splits_dtypes(*tensors, **kwargs):
|
|||||||
# RaggedTensorSpec
|
# RaggedTensorSpec
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
@tf_export("RaggedTensorSpec")
|
@tf_export("RaggedTensorSpec")
|
||||||
|
@type_spec.register("tf.RaggedTensorSpec")
|
||||||
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||||
"""Type specification for a `tf.RaggedTensor`."""
|
"""Type specification for a `tf.RaggedTensor`."""
|
||||||
|
|
||||||
|
@ -1313,6 +1313,7 @@ def _check_dtypes(value, dtype):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("TensorArraySpec")
|
@tf_export("TensorArraySpec")
|
||||||
|
@type_spec.register("tf.TensorArraySpec")
|
||||||
class TensorArraySpec(type_spec.TypeSpec):
|
class TensorArraySpec(type_spec.TypeSpec):
|
||||||
"""Type specification for a `tf.TensorArray`."""
|
"""Type specification for a `tf.TensorArray`."""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user