From ccd35296e8d7569cd2bc9d2ab01541ddb6903308 Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Thu, 25 Apr 2019 11:18:25 -0700 Subject: [PATCH] Add TensorSpec support for CompositeTensors (such as SparseTensor and RaggedTensor). PiperOrigin-RevId: 245272144 --- tensorflow/python/BUILD | 26 +++++++ tensorflow/python/eager/function.py | 4 +- tensorflow/python/eager/function_test.py | 12 +++ tensorflow/python/framework/func_graph.py | 2 - .../framework/indexed_slices_tensor_spec.py | 65 ++++++++++++++++ tensorflow/python/framework/ops.py | 9 ++- tensorflow/python/framework/sparse_tensor.py | 35 +++++---- .../python/framework/sparse_tensor_spec.py | 56 ++++++++++++++ tensorflow/python/ops/ragged/BUILD | 13 ++++ tensorflow/python/ops/ragged/ragged_tensor.py | 33 +++++--- .../python/ops/ragged/ragged_tensor_spec.py | 77 +++++++++++++++++++ 11 files changed, 305 insertions(+), 27 deletions(-) create mode 100644 tensorflow/python/framework/indexed_slices_tensor_spec.py create mode 100644 tensorflow/python/framework/sparse_tensor_spec.py create mode 100644 tensorflow/python/ops/ragged/ragged_tensor_spec.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c64ba8617a3..ed1e33b7ae9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -699,11 +699,13 @@ py_library( ":framework_for_generated_wrappers", ":function", ":graph_util", + ":indexed_slices_tensor_spec", ":lib", ":platform", ":pywrap_tensorflow", ":random_seed", ":sparse_tensor", + ":sparse_tensor_spec", ":tensor_spec", ":tensor_util", ":util", @@ -1191,6 +1193,30 @@ py_library( ], ) +py_library( + name = "sparse_tensor_spec", + srcs = ["framework/sparse_tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":sparse_tensor", + ":tensor_shape", + ":tensor_spec", + ], +) + +py_library( + name = "indexed_slices_tensor_spec", + srcs = ["framework/indexed_slices_tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":framework_ops", + ":sparse_tensor", + ":tensor_shape", + ], +) + py_library( name = "tensor_util", srcs = ["framework/tensor_util.py"], diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 4f0b3dec4e5..8574fa1382e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1039,7 +1039,7 @@ class FunctionSpec(object): self._input_signature = tuple(input_signature) self._flat_input_signature = tuple(nest.flatten(input_signature, - expand_composites=False)) + expand_composites=True)) @property def fullargspec(self): @@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs): def validate_signature(signature): if any(not isinstance(arg, tensor_spec.TensorSpec) - for arg in nest.flatten(signature, expand_composites=False)): + for arg in nest.flatten(signature, expand_composites=True)): raise TypeError("Invalid input_signature %s; input_signature must be " "a possibly nested sequence of TensorSpec objects.") diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 02130b053c5..9b8dc930946 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import sparse_tensor_spec from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_ops @@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_spec from tensorflow.python.platform import test from tensorflow.python.training import training_ops from tensorflow.python.util import compat @@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase): {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), (sparse_tensor.SparseTensor, {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), + (ragged_tensor.RaggedTensor.from_row_lengths, + {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}, + [ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]), + (ragged_tensor.RaggedTensor.from_nested_row_lengths, + {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}, + [ragged_tensor_spec.ragged_tensor_spec([None, None, None], + dtypes.int32)]), + (sparse_tensor.SparseTensor, + {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}, + [sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]), ]) # pyformat: disable def testCompositeAsArgumentTensorWithDefun(self, factory_fn, diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index db33c5472e4..5311f6601ac 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -27,7 +27,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec @@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None): Identical structure that has TensorSpec objects instead of Tensors and UknownArgument instead of any unsupported types. """ - structure = composite_tensor.replace_composites_with_components(structure) def encode_arg(arg, path): """A representation for this argument, for converting into signatures.""" if isinstance(arg, ops.Tensor): diff --git a/tensorflow/python/framework/indexed_slices_tensor_spec.py b/tensorflow/python/framework/indexed_slices_tensor_spec.py new file mode 100644 index 00000000000..965e09231a9 --- /dev/null +++ b/tensorflow/python/framework/indexed_slices_tensor_spec.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorSpec factory for sparse tensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec + + +def indexed_slices_tensor_spec(shape=None, + dtype=dtypes.float32, + num_slices=None, + has_dense_shape=True, + name=None): + """Returns a tensor specification for a IndexedSlices. + + Returns an object which can be passed to `tf.function` (or other + functions that expect `TensorSpec`s) to specify shape constraints + for a `IndexedSlices` argument. + + Args: + shape: The shape of the IndexedSlices, or `None` to allow any shape. + The returned specification object depends only on `shape[1:]`. + dtype: Data type of values in the IndexedSlices. + num_slices: Number of slices. Default allows for any number of slices. + has_dense_shape: Indicates whether the IndexedSlices is expected to have a + `dense_shape` component. + name: Optional name prefix for the `TensorSpec`s. + + Returns: + An object describing the `values`, `indices` and `dense_shape` tensors + that comprise the `IndexedSlices`. + """ + dtype = dtypes.as_dtype(dtype) + shape = tensor_shape.TensorShape(shape) + num_slices = tensor_shape.Shape([num_slices]) + + values = tensor_spec.TensorSpec( + num_slices.concatenate(shape[1:]), dtype, name) + indices = tensor_spec.TensorSpec(num_slices, dtypes.int64, + ("%s.indices" % name) if name else None) + if has_dense_shape: + dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64, + ("%s.dense_shape" % + name) if name else None) + else: + dense_shape = None + return ops.IndexedSlices(values, indices, dense_shape) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 1b1c19e9511..d23cfc77c94 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -62,8 +62,14 @@ from tensorflow.python.util import memory from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_stack from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export +# This is to avoid a circular dependency: ops -> tensor_spec -> ops +tensor_spec = LazyLoader( + "tensor_spec", globals(), + "tensorflow.python.framework.tensor_spec") + # Temporary global switches determining if we should enable the work-in-progress # calls to the C API. These will be removed once all functionality is supported. _USE_C_API = True @@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor): def __init__(self, values, indices, dense_shape=None): """Creates an `IndexedSlices`.""" - _get_graph_from_inputs([values, indices, dense_shape]) + if not isinstance(values, tensor_spec.TensorSpec): + _get_graph_from_inputs([values, indices, dense_shape]) self._values = values self._indices = indices self._dense_shape = dense_shape diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 834fa3338ba..e9199b1e661 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.util.tf_export import tf_export @@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): values: A 1-D tensor of any type and shape `[N]`. dense_shape: A 1-D int64 tensor of shape `[ndims]`. """ - with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): - indices = ops.convert_to_tensor( - indices, name="indices", dtype=dtypes.int64) - # TODO(touts): Consider adding mutable_values() when 'values' - # is a VariableOp and updating users of SparseTensor. - values = ops.internal_convert_to_tensor(values, name="values") - dense_shape = ops.convert_to_tensor( - dense_shape, name="dense_shape", dtype=dtypes.int64) + if isinstance(indices, tensor_spec.TensorSpec): + if not isinstance(values, tensor_spec.TensorSpec): + raise TypeError("Expected values to be a TensorSpec") + if not isinstance(dense_shape, tensor_spec.TensorSpec): + raise TypeError("Expected dense_shape to be a TensorSpec") + if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64: + raise TypeError("indices and dense_shape must have dtype=int64") + else: + with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): + indices = ops.convert_to_tensor( + indices, name="indices", dtype=dtypes.int64) + # TODO(touts): Consider adding mutable_values() when 'values' + # is a VariableOp and updating users of SparseTensor. + values = ops.internal_convert_to_tensor(values, name="values") + dense_shape = ops.convert_to_tensor( + dense_shape, name="dense_shape", dtype=dtypes.int64) self._indices = indices self._values = values self._dense_shape = dense_shape - indices_shape = indices.get_shape().with_rank(2) - values_shape = values.get_shape().with_rank(1) - dense_shape_shape = dense_shape.get_shape().with_rank(1) + indices_shape = indices.shape.with_rank(2) + values_shape = values.shape.with_rank(1) + dense_shape_shape = dense_shape.shape.with_rank(1) # Assert number of rows in indices match the number of elements in values. indices_shape.dims[0].merge_with(values_shape.dims[0]) @@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): return [ tensor_shape.TensorShape([None, rank]), # indices tensor_shape.TensorShape([None]), # values - tensor_shape.TensorShape([rank]) - ] # dense_shape + tensor_shape.TensorShape([rank]) # dense_shape + ] @property def _is_graph_tensor(self): diff --git a/tensorflow/python/framework/sparse_tensor_spec.py b/tensorflow/python/framework/sparse_tensor_spec.py new file mode 100644 index 00000000000..4c9f163384a --- /dev/null +++ b/tensorflow/python/framework/sparse_tensor_spec.py @@ -0,0 +1,56 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorSpec factory for sparse tensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec + + +def sparse_tensor_spec(shape=None, + dtype=dtypes.float32, + num_values=None, + name=None): + """Returns a tensor specification for a SparseTensor. + + Returns an object which can be passed to `tf.function` (or other + functions that expect `TensorSpec`s) to specify shape constraints + for a `SparseTensor` argument. + + Args: + shape: The shape of the SparseTensor, or `None` to allow any shape. The + returned specification object depends only on `shape.ndims`. + dtype: Data type of values in the SparseTensor. + num_values: The number of values in the SparseTensor, or `None` to allow any + number of values. + name: Optional name prefix for the `TensorSpec`s. + + Returns: + An object describing the `values`, `indices` and `dense_shape` tensors + that comprise the `SparseTensor`. + """ + dtype = dtypes.as_dtype(dtype) + rank = tensor_shape.TensorShape(shape).rank + indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64, + ("%s.indices" % name) if name else None) + values = tensor_spec.TensorSpec([num_values], dtype, name) + dense_shape = tensor_spec.TensorSpec( + [rank], dtypes.int64, ("%s.dense_shape" % name) if name else None) + return sparse_tensor.SparseTensor(indices, values, dense_shape) diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 3b7ea580e0f..d8fb74fae68 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -40,6 +40,7 @@ py_library( ":ragged_string_ops", ":ragged_tensor", ":ragged_tensor_shape", + ":ragged_tensor_spec", ":ragged_tensor_value", ":ragged_util", ":ragged_where_op", @@ -434,6 +435,18 @@ py_library( ], ) +py_library( + name = "ragged_tensor_spec", + srcs = ["ragged_tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":ragged_tensor", + "//tensorflow/python:dtypes", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + ], +) + #------------------------------------------------------------------------------- # RaggedTensor Tests #------------------------------------------------------------------------------- diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 561fd3bedc3..a30e6c13a77 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor): "of the factory methods instead (e.g., " "RaggedTensor.from_row_lengths())") - # Validate the arguments. - if not isinstance(values, (RaggedTensor, ops.Tensor)): - raise TypeError("values must be a Tensor or RaggedTensor.") - if not isinstance(row_splits, ops.Tensor): - raise TypeError("Row-partitioning argument must be a Tensor.") - if row_splits.dtype not in (dtypes.int32, dtypes.int64): - raise ValueError("Row-partitioning argument must be int32 or int64") - values.shape.with_rank_at_least(1) + is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec) + if is_tensor_spec: + if not (isinstance(values, tensor_spec.TensorSpec) or + (isinstance(values, RaggedTensor) and + isinstance(values.row_splits, tensor_spec.TensorSpec))): + raise TypeError("Expected values to be a TensorSpec, got %r" % values) + else: + # Validate the arguments. + if not isinstance(row_splits, ops.Tensor): + raise TypeError("Row-partitioning argument must be a Tensor, got %r" % + row_splits) + if not isinstance(values, (RaggedTensor, ops.Tensor)): + raise TypeError("values must be a Tensor or RaggedTensor, got %r" % + values) + if row_splits.dtype not in (dtypes.int32, dtypes.int64): + raise ValueError("Row-partitioning argument must be int32 or int64") + + # Validate shapes & dtypes. row_splits.shape.assert_has_rank(1) - row_splits.set_shape([None]) + values.shape.with_rank_at_least(1) + if not is_tensor_spec: + row_splits.set_shape([None]) if isinstance(values, RaggedTensor): assert row_splits.dtype == values.row_splits.dtype @@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): raise TypeError("validate must have type bool") if isinstance(row_splits, (list, tuple)) and not row_splits: raise ValueError("row_splits tensor may not be empty.") + if isinstance(row_splits, tensor_spec.TensorSpec): + return cls(values=values, row_splits=row_splits, internal=True) with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): values, row_splits = cls._convert_values_and_row_partition( diff --git a/tensorflow/python/ops/ragged/ragged_tensor_spec.py b/tensorflow/python/ops/ragged/ragged_tensor_spec.py new file mode 100644 index 00000000000..9da282cea3c --- /dev/null +++ b/tensorflow/python/ops/ragged/ragged_tensor_spec.py @@ -0,0 +1,77 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorSpec factory for ragged tensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops.ragged import ragged_tensor + + +def ragged_tensor_spec(shape=None, dtype=dtypes.float32, + ragged_rank=None, row_splits_dtype=dtypes.int64, + name=None): + """Returns a tensor specification for a RaggedTensor. + + Returns an object which can be passed to `tf.function` (or other + functions that expect `TensorSpec`s) to specify shape constraints + for a `RaggedTensor` argument. + + Args: + shape: The shape of the RaggedTensor, or `None` to allow any shape. + dtype: Data type of values in the RaggedTensor. + ragged_rank: Python integer, the ragged rank of the RaggedTensor + to be described. Defaults to `shape.ndims - 1`. + row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. + One of `tf.int32` or `tf.int64`. + name: Optional name prefix for the `TensorSpec`s. + + Returns: + An object describing the `flat_values` and `nested_row_splits` tensors + that comprise the `RaggedTensor`. + """ + dtype = dtypes.as_dtype(dtype) + shape = tensor_shape.TensorShape(shape) + if ragged_rank is None: + if shape.ndims is None: + raise ValueError("Must specify ragged_rank or a shape with known rank.") + ragged_rank = shape.ndims - 1 + elif not isinstance(ragged_rank, int): + raise TypeError("ragged_rank must be an int") + if ragged_rank == 0: + return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name) + + result = tensor_spec.TensorSpec( + tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]), + dtype, name) + + for i in range(ragged_rank - 1, 0, -1): + splits = tensor_spec.TensorSpec( + [None], row_splits_dtype, + "%s.row_splits_%d" % (name, i) if name else None) + result = ragged_tensor.RaggedTensor.from_row_splits(result, splits) + + outer_dim = tensor_shape.dimension_at_index(shape, 0) + splits_shape = [None if outer_dim is None else outer_dim + 1] + splits = tensor_spec.TensorSpec( + splits_shape, row_splits_dtype, + "%s.row_splits_0" % name if name else None) + result = ragged_tensor.RaggedTensor.from_row_splits(result, splits) + + return result