Support for encoding/decoding BoundedTensorSpec objects.
PiperOrigin-RevId: 303036023 Change-Id: I38691dc2ced5162b77964f44bc17a81684c34923
This commit is contained in:
parent
97099b3610
commit
3f36b5d5bf
@ -2,6 +2,7 @@ syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
@ -60,6 +61,8 @@ message StructuredValue {
|
||||
TensorSpecProto tensor_spec_value = 33;
|
||||
// Represents a value for tf.TypeSpec.
|
||||
TypeSpecProto type_spec_value = 34;
|
||||
// Represents a value for tf.BoundedTensorSpec.
|
||||
BoundedTensorSpecProto bounded_tensor_spec_value = 35;
|
||||
|
||||
// Represents a list of `Value`.
|
||||
ListValue list_value = 51;
|
||||
@ -103,13 +106,22 @@ message NamedTupleValue {
|
||||
repeated PairValue values = 2;
|
||||
}
|
||||
|
||||
// A protobuf to tf.TensorSpec.
|
||||
// A protobuf to represent tf.TensorSpec.
|
||||
message TensorSpecProto {
|
||||
string name = 1;
|
||||
tensorflow.TensorShapeProto shape = 2;
|
||||
tensorflow.DataType dtype = 3;
|
||||
}
|
||||
|
||||
// A protobuf to represent tf.BoundedTensorSpec.
|
||||
message BoundedTensorSpecProto {
|
||||
string name = 1;
|
||||
tensorflow.TensorShapeProto shape = 2;
|
||||
tensorflow.DataType dtype = 3;
|
||||
tensorflow.TensorProto minimum = 4;
|
||||
tensorflow.TensorProto maximum = 5;
|
||||
}
|
||||
|
||||
// Represents a tf.TypeSpec
|
||||
message TypeSpecProto {
|
||||
enum TypeSpecClass {
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -420,7 +421,9 @@ class _TensorSpecCodec(object):
|
||||
"""Codec for `TensorSpec`."""
|
||||
|
||||
def can_encode(self, pyobj):
|
||||
return isinstance(pyobj, tensor_spec.TensorSpec)
|
||||
# BoundedTensorSpec has its own decoder.
|
||||
return (isinstance(pyobj, tensor_spec.TensorSpec) and
|
||||
not isinstance(pyobj, tensor_spec.BoundedTensorSpec))
|
||||
|
||||
def do_encode(self, tensor_spec_value, encode_fn):
|
||||
encoded_tensor_spec = struct_pb2.StructuredValue()
|
||||
@ -449,6 +452,45 @@ class _TensorSpecCodec(object):
|
||||
StructureCoder.register_codec(_TensorSpecCodec())
|
||||
|
||||
|
||||
class _BoundedTensorSpecCodec(object):
|
||||
"""Codec for `BoundedTensorSpec`."""
|
||||
|
||||
def can_encode(self, pyobj):
|
||||
return isinstance(pyobj, tensor_spec.BoundedTensorSpec)
|
||||
|
||||
def do_encode(self, bounded_tensor_spec_value, encode_fn):
|
||||
"""Returns an encoded proto for the given `tf.BoundedTensorSpec`."""
|
||||
encoded_bounded_tensor_spec = struct_pb2.StructuredValue()
|
||||
encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom(
|
||||
struct_pb2.BoundedTensorSpecProto(
|
||||
shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value,
|
||||
dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value,
|
||||
name=bounded_tensor_spec_value.name,
|
||||
minimum=tensor_util.make_tensor_proto(
|
||||
bounded_tensor_spec_value.minimum),
|
||||
maximum=tensor_util.make_tensor_proto(
|
||||
bounded_tensor_spec_value.maximum)))
|
||||
return encoded_bounded_tensor_spec
|
||||
|
||||
def can_decode(self, value):
|
||||
return value.HasField("bounded_tensor_spec_value")
|
||||
|
||||
def do_decode(self, value, decode_fn):
|
||||
btsv = value.bounded_tensor_spec_value
|
||||
name = btsv.name
|
||||
return tensor_spec.BoundedTensorSpec(
|
||||
shape=decode_fn(
|
||||
struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)),
|
||||
dtype=decode_fn(
|
||||
struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)),
|
||||
minimum=tensor_util.MakeNdarray(btsv.minimum),
|
||||
maximum=tensor_util.MakeNdarray(btsv.maximum),
|
||||
name=(name if name else None))
|
||||
|
||||
|
||||
StructureCoder.register_codec(_BoundedTensorSpecCodec())
|
||||
|
||||
|
||||
class _TypeSpecCodec(object):
|
||||
"""Codec for `tf.TypeSpec`."""
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
@ -35,6 +36,7 @@ from tensorflow.python.saved_model import nested_structure_coder
|
||||
class NestedStructureTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(NestedStructureTest, self).setUp()
|
||||
self._coder = nested_structure_coder.StructureCoder()
|
||||
|
||||
def testEncodeDecodeList(self):
|
||||
@ -271,6 +273,54 @@ class NestedStructureTest(test.TestCase):
|
||||
ValueError, "The type 'FutureTensorSpec' is not supported"):
|
||||
self._coder.decode_proto(encoded)
|
||||
|
||||
def testEncodeDecodeBoundedTensorSpec(self):
|
||||
structure = [
|
||||
tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10,
|
||||
"hello-0-10")
|
||||
]
|
||||
self.assertTrue(self._coder.can_encode(structure))
|
||||
encoded = self._coder.encode_structure(structure)
|
||||
expected = struct_pb2.StructuredValue()
|
||||
expected_list = expected.list_value
|
||||
expected_tensor_spec = expected_list.values.add().bounded_tensor_spec_value
|
||||
expected_tensor_spec.shape.dim.add().size = 1
|
||||
expected_tensor_spec.shape.dim.add().size = 2
|
||||
expected_tensor_spec.shape.dim.add().size = 3
|
||||
expected_tensor_spec.name = "hello-0-10"
|
||||
expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
|
||||
expected_tensor_spec.minimum.CopyFrom(
|
||||
tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[]))
|
||||
expected_tensor_spec.maximum.CopyFrom(
|
||||
tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[]))
|
||||
self.assertEqual(expected, encoded)
|
||||
decoded = self._coder.decode_proto(encoded)
|
||||
self.assertEqual(structure, decoded)
|
||||
|
||||
def testEncodeDecodeBoundedTensorSpecNoName(self):
|
||||
structure = [
|
||||
tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2,
|
||||
(1, 1, 20))
|
||||
]
|
||||
self.assertTrue(self._coder.can_encode(structure))
|
||||
encoded = self._coder.encode_structure(structure)
|
||||
expected = struct_pb2.StructuredValue()
|
||||
expected_list = expected.list_value
|
||||
expected_tensor_spec = expected_list.values.add().bounded_tensor_spec_value
|
||||
expected_tensor_spec.shape.dim.add().size = 28
|
||||
expected_tensor_spec.shape.dim.add().size = 28
|
||||
expected_tensor_spec.shape.dim.add().size = 3
|
||||
expected_tensor_spec.name = ""
|
||||
expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum
|
||||
expected_tensor_spec.minimum.CopyFrom(
|
||||
tensor_util.make_tensor_proto([-2], dtype=dtypes.float64, shape=[]))
|
||||
expected_tensor_spec.maximum.CopyFrom(
|
||||
tensor_util.make_tensor_proto([1, 1, 20],
|
||||
dtype=dtypes.float64,
|
||||
shape=[3]))
|
||||
self.assertEqual(expected, encoded)
|
||||
decoded = self._coder.decode_proto(encoded)
|
||||
self.assertEqual(structure, decoded)
|
||||
|
||||
def testEncodeDataSetSpec(self):
|
||||
structure = [dataset_ops.DatasetSpec(
|
||||
{"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
|
||||
|
Loading…
Reference in New Issue
Block a user