Add TensorSpec support for CompositeTensors (such as SparseTensor and RaggedTensor).
PiperOrigin-RevId: 245272144
This commit is contained in:
parent
6b8c6cb57f
commit
ccd35296e8
@ -699,11 +699,13 @@ py_library(
|
|||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":function",
|
":function",
|
||||||
":graph_util",
|
":graph_util",
|
||||||
|
":indexed_slices_tensor_spec",
|
||||||
":lib",
|
":lib",
|
||||||
":platform",
|
":platform",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
":random_seed",
|
":random_seed",
|
||||||
":sparse_tensor",
|
":sparse_tensor",
|
||||||
|
":sparse_tensor_spec",
|
||||||
":tensor_spec",
|
":tensor_spec",
|
||||||
":tensor_util",
|
":tensor_util",
|
||||||
":util",
|
":util",
|
||||||
@ -1191,6 +1193,30 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "sparse_tensor_spec",
|
||||||
|
srcs = ["framework/sparse_tensor_spec.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":dtypes",
|
||||||
|
":sparse_tensor",
|
||||||
|
":tensor_shape",
|
||||||
|
":tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "indexed_slices_tensor_spec",
|
||||||
|
srcs = ["framework/indexed_slices_tensor_spec.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":dtypes",
|
||||||
|
":framework_ops",
|
||||||
|
":sparse_tensor",
|
||||||
|
":tensor_shape",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tensor_util",
|
name = "tensor_util",
|
||||||
srcs = ["framework/tensor_util.py"],
|
srcs = ["framework/tensor_util.py"],
|
||||||
|
@ -1039,7 +1039,7 @@ class FunctionSpec(object):
|
|||||||
|
|
||||||
self._input_signature = tuple(input_signature)
|
self._input_signature = tuple(input_signature)
|
||||||
self._flat_input_signature = tuple(nest.flatten(input_signature,
|
self._flat_input_signature = tuple(nest.flatten(input_signature,
|
||||||
expand_composites=False))
|
expand_composites=True))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fullargspec(self):
|
def fullargspec(self):
|
||||||
@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs):
|
|||||||
|
|
||||||
def validate_signature(signature):
|
def validate_signature(signature):
|
||||||
if any(not isinstance(arg, tensor_spec.TensorSpec)
|
if any(not isinstance(arg, tensor_spec.TensorSpec)
|
||||||
for arg in nest.flatten(signature, expand_composites=False)):
|
for arg in nest.flatten(signature, expand_composites=True)):
|
||||||
raise TypeError("Invalid input_signature %s; input_signature must be "
|
raise TypeError("Invalid input_signature %s; input_signature must be "
|
||||||
"a possibly nested sequence of TensorSpec objects.")
|
"a possibly nested sequence of TensorSpec objects.")
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import sparse_tensor_spec
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_ops
|
from tensorflow.python.framework import test_ops
|
||||||
@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor_spec
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import training_ops
|
from tensorflow.python.training import training_ops
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
|
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
|
||||||
(sparse_tensor.SparseTensor,
|
(sparse_tensor.SparseTensor,
|
||||||
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
|
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
|
||||||
|
(ragged_tensor.RaggedTensor.from_row_lengths,
|
||||||
|
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
|
||||||
|
[ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]),
|
||||||
|
(ragged_tensor.RaggedTensor.from_nested_row_lengths,
|
||||||
|
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
|
||||||
|
[ragged_tensor_spec.ragged_tensor_spec([None, None, None],
|
||||||
|
dtypes.int32)]),
|
||||||
|
(sparse_tensor.SparseTensor,
|
||||||
|
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
|
||||||
|
[sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]),
|
||||||
]) # pyformat: disable
|
]) # pyformat: disable
|
||||||
def testCompositeAsArgumentTensorWithDefun(self,
|
def testCompositeAsArgumentTensorWithDefun(self,
|
||||||
factory_fn,
|
factory_fn,
|
||||||
|
@ -27,7 +27,6 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None):
|
|||||||
Identical structure that has TensorSpec objects instead of Tensors and
|
Identical structure that has TensorSpec objects instead of Tensors and
|
||||||
UknownArgument instead of any unsupported types.
|
UknownArgument instead of any unsupported types.
|
||||||
"""
|
"""
|
||||||
structure = composite_tensor.replace_composites_with_components(structure)
|
|
||||||
def encode_arg(arg, path):
|
def encode_arg(arg, path):
|
||||||
"""A representation for this argument, for converting into signatures."""
|
"""A representation for this argument, for converting into signatures."""
|
||||||
if isinstance(arg, ops.Tensor):
|
if isinstance(arg, ops.Tensor):
|
||||||
|
65
tensorflow/python/framework/indexed_slices_tensor_spec.py
Normal file
65
tensorflow/python/framework/indexed_slices_tensor_spec.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TensorSpec factory for sparse tensors."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
|
||||||
|
|
||||||
|
def indexed_slices_tensor_spec(shape=None,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
num_slices=None,
|
||||||
|
has_dense_shape=True,
|
||||||
|
name=None):
|
||||||
|
"""Returns a tensor specification for a IndexedSlices.
|
||||||
|
|
||||||
|
Returns an object which can be passed to `tf.function` (or other
|
||||||
|
functions that expect `TensorSpec`s) to specify shape constraints
|
||||||
|
for a `IndexedSlices` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the IndexedSlices, or `None` to allow any shape.
|
||||||
|
The returned specification object depends only on `shape[1:]`.
|
||||||
|
dtype: Data type of values in the IndexedSlices.
|
||||||
|
num_slices: Number of slices. Default allows for any number of slices.
|
||||||
|
has_dense_shape: Indicates whether the IndexedSlices is expected to have a
|
||||||
|
`dense_shape` component.
|
||||||
|
name: Optional name prefix for the `TensorSpec`s.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object describing the `values`, `indices` and `dense_shape` tensors
|
||||||
|
that comprise the `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
shape = tensor_shape.TensorShape(shape)
|
||||||
|
num_slices = tensor_shape.Shape([num_slices])
|
||||||
|
|
||||||
|
values = tensor_spec.TensorSpec(
|
||||||
|
num_slices.concatenate(shape[1:]), dtype, name)
|
||||||
|
indices = tensor_spec.TensorSpec(num_slices, dtypes.int64,
|
||||||
|
("%s.indices" % name) if name else None)
|
||||||
|
if has_dense_shape:
|
||||||
|
dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64,
|
||||||
|
("%s.dense_shape" %
|
||||||
|
name) if name else None)
|
||||||
|
else:
|
||||||
|
dense_shape = None
|
||||||
|
return ops.IndexedSlices(values, indices, dense_shape)
|
@ -62,8 +62,14 @@ from tensorflow.python.util import memory
|
|||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util import tf_stack
|
from tensorflow.python.util import tf_stack
|
||||||
from tensorflow.python.util.deprecation import deprecated_args
|
from tensorflow.python.util.deprecation import deprecated_args
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
# This is to avoid a circular dependency: ops -> tensor_spec -> ops
|
||||||
|
tensor_spec = LazyLoader(
|
||||||
|
"tensor_spec", globals(),
|
||||||
|
"tensorflow.python.framework.tensor_spec")
|
||||||
|
|
||||||
# Temporary global switches determining if we should enable the work-in-progress
|
# Temporary global switches determining if we should enable the work-in-progress
|
||||||
# calls to the C API. These will be removed once all functionality is supported.
|
# calls to the C API. These will be removed once all functionality is supported.
|
||||||
_USE_C_API = True
|
_USE_C_API = True
|
||||||
@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
def __init__(self, values, indices, dense_shape=None):
|
def __init__(self, values, indices, dense_shape=None):
|
||||||
"""Creates an `IndexedSlices`."""
|
"""Creates an `IndexedSlices`."""
|
||||||
_get_graph_from_inputs([values, indices, dense_shape])
|
if not isinstance(values, tensor_spec.TensorSpec):
|
||||||
|
_get_graph_from_inputs([values, indices, dense_shape])
|
||||||
self._values = values
|
self._values = values
|
||||||
self._indices = indices
|
self._indices = indices
|
||||||
self._dense_shape = dense_shape
|
self._dense_shape = dense_shape
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
values: A 1-D tensor of any type and shape `[N]`.
|
values: A 1-D tensor of any type and shape `[N]`.
|
||||||
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
|
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
|
if isinstance(indices, tensor_spec.TensorSpec):
|
||||||
indices = ops.convert_to_tensor(
|
if not isinstance(values, tensor_spec.TensorSpec):
|
||||||
indices, name="indices", dtype=dtypes.int64)
|
raise TypeError("Expected values to be a TensorSpec")
|
||||||
# TODO(touts): Consider adding mutable_values() when 'values'
|
if not isinstance(dense_shape, tensor_spec.TensorSpec):
|
||||||
# is a VariableOp and updating users of SparseTensor.
|
raise TypeError("Expected dense_shape to be a TensorSpec")
|
||||||
values = ops.internal_convert_to_tensor(values, name="values")
|
if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64:
|
||||||
dense_shape = ops.convert_to_tensor(
|
raise TypeError("indices and dense_shape must have dtype=int64")
|
||||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
else:
|
||||||
|
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
|
||||||
|
indices = ops.convert_to_tensor(
|
||||||
|
indices, name="indices", dtype=dtypes.int64)
|
||||||
|
# TODO(touts): Consider adding mutable_values() when 'values'
|
||||||
|
# is a VariableOp and updating users of SparseTensor.
|
||||||
|
values = ops.internal_convert_to_tensor(values, name="values")
|
||||||
|
dense_shape = ops.convert_to_tensor(
|
||||||
|
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||||
self._indices = indices
|
self._indices = indices
|
||||||
self._values = values
|
self._values = values
|
||||||
self._dense_shape = dense_shape
|
self._dense_shape = dense_shape
|
||||||
|
|
||||||
indices_shape = indices.get_shape().with_rank(2)
|
indices_shape = indices.shape.with_rank(2)
|
||||||
values_shape = values.get_shape().with_rank(1)
|
values_shape = values.shape.with_rank(1)
|
||||||
dense_shape_shape = dense_shape.get_shape().with_rank(1)
|
dense_shape_shape = dense_shape.shape.with_rank(1)
|
||||||
|
|
||||||
# Assert number of rows in indices match the number of elements in values.
|
# Assert number of rows in indices match the number of elements in values.
|
||||||
indices_shape.dims[0].merge_with(values_shape.dims[0])
|
indices_shape.dims[0].merge_with(values_shape.dims[0])
|
||||||
@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
return [
|
return [
|
||||||
tensor_shape.TensorShape([None, rank]), # indices
|
tensor_shape.TensorShape([None, rank]), # indices
|
||||||
tensor_shape.TensorShape([None]), # values
|
tensor_shape.TensorShape([None]), # values
|
||||||
tensor_shape.TensorShape([rank])
|
tensor_shape.TensorShape([rank]) # dense_shape
|
||||||
] # dense_shape
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _is_graph_tensor(self):
|
def _is_graph_tensor(self):
|
||||||
|
56
tensorflow/python/framework/sparse_tensor_spec.py
Normal file
56
tensorflow/python/framework/sparse_tensor_spec.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TensorSpec factory for sparse tensors."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_tensor_spec(shape=None,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
num_values=None,
|
||||||
|
name=None):
|
||||||
|
"""Returns a tensor specification for a SparseTensor.
|
||||||
|
|
||||||
|
Returns an object which can be passed to `tf.function` (or other
|
||||||
|
functions that expect `TensorSpec`s) to specify shape constraints
|
||||||
|
for a `SparseTensor` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the SparseTensor, or `None` to allow any shape. The
|
||||||
|
returned specification object depends only on `shape.ndims`.
|
||||||
|
dtype: Data type of values in the SparseTensor.
|
||||||
|
num_values: The number of values in the SparseTensor, or `None` to allow any
|
||||||
|
number of values.
|
||||||
|
name: Optional name prefix for the `TensorSpec`s.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object describing the `values`, `indices` and `dense_shape` tensors
|
||||||
|
that comprise the `SparseTensor`.
|
||||||
|
"""
|
||||||
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
rank = tensor_shape.TensorShape(shape).rank
|
||||||
|
indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64,
|
||||||
|
("%s.indices" % name) if name else None)
|
||||||
|
values = tensor_spec.TensorSpec([num_values], dtype, name)
|
||||||
|
dense_shape = tensor_spec.TensorSpec(
|
||||||
|
[rank], dtypes.int64, ("%s.dense_shape" % name) if name else None)
|
||||||
|
return sparse_tensor.SparseTensor(indices, values, dense_shape)
|
@ -40,6 +40,7 @@ py_library(
|
|||||||
":ragged_string_ops",
|
":ragged_string_ops",
|
||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
":ragged_tensor_shape",
|
":ragged_tensor_shape",
|
||||||
|
":ragged_tensor_spec",
|
||||||
":ragged_tensor_value",
|
":ragged_tensor_value",
|
||||||
":ragged_util",
|
":ragged_util",
|
||||||
":ragged_where_op",
|
":ragged_where_op",
|
||||||
@ -434,6 +435,18 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "ragged_tensor_spec",
|
||||||
|
srcs = ["ragged_tensor_spec.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged_tensor",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python:tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
# RaggedTensor Tests
|
# RaggedTensor Tests
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
"of the factory methods instead (e.g., "
|
"of the factory methods instead (e.g., "
|
||||||
"RaggedTensor.from_row_lengths())")
|
"RaggedTensor.from_row_lengths())")
|
||||||
|
|
||||||
# Validate the arguments.
|
is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
|
||||||
if not isinstance(values, (RaggedTensor, ops.Tensor)):
|
if is_tensor_spec:
|
||||||
raise TypeError("values must be a Tensor or RaggedTensor.")
|
if not (isinstance(values, tensor_spec.TensorSpec) or
|
||||||
if not isinstance(row_splits, ops.Tensor):
|
(isinstance(values, RaggedTensor) and
|
||||||
raise TypeError("Row-partitioning argument must be a Tensor.")
|
isinstance(values.row_splits, tensor_spec.TensorSpec))):
|
||||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
raise TypeError("Expected values to be a TensorSpec, got %r" % values)
|
||||||
raise ValueError("Row-partitioning argument must be int32 or int64")
|
else:
|
||||||
values.shape.with_rank_at_least(1)
|
# Validate the arguments.
|
||||||
|
if not isinstance(row_splits, ops.Tensor):
|
||||||
|
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
|
||||||
|
row_splits)
|
||||||
|
if not isinstance(values, (RaggedTensor, ops.Tensor)):
|
||||||
|
raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
|
||||||
|
values)
|
||||||
|
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||||
|
raise ValueError("Row-partitioning argument must be int32 or int64")
|
||||||
|
|
||||||
|
# Validate shapes & dtypes.
|
||||||
row_splits.shape.assert_has_rank(1)
|
row_splits.shape.assert_has_rank(1)
|
||||||
row_splits.set_shape([None])
|
values.shape.with_rank_at_least(1)
|
||||||
|
if not is_tensor_spec:
|
||||||
|
row_splits.set_shape([None])
|
||||||
if isinstance(values, RaggedTensor):
|
if isinstance(values, RaggedTensor):
|
||||||
assert row_splits.dtype == values.row_splits.dtype
|
assert row_splits.dtype == values.row_splits.dtype
|
||||||
|
|
||||||
@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
raise TypeError("validate must have type bool")
|
raise TypeError("validate must have type bool")
|
||||||
if isinstance(row_splits, (list, tuple)) and not row_splits:
|
if isinstance(row_splits, (list, tuple)) and not row_splits:
|
||||||
raise ValueError("row_splits tensor may not be empty.")
|
raise ValueError("row_splits tensor may not be empty.")
|
||||||
|
if isinstance(row_splits, tensor_spec.TensorSpec):
|
||||||
|
return cls(values=values, row_splits=row_splits, internal=True)
|
||||||
|
|
||||||
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
|
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
|
||||||
values, row_splits = cls._convert_values_and_row_partition(
|
values, row_splits = cls._convert_values_and_row_partition(
|
||||||
|
77
tensorflow/python/ops/ragged/ragged_tensor_spec.py
Normal file
77
tensorflow/python/ops/ragged/ragged_tensor_spec.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TensorSpec factory for ragged tensors."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
|
||||||
|
ragged_rank=None, row_splits_dtype=dtypes.int64,
|
||||||
|
name=None):
|
||||||
|
"""Returns a tensor specification for a RaggedTensor.
|
||||||
|
|
||||||
|
Returns an object which can be passed to `tf.function` (or other
|
||||||
|
functions that expect `TensorSpec`s) to specify shape constraints
|
||||||
|
for a `RaggedTensor` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the RaggedTensor, or `None` to allow any shape.
|
||||||
|
dtype: Data type of values in the RaggedTensor.
|
||||||
|
ragged_rank: Python integer, the ragged rank of the RaggedTensor
|
||||||
|
to be described. Defaults to `shape.ndims - 1`.
|
||||||
|
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
|
||||||
|
One of `tf.int32` or `tf.int64`.
|
||||||
|
name: Optional name prefix for the `TensorSpec`s.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object describing the `flat_values` and `nested_row_splits` tensors
|
||||||
|
that comprise the `RaggedTensor`.
|
||||||
|
"""
|
||||||
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
shape = tensor_shape.TensorShape(shape)
|
||||||
|
if ragged_rank is None:
|
||||||
|
if shape.ndims is None:
|
||||||
|
raise ValueError("Must specify ragged_rank or a shape with known rank.")
|
||||||
|
ragged_rank = shape.ndims - 1
|
||||||
|
elif not isinstance(ragged_rank, int):
|
||||||
|
raise TypeError("ragged_rank must be an int")
|
||||||
|
if ragged_rank == 0:
|
||||||
|
return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)
|
||||||
|
|
||||||
|
result = tensor_spec.TensorSpec(
|
||||||
|
tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
|
||||||
|
dtype, name)
|
||||||
|
|
||||||
|
for i in range(ragged_rank - 1, 0, -1):
|
||||||
|
splits = tensor_spec.TensorSpec(
|
||||||
|
[None], row_splits_dtype,
|
||||||
|
"%s.row_splits_%d" % (name, i) if name else None)
|
||||||
|
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
|
||||||
|
|
||||||
|
outer_dim = tensor_shape.dimension_at_index(shape, 0)
|
||||||
|
splits_shape = [None if outer_dim is None else outer_dim + 1]
|
||||||
|
splits = tensor_spec.TensorSpec(
|
||||||
|
splits_shape, row_splits_dtype,
|
||||||
|
"%s.row_splits_0" % name if name else None)
|
||||||
|
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
|
||||||
|
|
||||||
|
return result
|
Loading…
x
Reference in New Issue
Block a user