Add TensorSpec support for CompositeTensors (such as SparseTensor and RaggedTensor).

PiperOrigin-RevId: 245272144
This commit is contained in:
Edward Loper 2019-04-25 11:18:25 -07:00 committed by TensorFlower Gardener
parent 6b8c6cb57f
commit ccd35296e8
11 changed files with 305 additions and 27 deletions

View File

@ -699,11 +699,13 @@ py_library(
":framework_for_generated_wrappers", ":framework_for_generated_wrappers",
":function", ":function",
":graph_util", ":graph_util",
":indexed_slices_tensor_spec",
":lib", ":lib",
":platform", ":platform",
":pywrap_tensorflow", ":pywrap_tensorflow",
":random_seed", ":random_seed",
":sparse_tensor", ":sparse_tensor",
":sparse_tensor_spec",
":tensor_spec", ":tensor_spec",
":tensor_util", ":tensor_util",
":util", ":util",
@ -1191,6 +1193,30 @@ py_library(
], ],
) )
py_library(
name = "sparse_tensor_spec",
srcs = ["framework/sparse_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":dtypes",
":sparse_tensor",
":tensor_shape",
":tensor_spec",
],
)
py_library(
name = "indexed_slices_tensor_spec",
srcs = ["framework/indexed_slices_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":dtypes",
":framework_ops",
":sparse_tensor",
":tensor_shape",
],
)
py_library( py_library(
name = "tensor_util", name = "tensor_util",
srcs = ["framework/tensor_util.py"], srcs = ["framework/tensor_util.py"],

View File

@ -1039,7 +1039,7 @@ class FunctionSpec(object):
self._input_signature = tuple(input_signature) self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature, self._flat_input_signature = tuple(nest.flatten(input_signature,
expand_composites=False)) expand_composites=True))
@property @property
def fullargspec(self): def fullargspec(self):
@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs):
def validate_signature(signature): def validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec) if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature, expand_composites=False)): for arg in nest.flatten(signature, expand_composites=True)):
raise TypeError("Invalid input_signature %s; input_signature must be " raise TypeError("Invalid input_signature %s; input_signature must be "
"a possibly nested sequence of TensorSpec objects.") "a possibly nested sequence of TensorSpec objects.")

View File

@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import sparse_tensor_spec
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_ops
@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_spec
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import training_ops from tensorflow.python.training import training_ops
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
(sparse_tensor.SparseTensor, (sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
(ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
[ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]),
(ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
[ragged_tensor_spec.ragged_tensor_spec([None, None, None],
dtypes.int32)]),
(sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
[sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]),
]) # pyformat: disable ]) # pyformat: disable
def testCompositeAsArgumentTensorWithDefun(self, def testCompositeAsArgumentTensorWithDefun(self,
factory_fn, factory_fn,

View File

@ -27,7 +27,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None):
Identical structure that has TensorSpec objects instead of Tensors and Identical structure that has TensorSpec objects instead of Tensors and
UknownArgument instead of any unsupported types. UknownArgument instead of any unsupported types.
""" """
structure = composite_tensor.replace_composites_with_components(structure)
def encode_arg(arg, path): def encode_arg(arg, path):
"""A representation for this argument, for converting into signatures.""" """A representation for this argument, for converting into signatures."""
if isinstance(arg, ops.Tensor): if isinstance(arg, ops.Tensor):

View File

@ -0,0 +1,65 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for sparse tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
def indexed_slices_tensor_spec(shape=None,
dtype=dtypes.float32,
num_slices=None,
has_dense_shape=True,
name=None):
"""Returns a tensor specification for a IndexedSlices.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `IndexedSlices` argument.
Args:
shape: The shape of the IndexedSlices, or `None` to allow any shape.
The returned specification object depends only on `shape[1:]`.
dtype: Data type of values in the IndexedSlices.
num_slices: Number of slices. Default allows for any number of slices.
has_dense_shape: Indicates whether the IndexedSlices is expected to have a
`dense_shape` component.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `values`, `indices` and `dense_shape` tensors
that comprise the `IndexedSlices`.
"""
dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.TensorShape(shape)
num_slices = tensor_shape.Shape([num_slices])
values = tensor_spec.TensorSpec(
num_slices.concatenate(shape[1:]), dtype, name)
indices = tensor_spec.TensorSpec(num_slices, dtypes.int64,
("%s.indices" % name) if name else None)
if has_dense_shape:
dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64,
("%s.dense_shape" %
name) if name else None)
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape)

