Add new "TypeSpec" classes, which can be used to specify types for objects accepted or returned by TensorFlow APIs. For example, TypeSpec subclasses may be used in tf.function's input_signature argument, and in tf.while_loop's shape_invariant argument.

PiperOrigin-RevId: 253120113
This commit is contained in:
Edward Loper 2019-06-13 15:41:12 -07:00 committed by TensorFlower Gardener
parent 0d44856773
commit da392ebd7c
25 changed files with 426 additions and 21 deletions

View File

@ -2371,7 +2371,7 @@ def to_variant(dataset):
# TODO(b/133606651) Rename this class to DatasetSpec
@tf_export("data.experimental.DatasetStructure")
@tf_export("data.DatasetSpec", "data.experimental.DatasetStructure")
class DatasetStructure(type_spec.BatchableTypeSpec):
"""Type specification for `tf.data.Dataset`."""

View File

@ -154,7 +154,7 @@ class _OptionalImpl(Optional):
# TODO(b/133606651) Rename this class to OptionalSpec
@tf_export("data.experimental.OptionalStructure")
@tf_export("OptionalSpec", "data.experimental.OptionalStructure")
class OptionalStructure(type_spec.TypeSpec):
"""Represents an optional potentially containing a structured value."""

View File

@ -178,7 +178,7 @@ IndexedSlicesValue = collections.namedtuple(
"IndexedSlicesValue", ["values", "indices", "dense_shape"])
# TODO(b/133606651) Export this as tf.IndexedSlicesSpec.
@tf_export("IndexedSlicesSpec")
class IndexedSlicesSpec(type_spec.TypeSpec):
"""Type specification for a `tf.IndexedSlices`."""

View File

