Add TensorSpec support for CompositeTensors (such as SparseTensor and RaggedTensor).

PiperOrigin-RevId: 245272144
This commit is contained in:
Edward Loper 2019-04-25 11:18:25 -07:00 committed by TensorFlower Gardener
parent 6b8c6cb57f
commit ccd35296e8
11 changed files with 305 additions and 27 deletions

View File

@ -699,11 +699,13 @@ py_library(
":framework_for_generated_wrappers",
":function",
":graph_util",
":indexed_slices_tensor_spec",
":lib",
":platform",
":pywrap_tensorflow",
":random_seed",
":sparse_tensor",
":sparse_tensor_spec",
":tensor_spec",
":tensor_util",
":util",
@ -1191,6 +1193,30 @@ py_library(
],
)
py_library(
name = "sparse_tensor_spec",
srcs = ["framework/sparse_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":dtypes",
":sparse_tensor",
":tensor_shape",
":tensor_spec",
],
)
py_library(
name = "indexed_slices_tensor_spec",
srcs = ["framework/indexed_slices_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":dtypes",
":framework_ops",
":sparse_tensor",
":tensor_shape",
],
)
py_library(
name = "tensor_util",
srcs = ["framework/tensor_util.py"],

View File

@ -1039,7 +1039,7 @@ class FunctionSpec(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature,
expand_composites=False))
expand_composites=True))
@property
def fullargspec(self):
@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs):
def validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature, expand_composites=False)):
for arg in nest.flatten(signature, expand_composites=True)):
raise TypeError("Invalid input_signature %s; input_signature must be "
"a possibly nested sequence of TensorSpec objects.")

View File

@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import sparse_tensor_spec
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_spec
from tensorflow.python.platform import test
from tensorflow.python.training import training_ops
from tensorflow.python.util import compat
@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
(sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
(ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
[ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]),
(ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
[ragged_tensor_spec.ragged_tensor_spec([None, None, None],
dtypes.int32)]),
(sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
[sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]),
]) # pyformat: disable
def testCompositeAsArgumentTensorWithDefun(self,
factory_fn,

View File

@ -27,7 +27,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None):
Identical structure that has TensorSpec objects instead of Tensors and
UknownArgument instead of any unsupported types.
"""
structure = composite_tensor.replace_composites_with_components(structure)
def encode_arg(arg, path):
"""A representation for this argument, for converting into signatures."""
if isinstance(arg, ops.Tensor):

View File

@ -0,0 +1,65 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for sparse tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
def indexed_slices_tensor_spec(shape=None,
dtype=dtypes.float32,
num_slices=None,
has_dense_shape=True,
name=None):
"""Returns a tensor specification for a IndexedSlices.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `IndexedSlices` argument.
Args:
shape: The shape of the IndexedSlices, or `None` to allow any shape.
The returned specification object depends only on `shape[1:]`.
dtype: Data type of values in the IndexedSlices.
num_slices: Number of slices. Default allows for any number of slices.
has_dense_shape: Indicates whether the IndexedSlices is expected to have a
`dense_shape` component.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `values`, `indices` and `dense_shape` tensors
that comprise the `IndexedSlices`.
"""
dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.TensorShape(shape)
num_slices = tensor_shape.Shape([num_slices])
values = tensor_spec.TensorSpec(
num_slices.concatenate(shape[1:]), dtype, name)
indices = tensor_spec.TensorSpec(num_slices, dtypes.int64,
("%s.indices" % name) if name else None)
if has_dense_shape:
dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64,
("%s.dense_shape" %
name) if name else None)
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape)

View File

@ -62,8 +62,14 @@ from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# This is to avoid a circular dependency: ops -> tensor_spec -> ops
tensor_spec = LazyLoader(
"tensor_spec", globals(),
"tensorflow.python.framework.tensor_spec")
# Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True
@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
def __init__(self, values, indices, dense_shape=None):
"""Creates an `IndexedSlices`."""
_get_graph_from_inputs([values, indices, dense_shape])
if not isinstance(values, tensor_spec.TensorSpec):
_get_graph_from_inputs([values, indices, dense_shape])
self._values = values
self._indices = indices
self._dense_shape = dense_shape

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.util.tf_export import tf_export
@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
values: A 1-D tensor of any type and shape `[N]`.
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
"""
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
indices = ops.convert_to_tensor(
indices, name="indices", dtype=dtypes.int64)
# TODO(touts): Consider adding mutable_values() when 'values'
# is a VariableOp and updating users of SparseTensor.
values = ops.internal_convert_to_tensor(values, name="values")
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
if isinstance(indices, tensor_spec.TensorSpec):
if not isinstance(values, tensor_spec.TensorSpec):
raise TypeError("Expected values to be a TensorSpec")
if not isinstance(dense_shape, tensor_spec.TensorSpec):
raise TypeError("Expected dense_shape to be a TensorSpec")
if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64:
raise TypeError("indices and dense_shape must have dtype=int64")
else:
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
indices = ops.convert_to_tensor(
indices, name="indices", dtype=dtypes.int64)
# TODO(touts): Consider adding mutable_values() when 'values'
# is a VariableOp and updating users of SparseTensor.
values = ops.internal_convert_to_tensor(values, name="values")
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
self._indices = indices
self._values = values
self._dense_shape = dense_shape
indices_shape = indices.get_shape().with_rank(2)
values_shape = values.get_shape().with_rank(1)
dense_shape_shape = dense_shape.get_shape().with_rank(1)
indices_shape = indices.shape.with_rank(2)
values_shape = values.shape.with_rank(1)
dense_shape_shape = dense_shape.shape.with_rank(1)
# Assert number of rows in indices match the number of elements in values.
indices_shape.dims[0].merge_with(values_shape.dims[0])
@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
return [
tensor_shape.TensorShape([None, rank]), # indices
tensor_shape.TensorShape([None]), # values
tensor_shape.TensorShape([rank])
] # dense_shape
tensor_shape.TensorShape([rank]) # dense_shape
]
@property
def _is_graph_tensor(self):

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for sparse tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
def sparse_tensor_spec(shape=None,
dtype=dtypes.float32,
num_values=None,
name=None):
"""Returns a tensor specification for a SparseTensor.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `SparseTensor` argument.
Args:
shape: The shape of the SparseTensor, or `None` to allow any shape. The
returned specification object depends only on `shape.ndims`.
dtype: Data type of values in the SparseTensor.
num_values: The number of values in the SparseTensor, or `None` to allow any
number of values.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `values`, `indices` and `dense_shape` tensors
that comprise the `SparseTensor`.
"""
dtype = dtypes.as_dtype(dtype)
rank = tensor_shape.TensorShape(shape).rank
indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64,
("%s.indices" % name) if name else None)
values = tensor_spec.TensorSpec([num_values], dtype, name)
dense_shape = tensor_spec.TensorSpec(
[rank], dtypes.int64, ("%s.dense_shape" % name) if name else None)
return sparse_tensor.SparseTensor(indices, values, dense_shape)

