From 9f2b3620f6d4c553da662c772831278b4fa3126e Mon Sep 17 00:00:00 2001 From: Patrick Nguyen Date: Wed, 21 Dec 2016 17:16:43 -0800 Subject: [PATCH] Seal tf.sets. Change: 142717732 --- tensorflow/python/BUILD | 5 +- tensorflow/python/ops/sets.py | 297 +------------------------- tensorflow/python/ops/sets_impl.py | 323 +++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+), 291 deletions(-) create mode 100644 tensorflow/python/ops/sets_impl.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 806fa895da9..bad0e0a0d79 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -944,7 +944,10 @@ py_library( py_library( name = "sets", - srcs = ["ops/sets.py"], + srcs = [ + "ops/sets.py", + "ops/sets_impl.py", + ], srcs_version = "PY2AND3", deps = [ ":framework", diff --git a/tensorflow/python/ops/sets.py b/tensorflow/python/ops/sets.py index 3b1f756481c..1eff9033559 100644 --- a/tensorflow/python/ops/sets.py +++ b/tensorflow/python/ops/sets.py @@ -24,295 +24,12 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_set_ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.sets_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented -_VALID_DTYPES = set([ - dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16, dtypes.string]) - - -def set_size(a, validate_indices=True): - """Compute number of unique elements along last dimension of `a`. - - Args: - a: `SparseTensor`, with indices sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a`. - - Returns: - `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with - rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the - number of unique elements in the corresponding `[0...n-1]` dimension of `a`. - - Raises: - TypeError: If `a` is an invalid types. - """ - a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") - if not isinstance(a, sparse_tensor.SparseTensor): - raise TypeError("Expected `SparseTensor`, got %s." % a) - if a.values.dtype.base_dtype not in _VALID_DTYPES: - raise TypeError("Invalid dtype %s." % a.values.dtype) - # pylint: disable=protected-access - return gen_set_ops.set_size( - a.indices, a.values, a.dense_shape, validate_indices) - -ops.NotDifferentiable("SetSize") - - -ops.NotDifferentiable("DenseToDenseSetOperation") -ops.NotDifferentiable("DenseToSparseSetOperation") -ops.NotDifferentiable("SparseToSparseSetOperation") - - -def _convert_to_tensors_or_sparse_tensors(a, b): - """Convert to tensor types, and flip order if necessary. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. - b: `Tensor` or `SparseTensor` of the same type as `a`. - - Returns: - Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to - `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has - been flipped to make it dense,sparse instead of sparse,dense (since the set - ops do not support the latter). - """ - a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") - if a.dtype.base_dtype not in _VALID_DTYPES: - raise TypeError("'a' invalid dtype %s." % a.dtype) - b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") - if b.dtype.base_dtype != a.dtype.base_dtype: - raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) - if (isinstance(a, sparse_tensor.SparseTensor) and - not isinstance(b, sparse_tensor.SparseTensor)): - return b, a, True - return a, b, False - - -def _set_operation(a, b, set_operation, validate_indices=True): - """Compute set operation of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. Must be - `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be - sorted in row-major order. - set_operation: String indicating set operaiton. See - SetOperationOp::SetOperationFromContext for valid values. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` with the same rank as `a` and `b`, and all but the last - dimension the same. Elements along the last dimension contain the results - of the set operation. - - Raises: - TypeError: If inputs are invalid types. - ValueError: If `a` is sparse and `b` is dense. - """ - if isinstance(a, sparse_tensor.SparseTensor): - if isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( - a.indices, a.values, a.dense_shape, - b.indices, b.values, b.dense_shape, - set_operation, validate_indices) - else: - raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " - "Please flip the order of your inputs.") - elif isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( - a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) - else: - indices, values, shape = gen_set_ops.dense_to_dense_set_operation( - a, b, set_operation, validate_indices) - return sparse_tensor.SparseTensor(indices, values, shape) - - -def set_intersection(a, b, validate_indices=True): - """Compute set intersection of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Example: - a = [ - [ - [ - [1, 2], - [3], - ], - [ - [4], - [5, 6], - ], - ], - ] - b = [ - [ - [ - [1, 3], - [2], - ], - [ - [4, 5], - [5, 6, 7, 8], - ], - ], - ] - set_intersection(a, b) = [ - [ - [ - [1], - [], - ], - [ - [4], - [5, 6], - ], - ], - ] - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices - must be sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but - the last dimension the same. Elements along the last dimension contain the - intersections. - """ - a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) - return _set_operation(a, b, "intersection", validate_indices) - - -def set_difference(a, b, aminusb=True, validate_indices=True): - """Compute set difference of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Example: - a = [ - [ - [ - [1, 2], - [3], - ], - [ - [4], - [5, 6], - ], - ], - ] - b = [ - [ - [ - [1, 3], - [2], - ], - [ - [4, 5], - [5, 6, 7, 8], - ], - ], - ] - set_difference(a, b, aminusb=True) = [ - [ - [ - [2], - [3], - ], - [ - [], - [], - ], - ], - ] - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices - must be sorted in row-major order. - aminusb: Whether to subtract `b` from `a`, vs vice versa. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but - the last dimension the same. Elements along the last dimension contain the - differences. - """ - a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) - if flipped: - aminusb = not aminusb - return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) - - -def set_union(a, b, validate_indices=True): - """Compute set union of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Example: - a = [ - [ - [ - [1, 2], - [3], - ], - [ - [4], - [5, 6], - ], - ], - ] - b = [ - [ - [ - [1, 3], - [2], - ], - [ - [4, 5], - [5, 6, 7, 8], - ], - ], - ] - set_union(a, b) = [ - [ - [ - [1, 2, 3], - [2, 3], - ], - [ - [4, 5], - [5, 6, 7, 8], - ], - ], - ] - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices - must be sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but - the last dimension the same. Elements along the last dimension contain the - unions. - """ - a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) - return _set_operation(a, b, "union", validate_indices) +_allowed_symbols = [] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py new file mode 100644 index 00000000000..f3b52636d4f --- /dev/null +++ b/tensorflow/python/ops/sets_impl.py @@ -0,0 +1,323 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of tf.sets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_set_ops + + +_VALID_DTYPES = set([ + dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16, dtypes.string]) + + +def set_size(a, validate_indices=True): + """Compute number of unique elements along last dimension of `a`. + + Args: + a: `SparseTensor`, with indices sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a`. + + Returns: + `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with + rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the + number of unique elements in the corresponding `[0...n-1]` dimension of `a`. + + Raises: + TypeError: If `a` is an invalid types. + """ + a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") + if not isinstance(a, sparse_tensor.SparseTensor): + raise TypeError("Expected `SparseTensor`, got %s." % a) + if a.values.dtype.base_dtype not in _VALID_DTYPES: + raise TypeError("Invalid dtype %s." % a.values.dtype) + # pylint: disable=protected-access + return gen_set_ops.set_size( + a.indices, a.values, a.dense_shape, validate_indices) + +ops.NotDifferentiable("SetSize") + + +ops.NotDifferentiable("DenseToDenseSetOperation") +ops.NotDifferentiable("DenseToSparseSetOperation") +ops.NotDifferentiable("SparseToSparseSetOperation") + + +def _convert_to_tensors_or_sparse_tensors(a, b): + """Convert to tensor types, and flip order if necessary. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. + b: `Tensor` or `SparseTensor` of the same type as `a`. + + Returns: + Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to + `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has + been flipped to make it dense,sparse instead of sparse,dense (since the set + ops do not support the latter). + """ + a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") + if a.dtype.base_dtype not in _VALID_DTYPES: + raise TypeError("'a' invalid dtype %s." % a.dtype) + b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") + if b.dtype.base_dtype != a.dtype.base_dtype: + raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) + if (isinstance(a, sparse_tensor.SparseTensor) and + not isinstance(b, sparse_tensor.SparseTensor)): + return b, a, True + return a, b, False + + +def _set_operation(a, b, set_operation, validate_indices=True): + """Compute set operation of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. Must be + `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be + sorted in row-major order. + set_operation: String indicating set operaiton. See + SetOperationOp::SetOperationFromContext for valid values. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` with the same rank as `a` and `b`, and all but the last + dimension the same. Elements along the last dimension contain the results + of the set operation. + + Raises: + TypeError: If inputs are invalid types. + ValueError: If `a` is sparse and `b` is dense. + """ + if isinstance(a, sparse_tensor.SparseTensor): + if isinstance(b, sparse_tensor.SparseTensor): + indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( + a.indices, a.values, a.dense_shape, + b.indices, b.values, b.dense_shape, + set_operation, validate_indices) + else: + raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " + "Please flip the order of your inputs.") + elif isinstance(b, sparse_tensor.SparseTensor): + indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( + a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) + else: + indices, values, shape = gen_set_ops.dense_to_dense_set_operation( + a, b, set_operation, validate_indices) + return sparse_tensor.SparseTensor(indices, values, shape) + + +def set_intersection(a, b, validate_indices=True): + """Compute set intersection of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Example: + + ```python + a = [ + [ + [ + [1, 2], + [3], + ], + [ + [4], + [5, 6], + ], + ], + ] + b = [ + [ + [ + [1, 3], + [2], + ], + [ + [4, 5], + [5, 6, 7, 8], + ], + ], + ] + set_intersection(a, b) = [ + [ + [ + [1], + [], + ], + [ + [4], + [5, 6], + ], + ], + ] + ``` + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices + must be sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but + the last dimension the same. Elements along the last dimension contain the + intersections. + """ + a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) + return _set_operation(a, b, "intersection", validate_indices) + + +def set_difference(a, b, aminusb=True, validate_indices=True): + """Compute set difference of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Example: + + ```python + a = [ + [ + [ + [1, 2], + [3], + ], + [ + [4], + [5, 6], + ], + ], + ] + b = [ + [ + [ + [1, 3], + [2], + ], + [ + [4, 5], + [5, 6, 7, 8], + ], + ], + ] + set_difference(a, b, aminusb=True) = [ + [ + [ + [2], + [3], + ], + [ + [], + [], + ], + ], + ] + ``` + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices + must be sorted in row-major order. + aminusb: Whether to subtract `b` from `a`, vs vice versa. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but + the last dimension the same. Elements along the last dimension contain the + differences. + """ + a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) + if flipped: + aminusb = not aminusb + return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) + + +def set_union(a, b, validate_indices=True): + """Compute set union of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Example: + + ```python + a = [ + [ + [ + [1, 2], + [3], + ], + [ + [4], + [5, 6], + ], + ], + ] + b = [ + [ + [ + [1, 3], + [2], + ], + [ + [4, 5], + [5, 6, 7, 8], + ], + ], + ] + set_union(a, b) = [ + [ + [ + [1, 2, 3], + [2, 3], + ], + [ + [4, 5], + [5, 6, 7, 8], + ], + ], + ] + ``` + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices + must be sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but + the last dimension the same. Elements along the last dimension contain the + unions. + """ + a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) + return _set_operation(a, b, "union", validate_indices)