View File

@ -62,8 +62,14 @@ from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# This is to avoid a circular dependency: ops -> tensor_spec -> ops
tensor_spec = LazyLoader(
"tensor_spec", globals(),
"tensorflow.python.framework.tensor_spec")
# Temporary global switches determining if we should enable the work-in-progress # Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported. # calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True _USE_C_API = True
@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
def __init__(self, values, indices, dense_shape=None): def __init__(self, values, indices, dense_shape=None):
"""Creates an `IndexedSlices`.""" """Creates an `IndexedSlices`."""
_get_graph_from_inputs([values, indices, dense_shape]) if not isinstance(values, tensor_spec.TensorSpec):
_get_graph_from_inputs([values, indices, dense_shape])
self._values = values self._values = values
self._indices = indices self._indices = indices
self._dense_shape = dense_shape self._dense_shape = dense_shape

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
values: A 1-D tensor of any type and shape `[N]`. values: A 1-D tensor of any type and shape `[N]`.
dense_shape: A 1-D int64 tensor of shape `[ndims]`. dense_shape: A 1-D int64 tensor of shape `[ndims]`.
""" """
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): if isinstance(indices, tensor_spec.TensorSpec):
indices = ops.convert_to_tensor( if not isinstance(values, tensor_spec.TensorSpec):
indices, name="indices", dtype=dtypes.int64) raise TypeError("Expected values to be a TensorSpec")
# TODO(touts): Consider adding mutable_values() when 'values' if not isinstance(dense_shape, tensor_spec.TensorSpec):
# is a VariableOp and updating users of SparseTensor. raise TypeError("Expected dense_shape to be a TensorSpec")
values = ops.internal_convert_to_tensor(values, name="values") if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64:
dense_shape = ops.convert_to_tensor( raise TypeError("indices and dense_shape must have dtype=int64")
dense_shape, name="dense_shape", dtype=dtypes.int64) else:
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
indices = ops.convert_to_tensor(
indices, name="indices", dtype=dtypes.int64)
# TODO(touts): Consider adding mutable_values() when 'values'
# is a VariableOp and updating users of SparseTensor.
values = ops.internal_convert_to_tensor(values, name="values")
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
self._indices = indices self._indices = indices
self._values = values self._values = values
self._dense_shape = dense_shape self._dense_shape = dense_shape
indices_shape = indices.get_shape().with_rank(2) indices_shape = indices.shape.with_rank(2)
values_shape = values.get_shape().with_rank(1) values_shape = values.shape.with_rank(1)
dense_shape_shape = dense_shape.get_shape().with_rank(1) dense_shape_shape = dense_shape.shape.with_rank(1)
# Assert number of rows in indices match the number of elements in values. # Assert number of rows in indices match the number of elements in values.
indices_shape.dims[0].merge_with(values_shape.dims[0]) indices_shape.dims[0].merge_with(values_shape.dims[0])
@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
return [ return [
tensor_shape.TensorShape([None, rank]), # indices tensor_shape.TensorShape([None, rank]), # indices
tensor_shape.TensorShape([None]), # values tensor_shape.TensorShape([None]), # values
tensor_shape.TensorShape([rank]) tensor_shape.TensorShape([rank]) # dense_shape
] # dense_shape ]
@property @property
def _is_graph_tensor(self): def _is_graph_tensor(self):

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for sparse tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
def sparse_tensor_spec(shape=None,
dtype=dtypes.float32,
num_values=None,
name=None):
"""Returns a tensor specification for a SparseTensor.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `SparseTensor` argument.
Args:
shape: The shape of the SparseTensor, or `None` to allow any shape. The
returned specification object depends only on `shape.ndims`.
dtype: Data type of values in the SparseTensor.
num_values: The number of values in the SparseTensor, or `None` to allow any
number of values.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `values`, `indices` and `dense_shape` tensors
that comprise the `SparseTensor`.
"""
dtype = dtypes.as_dtype(dtype)
rank = tensor_shape.TensorShape(shape).rank
indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64,
("%s.indices" % name) if name else None)
values = tensor_spec.TensorSpec([num_values], dtype, name)
dense_shape = tensor_spec.TensorSpec(
[rank], dtypes.int64, ("%s.dense_shape" % name) if name else None)
return sparse_tensor.SparseTensor(indices, values, dense_shape)

View File

