Moves metrics/sets and tensor_util.convert_to_tensor_or_sparse_tensor from contrib to core.

Change: 140793359
This commit is contained in:
A. Unique TensorFlower 2016-12-01 16:28:32 -08:00 committed by TensorFlower Gardener
parent 8b8e0aa63c
commit 1af94c2698
21 changed files with 298 additions and 334 deletions

View File

@ -108,7 +108,6 @@ filegroup(
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/metrics/kernels:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/rnn:all_files",

View File

@ -60,7 +60,6 @@ cc_library(
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/layers:bucketization_op_kernel",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/metrics:set_ops_kernels",
],
)
@ -72,7 +71,6 @@ cc_library(
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/layers:bucketization_op_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/metrics:set_ops_op_lib",
],
)

View File

@ -36,8 +36,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/metrics/kernels/set_kernels.cc"
"${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"

View File

@ -39,6 +39,10 @@ __all__ = [
'with_same_shape']
convert_to_tensor_or_sparse_tensor = (
sparse_tensor.convert_to_tensor_or_sparse_tensor)
def _assert_same_base_type(items, expected_type=None):
r"""Asserts all items are of the same base type.
@ -361,33 +365,3 @@ def with_shape(expected_shape, tensor):
tensor.name, expected_shape, actual_shape))
return tensor
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
"""Converts value to a `SparseTensor` or `Tensor`.
Args:
value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
registered `Tensor` conversion function.
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
name: Optional name to use if a new `Tensor` is created.
Returns:
A `SparseTensor` or `Tensor` based on `value`.
Raises:
RuntimeError: If result type is incompatible with `dtype`.
"""
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
if isinstance(value, sparse_tensor.SparseTensorValue):
value = sparse_tensor.SparseTensor.from_value(value)
if isinstance(value, sparse_tensor.SparseTensor):
if dtype and not dtype.is_compatible_with(value.dtype):
raise RuntimeError(
'Sparse dtype: requested = %s, actual = %s' % (
dtype.name, value.dtype.name))
return value
return ops.internal_convert_to_tensor(
value, dtype=dtype, name=name)

View File

@ -279,32 +279,6 @@ class WithShapeTest(tf.test.TestCase):
ValueError, tensor_2x2.eval, {tensor_partial_shape: [42.0]})
class ConvertToTensorOrSparseTensorTest(tf.test.TestCase):
def test_convert_dense(self):
with self.test_session():
value = [42, 43]
from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
with self.test_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
sparse_tensor_value = tf.SparseTensorValue(indices, values, shape)
sparse_tensor = tf.SparseTensor.from_value(sparse_tensor_value)
from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
sparse_tensor_value).eval()
from_tensor = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
sparse_tensor).eval()
for convertee in [from_value, from_tensor]:
self.assertAllEqual(sparse_tensor_value.indices, convertee.indices)
self.assertAllEqual(sparse_tensor_value.values, convertee.values)
self.assertAllEqual(sparse_tensor_value.shape, convertee.shape)
class RemoveSqueezableDimensionsTest(tf.test.TestCase):
def testRemoveSqueezableDimensions(self):

View File

@ -8,93 +8,16 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)
tf_custom_op_library(
# TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it.
name = "python/ops/_set_ops.so",
srcs = [
"ops/set_ops.cc",
],
deps = [
"//tensorflow/contrib/metrics/kernels:set_kernels",
],
)
tf_gen_op_libs(
op_lib_names = ["set_ops"],
)
tf_gen_op_wrapper_py(
name = "set_ops",
hidden = [
"DenseToDenseSetOperation",
"DenseToSparseSetOperation",
"SparseToSparseSetOperation",
"SetSize",
],
deps = [":set_ops_op_lib"],
)
tf_kernel_library(
name = "set_ops_kernels",
deps = [
"//tensorflow/contrib/metrics/kernels:set_kernels",
"//tensorflow/core:framework",
],
alwayslink = 1,
)
load("//tensorflow:tensorflow.bzl", "tf_py_test")
py_library(
name = "metrics_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + glob(["python/metrics/*.py"]),
data = [":python/ops/_set_ops.so"],
srcs_version = "PY2AND3",
deps = [":set_ops"],
)
py_test(
name = "set_ops_test",
size = "small",
srcs = ["python/kernel_tests/set_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":metrics_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
tf_cc_tests(
size = "small",
srcs = [
"ops/set_ops_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":set_ops_op_lib",
"//tensorflow/cc:cc_ops",
"//tensorflow/core",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//third_party/eigen3",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:sets",
],
)

View File

@ -1,31 +0,0 @@
# Description:
# Contains kernels for evaluation metrics and summary statistics.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
cc_library(
name = "set_kernels",
srcs = ["set_kernels.cc"],
copts = ["-Wno-sign-compare"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@protobuf//:protobuf",
],
alwayslink = 1,
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -17,167 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import resource_loader
from tensorflow.python.ops import sets
_set_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_set_ops.so"))
set_size = sets.set_size
_VALID_DTYPES = set([
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.string])
set_intersection = sets.set_intersection
set_difference = sets.set_difference
def set_size(a, validate_indices=True):
"""Compute number of unique elements along last dimension of `a`.
Args:
a: `SparseTensor`, with indices sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a`.
Returns:
`int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
Raises:
TypeError: If `a` is an invalid types.
"""
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
if not isinstance(a, sparse_tensor.SparseTensor):
raise TypeError("Expected `SparseTensor`, got %s." % a)
if a.values.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("Invalid dtype %s." % a.values.dtype)
# pylint: disable=protected-access
return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
ops.NotDifferentiable("SetSize")
ops.NotDifferentiable("DenseToDenseSetOperation")
ops.NotDifferentiable("DenseToSparseSetOperation")
ops.NotDifferentiable("SparseToSparseSetOperation")
def _set_operation(a, b, set_operation, validate_indices=True):
"""Compute set operation of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
set_operation: String indicating set operaiton. See
SetOperationOp::SetOperationFromContext for valid values.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the results
of the set operation.
Raises:
TypeError: If inputs are invalid types.
ValueError: If `a` is sparse and `b` is dense.
"""
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
if a.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("'a' invalid dtype %s." % a.dtype)
b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b")
if b.dtype.base_dtype != a.dtype.base_dtype:
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
# pylint: disable=protected-access
if isinstance(a, sparse_tensor.SparseTensor):
if isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
set_operation, validate_indices)
else:
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
"Please flip the order of your inputs.")
elif isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
a, b.indices, b.values, b.shape, set_operation, validate_indices)
else:
indices, values, shape = _set_ops.dense_to_dense_set_operation(
a, b, set_operation, validate_indices)
# pylint: enable=protected-access
return sparse_tensor.SparseTensor(indices, values, shape)
def set_intersection(a, b, validate_indices=True):
"""Compute set intersection of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
intersections.
"""
return _set_operation(a, b, "intersection", validate_indices)
def set_difference(a, b, aminusb=True, validate_indices=True):
"""Compute set difference of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
aminusb: Whether to subtract `b` from `a`, vs vice versa.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
differences.
"""
return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
def set_union(a, b, validate_indices=True):
"""Compute set union of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
unions.
"""
return _set_operation(a, b, "union", validate_indices)
set_union = sets.set_union

View File

@ -422,6 +422,7 @@ tf_gen_op_libs(
"random_ops",
"resource_variable_ops",
"sdca_ops",
"set_ops",
"script_ops",
"sendrecv_ops",
"sparse_ops",
@ -465,6 +466,7 @@ cc_library(
":script_ops_op_lib",
":sdca_ops_op_lib",
":sendrecv_ops_op_lib",
":set_ops_op_lib",
":sparse_ops_op_lib",
":state_ops_op_lib",
":string_ops_op_lib",
@ -587,6 +589,7 @@ cc_library(
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
"//tensorflow/core/kernels:string",
@ -2179,6 +2182,7 @@ tf_cc_tests(
"ops/nn_ops_test.cc",
"ops/parsing_ops_test.cc",
"ops/random_ops_test.cc",
"ops/set_ops_test.cc",
"ops/sparse_ops_test.cc",
"ops/state_ops_test.cc",
"ops/string_ops_test.cc",

View File

@ -431,6 +431,17 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
tf_kernel_library(
name = "set_kernels",
prefix = "set_kernels",
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:set_ops_op_lib",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "debug_ops",
prefix = "debug_ops",

View File

@ -746,6 +746,11 @@ tf_gen_op_wrapper_private_py(
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "set_ops_gen",
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "state_ops_gen",
require_shape_functions = True,
@ -795,6 +800,16 @@ py_library(
],
)
py_library(
name = "sets",
srcs = ["ops/sets.py"],
srcs_version = "PY2AND3",
deps = [
":framework",
":set_ops_gen",
],
)
py_library(
name = "candidate_sampling_ops",
srcs = ["ops/candidate_sampling_ops.py"],
@ -1516,6 +1531,7 @@ py_library(
":script_ops",
":seq2seq",
":session_ops",
":sets",
":sparse_grad",
":sparse_ops",
":special_math_ops",

View File

@ -87,6 +87,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import resources
from tensorflow.python.ops import sdca_ops as sdca
from tensorflow.python.ops import image_ops as image
from tensorflow.python.ops import sets
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
from tensorflow.python.summary import summary
@ -223,6 +224,7 @@ _allowed_symbols.extend([
'resources',
'resource_loader',
'sdca',
'sets',
'summary',
'sysconfig',
'test',
@ -244,8 +246,9 @@ remove_undocumented(__name__, _allowed_symbols,
[framework_lib, array_ops, client_lib, check_ops,
compat, constant_op, control_flow_ops, functional_ops,
histogram_ops, io_ops, math_ops, nn, resource_loader,
resources, script_ops, session_ops, sparse_ops, state_ops,
string_ops, summary, tensor_array_ops, train, layers])
resources, sets, script_ops, session_ops, sparse_ops,
state_ops, string_ops, summary, tensor_array_ops, train,
layers])
# Special dunders that we choose to export:
_exported_dunders = set([

View File

@ -35,6 +35,7 @@
@@control_dependencies
@@convert_to_tensor
@@convert_to_tensor_or_indexed_slices
@@convert_to_tensor_or_sparse_tensor
@@get_default_graph
@@reset_default_graph
@@import_graph_def
@ -93,6 +94,7 @@ from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
from tensorflow.python.framework.random_seed import get_seed
from tensorflow.python.framework.random_seed import set_random_seed
from tensorflow.python.framework.sparse_tensor import convert_to_tensor_or_sparse_tensor
from tensorflow.python.framework.subscribe import subscribe
from tensorflow.python.framework.importer import import_graph_def

View File

@ -302,3 +302,33 @@ class SparseTensorValue(object):
def __getitem__(self, i):
return [self.indices, self.values, self.dense_shape][i]
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
"""Converts value to a `SparseTensor` or `Tensor`.
Args:
value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
registered `Tensor` conversion function.
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
name: Optional name to use if a new `Tensor` is created.
Returns:
A `SparseTensor` or `Tensor` based on `value`.
Raises:
RuntimeError: If result type is incompatible with `dtype`.
"""
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
if isinstance(value, SparseTensorValue):
value = SparseTensor.from_value(value)
if isinstance(value, SparseTensor):
if dtype and not dtype.is_compatible_with(value.dtype):
raise RuntimeError(
"Sparse dtype: requested = %s, actual = %s" % (
dtype.name, value.dtype.name))
return value
return ops.internal_convert_to_tensor(
value, dtype=dtype, name=name)

View File

@ -52,5 +52,31 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sess_run_value.shape, value.shape)
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
def test_convert_dense(self):
with self.test_session():
value = [42, 43]
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
with self.test_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
sparse_tensor_value = sparse_tensor.SparseTensorValue(
indices, values, shape)
st = sparse_tensor.SparseTensor.from_value(sparse_tensor_value)
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
sparse_tensor_value).eval()
from_tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(st).eval()
for convertee in [from_value, from_tensor]:
self.assertAllEqual(sparse_tensor_value.indices, convertee.indices)
self.assertAllEqual(sparse_tensor_value.values, convertee.values)
self.assertAllEqual(sparse_tensor_value.shape, convertee.shape)
if __name__ == "__main__":
googletest.main()

View File

@ -1380,6 +1380,13 @@ sycl_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)
tf_py_test(
name = "sets_test",
size = "small",
srcs = ["sets_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,184 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python layer for sets.
@@set_size
@@set_intersection
@@set_union
@@set_difference
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_set_ops
_VALID_DTYPES = set([
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.string])
def set_size(a, validate_indices=True):
"""Compute number of unique elements along last dimension of `a`.
Args:
a: `SparseTensor`, with indices sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a`.
Returns:
`int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
Raises:
TypeError: If `a` is an invalid types.
"""
a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
if not isinstance(a, sparse_tensor.SparseTensor):
raise TypeError("Expected `SparseTensor`, got %s." % a)
if a.values.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("Invalid dtype %s." % a.values.dtype)
# pylint: disable=protected-access
return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
ops.NotDifferentiable("SetSize")
ops.NotDifferentiable("DenseToDenseSetOperation")
ops.NotDifferentiable("DenseToSparseSetOperation")
ops.NotDifferentiable("SparseToSparseSetOperation")
def _set_operation(a, b, set_operation, validate_indices=True):
"""Compute set operation of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
set_operation: String indicating set operaiton. See
SetOperationOp::SetOperationFromContext for valid values.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the results
of the set operation.
Raises:
TypeError: If inputs are invalid types.
ValueError: If `a` is sparse and `b` is dense.
"""
a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
if a.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("'a' invalid dtype %s." % a.dtype)
b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
if b.dtype.base_dtype != a.dtype.base_dtype:
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
# pylint: disable=protected-access
if isinstance(a, sparse_tensor.SparseTensor):
if isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
set_operation, validate_indices)
else:
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
"Please flip the order of your inputs.")
elif isinstance(b, sparse_tensor.SparseTensor):
indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
a, b.indices, b.values, b.shape, set_operation, validate_indices)
else:
indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
a, b, set_operation, validate_indices)
# pylint: enable=protected-access
return sparse_tensor.SparseTensor(indices, values, shape)
def set_intersection(a, b, validate_indices=True):
"""Compute set intersection of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
intersections.
"""
return _set_operation(a, b, "intersection", validate_indices)
def set_difference(a, b, aminusb=True, validate_indices=True):
"""Compute set difference of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
aminusb: Whether to subtract `b` from `a`, vs vice versa.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
differences.
"""
return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
def set_union(a, b, validate_indices=True):
"""Compute set union of elements in last dimension of `a` and `b`.
All but the last dimension of `a` and `b` must match.
Args:
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
must be sorted in row-major order.
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
sorted in row-major order.
validate_indices: Whether to validate the order and range of sparse indices
in `a` and `b`.
Returns:
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
dimension the same. Elements along the last dimension contain the
unions.
"""
return _set_operation(a, b, "union", validate_indices)