@ -256,31 +256,31 @@ tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
# TODO(b/133606651) Export this as tf.SparseTensorSpec.
@tf_export("SparseTensorSpec")
class SparseTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.SparseTensor`."""
__slots__ = ["_dense_shape", "_dtype"]
__slots__ = ["_shape", "_dtype"]
value_type = property(lambda self: SparseTensor)
def __init__(self, dense_shape=None, dtype=dtypes.float32):
def __init__(self, shape=None, dtype=dtypes.float32):
"""Constructs a type specification for a `tf.SparseTensor`.
Args:
dense_shape: The dense shape of the `SparseTensor`, or `None` to allow
shape: The dense shape of the `SparseTensor`, or `None` to allow
any dense shape.
dtype: `tf.DType` of values in the `SparseTensor`.
"""
self._dense_shape = tensor_shape.as_shape(dense_shape)
self._shape = tensor_shape.as_shape(shape)
self._dtype = dtypes.as_dtype(dtype)
def _serialize(self):
return (self._dense_shape, self._dtype)
return (self._shape, self._dtype)
@property
def _component_specs(self):
rank = self._dense_shape.ndims
rank = self._shape.ndims
num_values = None
return [
tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
@ -314,7 +314,7 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
def _to_batched_tensor_list(self, value):
dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
if self._dense_shape.merge_with(dense_shape).ndims == 0:
if self._shape.merge_with(dense_shape).ndims == 0:
raise ValueError(
"Unbatching a sparse tensor is only supported for rank >= 1")
return [gen_sparse_ops.serialize_many_sparse(
@ -324,26 +324,26 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
def _from_compatible_tensor_list(self, tensor_list):
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
result = SparseTensor(*tensor_list)
rank = self._dense_shape.ndims
rank = self._shape.ndims
result.indices.set_shape([None, rank])
result.dense_shape.set_shape([rank])
return result
def _batch(self, batch_size):
return SparseTensorSpec(
tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape),
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
self._dtype)
def _unbatch(self):
if self._dense_shape.ndims == 0:
if self._shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return SparseTensorSpec(self._dense_shape[1:], self._dtype)
return SparseTensorSpec(self._shape[1:], self._dtype)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._dense_shape
return self._shape
def _to_legacy_output_classes(self):
return SparseTensor

View File

@ -41,9 +41,8 @@ ops = LazyLoader(
"tensorflow.python.framework.ops")
# TODO(b/133606651) Export this as "TypeSpec" (or experimental.TypeSpec?) and
# deprecate the tf.data.experimental.Structure endpoint.
@tf_export("data.experimental.Structure")
# TODO(b/133606651) Deprecate the tf.data.experimental.Structure endpoint.
@tf_export("TypeSpec", "data.experimental.Structure")
@six.add_metaclass(abc.ABCMeta)
class TypeSpec(object):
"""Specifies a TensorFlow value type.

View File

@ -1879,7 +1879,7 @@ def match_row_splits_dtypes(*tensors, **kwargs):
#===============================================================================
# RaggedTensorSpec
#===============================================================================
# TODO(b/133606651) Export this as tf.RaggedTensorSpec.
@tf_export("RaggedTensorSpec")
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.RaggedTensor`."""

View File

@ -1346,7 +1346,7 @@ def _check_dtypes(value, dtype):
"".join(traceback.format_stack())))
# TODO(b/133606651) Export this as tf.TensorArraySpec.
@tf_export("TensorArraySpec")
class TensorArraySpec(type_spec.TypeSpec):
"""Type specification for a `tf.TensorArray`."""

View File

@ -0,0 +1,22 @@
path: "tensorflow.IndexedSlicesSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlicesSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'indices_dtype\', \'dense_shape_dtype\', \'indices_shape\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \"<dtype: \'int64\'>\", \'True\', \'None\'], "
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,26 @@
path: "tensorflow.OptionalSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.RaggedTensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "from_value"
argspec: "args=[\'cls\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.SparseTensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_value"
argspec: "args=[\'cls\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,26 @@
path: "tensorflow.TensorArraySpec"
tf_class {
is_instance: "<class \'tensorflow.python.ops.tensor_array_ops.TensorArraySpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_shape\', \'dtype\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'False\', \'True\'], "
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,20 @@
path: "tensorflow.TypeSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.data.DatasetSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "Dataset"
mtype: "<type \'type\'>"
}
member {
name: "DatasetSpec"
mtype: "<type \'type\'>"
}
member {
name: "FixedLengthRecordDataset"
mtype: "<type \'type\'>"

View File

@ -120,6 +120,10 @@ tf_module {
name: "IndexedSlices"
mtype: "<type \'type\'>"
}
member {
name: "IndexedSlicesSpec"
mtype: "<type \'type\'>"
}
member {
name: "InteractiveSession"
mtype: "<type \'type\'>"
@ -164,6 +168,10 @@ tf_module {
name: "OptimizerOptions"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "OptionalSpec"
mtype: "<type \'type\'>"
}
member {
name: "PaddingFIFOQueue"
mtype: "<type \'type\'>"
@ -184,6 +192,10 @@ tf_module {
name: "RaggedTensor"
mtype: "<type \'type\'>"
}
member {
name: "RaggedTensorSpec"
mtype: "<type \'type\'>"
}
member {
name: "RandomShuffleQueue"
mtype: "<type \'type\'>"
@ -224,6 +236,10 @@ tf_module {
name: "SparseTensor"
mtype: "<type \'type\'>"
}
member {
name: "SparseTensorSpec"
mtype: "<type \'type\'>"
}
member {
name: "SparseTensorValue"
mtype: "<type \'type\'>"
@ -248,6 +264,10 @@ tf_module {
name: "TensorArray"
mtype: "<type \'type\'>"
}
member {
name: "TensorArraySpec"
mtype: "<type \'type\'>"
}
member {
name: "TensorInfo"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
@ -264,6 +284,10 @@ tf_module {
name: "TextLineReader"
mtype: "<type \'type\'>"
}
member {
name: "TypeSpec"
mtype: "<type \'type\'>"
}
member {
name: "UnconnectedGradients"
mtype: "<class \'enum.EnumMeta\'>"

View File

@ -0,0 +1,22 @@
path: "tensorflow.IndexedSlicesSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlicesSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'indices_dtype\', \'dense_shape_dtype\', \'indices_shape\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \"<dtype: \'int64\'>\", \'True\', \'None\'], "
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,26 @@
path: "tensorflow.OptionalSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.RaggedTensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "from_value"
argspec: "args=[\'cls\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.SparseTensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_value"
argspec: "args=[\'cls\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,26 @@
path: "tensorflow.TensorArraySpec"
tf_class {
is_instance: "<class \'tensorflow.python.ops.tensor_array_ops.TensorArraySpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_shape\', \'dtype\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'False\', \'True\'], "
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,20 @@
path: "tensorflow.TypeSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,27 @@
path: "tensorflow.data.DatasetSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "Dataset"
mtype: "<type \'type\'>"
}
member {
name: "DatasetSpec"
mtype: "<type \'type\'>"
}
member {
name: "FixedLengthRecordDataset"
mtype: "<type \'type\'>"

View File

@ -28,6 +28,10 @@ tf_module {
name: "IndexedSlices"
mtype: "<type \'type\'>"
}
member {
name: "IndexedSlicesSpec"
mtype: "<type \'type\'>"
}
member {
name: "Module"
mtype: "<type \'type\'>"
@ -36,10 +40,18 @@ tf_module {
name: "Operation"
mtype: "<type \'type\'>"
}
member {
name: "OptionalSpec"
mtype: "<type \'type\'>"
}
member {
name: "RaggedTensor"
mtype: "<type \'type\'>"
}
member {
name: "RaggedTensorSpec"
mtype: "<type \'type\'>"
}
member {
name: "RegisterGradient"
mtype: "<type \'type\'>"
@ -48,6 +60,10 @@ tf_module {
name: "SparseTensor"
mtype: "<type \'type\'>"
}
member {
name: "SparseTensorSpec"
mtype: "<type \'type\'>"
}
member {
name: "Tensor"
mtype: "<type \'type\'>"
@ -56,6 +72,10 @@ tf_module {
name: "TensorArray"
mtype: "<type \'type\'>"
}
member {
name: "TensorArraySpec"
mtype: "<type \'type\'>"
}
member {
name: "TensorShape"
mtype: "<type \'type\'>"
@ -64,6 +84,10 @@ tf_module {
name: "TensorSpec"
mtype: "<type \'type\'>"
}
member {
name: "TypeSpec"
mtype: "<type \'type\'>"
}
member {
name: "UnconnectedGradients"
mtype: "<class \'enum.EnumMeta\'>"