@ -40,6 +40,7 @@ py_library(
":ragged_string_ops", ":ragged_string_ops",
":ragged_tensor", ":ragged_tensor",
":ragged_tensor_shape", ":ragged_tensor_shape",
":ragged_tensor_spec",
":ragged_tensor_value", ":ragged_tensor_value",
":ragged_util", ":ragged_util",
":ragged_where_op", ":ragged_where_op",
@ -434,6 +435,18 @@ py_library(
], ],
) )
py_library(
name = "ragged_tensor_spec",
srcs = ["ragged_tensor_spec.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
"//tensorflow/python:dtypes",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
],
)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# RaggedTensor Tests # RaggedTensor Tests
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor):
"of the factory methods instead (e.g., " "of the factory methods instead (e.g., "
"RaggedTensor.from_row_lengths())") "RaggedTensor.from_row_lengths())")
# Validate the arguments. is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
if not isinstance(values, (RaggedTensor, ops.Tensor)): if is_tensor_spec:
raise TypeError("values must be a Tensor or RaggedTensor.") if not (isinstance(values, tensor_spec.TensorSpec) or
if not isinstance(row_splits, ops.Tensor): (isinstance(values, RaggedTensor) and
raise TypeError("Row-partitioning argument must be a Tensor.") isinstance(values.row_splits, tensor_spec.TensorSpec))):
if row_splits.dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Expected values to be a TensorSpec, got %r" % values)
raise ValueError("Row-partitioning argument must be int32 or int64") else:
values.shape.with_rank_at_least(1) # Validate the arguments.
if not isinstance(row_splits, ops.Tensor):
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
row_splits)
if not isinstance(values, (RaggedTensor, ops.Tensor)):
raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
values)
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("Row-partitioning argument must be int32 or int64")
# Validate shapes & dtypes.
row_splits.shape.assert_has_rank(1) row_splits.shape.assert_has_rank(1)
row_splits.set_shape([None]) values.shape.with_rank_at_least(1)
if not is_tensor_spec:
row_splits.set_shape([None])
if isinstance(values, RaggedTensor): if isinstance(values, RaggedTensor):
assert row_splits.dtype == values.row_splits.dtype assert row_splits.dtype == values.row_splits.dtype
@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
raise TypeError("validate must have type bool") raise TypeError("validate must have type bool")
if isinstance(row_splits, (list, tuple)) and not row_splits: if isinstance(row_splits, (list, tuple)) and not row_splits:
raise ValueError("row_splits tensor may not be empty.") raise ValueError("row_splits tensor may not be empty.")
if isinstance(row_splits, tensor_spec.TensorSpec):
return cls(values=values, row_splits=row_splits, internal=True)
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
values, row_splits = cls._convert_values_and_row_partition( values, row_splits = cls._convert_values_and_row_partition(

View File

@ -0,0 +1,77 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorSpec factory for ragged tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops.ragged import ragged_tensor
def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
ragged_rank=None, row_splits_dtype=dtypes.int64,
name=None):
"""Returns a tensor specification for a RaggedTensor.
Returns an object which can be passed to `tf.function` (or other
functions that expect `TensorSpec`s) to specify shape constraints
for a `RaggedTensor` argument.
Args:
shape: The shape of the RaggedTensor, or `None` to allow any shape.
dtype: Data type of values in the RaggedTensor.
ragged_rank: Python integer, the ragged rank of the RaggedTensor
to be described. Defaults to `shape.ndims - 1`.
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
One of `tf.int32` or `tf.int64`.
name: Optional name prefix for the `TensorSpec`s.
Returns:
An object describing the `flat_values` and `nested_row_splits` tensors
that comprise the `RaggedTensor`.
"""
dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.TensorShape(shape)
if ragged_rank is None:
if shape.ndims is None:
raise ValueError("Must specify ragged_rank or a shape with known rank.")
ragged_rank = shape.ndims - 1
elif not isinstance(ragged_rank, int):
raise TypeError("ragged_rank must be an int")
if ragged_rank == 0:
return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)
result = tensor_spec.TensorSpec(
tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
dtype, name)
for i in range(ragged_rank - 1, 0, -1):
splits = tensor_spec.TensorSpec(
[None], row_splits_dtype,
"%s.row_splits_%d" % (name, i) if name else None)
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
outer_dim = tensor_shape.dimension_at_index(shape, 0)
splits_shape = [None if outer_dim is None else outer_dim + 1]
splits = tensor_spec.TensorSpec(
splits_shape, row_splits_dtype,
"%s.row_splits_0" % name if name else None)
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
return result