View File

@ -40,6 +40,7 @@ py_library(
":ragged_string_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_tensor_spec",
":ragged_tensor_value",
":ragged_util",
":ragged_where_op",
@ -434,6 +435,18 @@ py_library(
],
)
py_library(
name = "ragged_tensor_spec",
srcs = ["ragged_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
"//tensorflow/python:dtypes",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
],
)
#-------------------------------------------------------------------------------
# RaggedTensor Tests
#-------------------------------------------------------------------------------

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor):
"of the factory methods instead (e.g., "
"RaggedTensor.from_row_lengths())")
# Validate the arguments.
if not isinstance(values, (RaggedTensor, ops.Tensor)):
raise TypeError("values must be a Tensor or RaggedTensor.")
if not isinstance(row_splits, ops.Tensor):
raise TypeError("Row-partitioning argument must be a Tensor.")
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("Row-partitioning argument must be int32 or int64")
values.shape.with_rank_at_least(1)
is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
if is_tensor_spec:
if not (isinstance(values, tensor_spec.TensorSpec) or
(isinstance(values, RaggedTensor) and
isinstance(values.row_splits, tensor_spec.TensorSpec))):
raise TypeError("Expected values to be a TensorSpec, got %r" % values)
else:
# Validate the arguments.
if not isinstance(row_splits, ops.Tensor):
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
row_splits)
if not isinstance(values, (RaggedTensor, ops.Tensor)):
raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
values)
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("Row-partitioning argument must be int32 or int64")
# Validate shapes & dtypes.
row_splits.shape.assert_has_rank(1)
row_splits.set_shape([None])
values.shape.with_rank_at_least(1)
if not is_tensor_spec:
row_splits.set_shape([None])
if isinstance(values, RaggedTensor):
assert row_splits.dtype == values.row_splits.dtype
@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
raise TypeError("validate must have type bool")
if isinstance(row_splits, (list, tuple)) and not row_splits:
raise ValueError("row_splits tensor may not be empty.")
if isinstance(row_splits, tensor_spec.TensorSpec):
return cls(values=values, row_splits=row_splits, internal=True)
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
values, row_splits = cls._convert_values_and_row_partition(

View File

@ -0,0 +1,77 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for ragged tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops.ragged import ragged_tensor
def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
ragged_rank=None, row_splits_dtype=dtypes.int64,
name=None):
"""Returns a tensor specification for a RaggedTensor.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `RaggedTensor` argument.
Args:
shape: The shape of the RaggedTensor, or `None` to allow any shape.
dtype: Data type of values in the RaggedTensor.
ragged_rank: Python integer, the ragged rank of the RaggedTensor
to be described. Defaults to `shape.ndims - 1`.
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
One of `tf.int32` or `tf.int64`.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `flat_values` and `nested_row_splits` tensors
that comprise the `RaggedTensor`.
"""
dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.TensorShape(shape)
if ragged_rank is None:
if shape.ndims is None:
raise ValueError("Must specify ragged_rank or a shape with known rank.")
ragged_rank = shape.ndims - 1
elif not isinstance(ragged_rank, int):
raise TypeError("ragged_rank must be an int")
if ragged_rank == 0:
return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)
result = tensor_spec.TensorSpec(
tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
dtype, name)
for i in range(ragged_rank - 1, 0, -1):
splits = tensor_spec.TensorSpec(
[None], row_splits_dtype,
"%s.row_splits_%d" % (name, i) if name else None)
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
outer_dim = tensor_shape.dimension_at_index(shape, 0)
splits_shape = [None if outer_dim is None else outer_dim + 1]
splits = tensor_spec.TensorSpec(
splits_shape, row_splits_dtype,
"%s.row_splits_0" % name if name else None)
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
return result