Moves metrics/sets and tensor_util.convert_to_tensor_or_sparse_tensor from contrib to core.
Change: 140793359
This commit is contained in:
parent
8b8e0aa63c
commit
1af94c2698
tensorflow
BUILD
contrib
core
python
@ -108,7 +108,6 @@ filegroup(
|
||||
"//tensorflow/contrib/lookup:all_files",
|
||||
"//tensorflow/contrib/losses:all_files",
|
||||
"//tensorflow/contrib/metrics:all_files",
|
||||
"//tensorflow/contrib/metrics/kernels:all_files",
|
||||
"//tensorflow/contrib/ndlstm:all_files",
|
||||
"//tensorflow/contrib/opt:all_files",
|
||||
"//tensorflow/contrib/rnn:all_files",
|
||||
|
@ -60,7 +60,6 @@ cc_library(
|
||||
"//tensorflow/contrib/factorization/kernels:all_kernels",
|
||||
"//tensorflow/contrib/layers:bucketization_op_kernel",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
"//tensorflow/contrib/metrics:set_ops_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
@ -72,7 +71,6 @@ cc_library(
|
||||
"//tensorflow/contrib/framework:all_ops",
|
||||
"//tensorflow/contrib/layers:bucketization_op_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/metrics:set_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,8 +36,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/metrics/kernels/set_kernels.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
|
||||
|
@ -39,6 +39,10 @@ __all__ = [
|
||||
'with_same_shape']
|
||||
|
||||
|
||||
convert_to_tensor_or_sparse_tensor = (
|
||||
sparse_tensor.convert_to_tensor_or_sparse_tensor)
|
||||
|
||||
|
||||
def _assert_same_base_type(items, expected_type=None):
|
||||
r"""Asserts all items are of the same base type.
|
||||
|
||||
@ -361,33 +365,3 @@ def with_shape(expected_shape, tensor):
|
||||
tensor.name, expected_shape, actual_shape))
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
|
||||
"""Converts value to a `SparseTensor` or `Tensor`.
|
||||
|
||||
Args:
|
||||
value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
|
||||
registered `Tensor` conversion function.
|
||||
dtype: Optional element type for the returned tensor. If missing, the
|
||||
type is inferred from the type of `value`.
|
||||
name: Optional name to use if a new `Tensor` is created.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` or `Tensor` based on `value`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If result type is incompatible with `dtype`.
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if isinstance(value, sparse_tensor.SparseTensorValue):
|
||||
value = sparse_tensor.SparseTensor.from_value(value)
|
||||
if isinstance(value, sparse_tensor.SparseTensor):
|
||||
if dtype and not dtype.is_compatible_with(value.dtype):
|
||||
raise RuntimeError(
|
||||
'Sparse dtype: requested = %s, actual = %s' % (
|
||||
dtype.name, value.dtype.name))
|
||||
return value
|
||||
return ops.internal_convert_to_tensor(
|
||||
value, dtype=dtype, name=name)
|
||||
|
@ -279,32 +279,6 @@ class WithShapeTest(tf.test.TestCase):
|
||||
ValueError, tensor_2x2.eval, {tensor_partial_shape: [42.0]})
|
||||
|
||||
|
||||
class ConvertToTensorOrSparseTensorTest(tf.test.TestCase):
|
||||
|
||||
def test_convert_dense(self):
|
||||
with self.test_session():
|
||||
value = [42, 43]
|
||||
from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
|
||||
value)
|
||||
self.assertAllEqual(value, from_value.eval())
|
||||
|
||||
def test_convert_sparse(self):
|
||||
with self.test_session():
|
||||
indices = [[0, 1], [1, 0]]
|
||||
values = [42, 43]
|
||||
shape = [2, 2]
|
||||
sparse_tensor_value = tf.SparseTensorValue(indices, values, shape)
|
||||
sparse_tensor = tf.SparseTensor.from_value(sparse_tensor_value)
|
||||
from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
|
||||
sparse_tensor_value).eval()
|
||||
from_tensor = tf.contrib.framework.convert_to_tensor_or_sparse_tensor(
|
||||
sparse_tensor).eval()
|
||||
for convertee in [from_value, from_tensor]:
|
||||
self.assertAllEqual(sparse_tensor_value.indices, convertee.indices)
|
||||
self.assertAllEqual(sparse_tensor_value.values, convertee.values)
|
||||
self.assertAllEqual(sparse_tensor_value.shape, convertee.shape)
|
||||
|
||||
|
||||
class RemoveSqueezableDimensionsTest(tf.test.TestCase):
|
||||
|
||||
def testRemoveSqueezableDimensions(self):
|
||||
|
@ -8,93 +8,16 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
# TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it.
|
||||
name = "python/ops/_set_ops.so",
|
||||
srcs = [
|
||||
"ops/set_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/metrics/kernels:set_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["set_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "set_ops",
|
||||
hidden = [
|
||||
"DenseToDenseSetOperation",
|
||||
"DenseToSparseSetOperation",
|
||||
"SparseToSparseSetOperation",
|
||||
"SetSize",
|
||||
],
|
||||
deps = [":set_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "set_ops_kernels",
|
||||
deps = [
|
||||
"//tensorflow/contrib/metrics/kernels:set_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
py_library(
|
||||
name = "metrics_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + glob(["python/metrics/*.py"]),
|
||||
data = [":python/ops/_set_ops.so"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":set_ops"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "set_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/set_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metrics_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
size = "small",
|
||||
srcs = [
|
||||
"ops/set_ops_test.cc",
|
||||
],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":set_ops_op_lib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:sets",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,31 +0,0 @@
|
||||
# Description:
|
||||
# Contains kernels for evaluation metrics and summary statistics.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
cc_library(
|
||||
name = "set_kernels",
|
||||
srcs = ["set_kernels.cc"],
|
||||
copts = ["-Wno-sign-compare"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@protobuf//:protobuf",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
@ -17,167 +17,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.ops import sets
|
||||
|
||||
|
||||
_set_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_set_ops.so"))
|
||||
set_size = sets.set_size
|
||||
|
||||
_VALID_DTYPES = set([
|
||||
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||
dtypes.uint8, dtypes.uint16, dtypes.string])
|
||||
set_intersection = sets.set_intersection
|
||||
|
||||
set_difference = sets.set_difference
|
||||
|
||||
def set_size(a, validate_indices=True):
|
||||
"""Compute number of unique elements along last dimension of `a`.
|
||||
|
||||
Args:
|
||||
a: `SparseTensor`, with indices sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a`.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
|
||||
rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
|
||||
number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `a` is an invalid types.
|
||||
"""
|
||||
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||
if not isinstance(a, sparse_tensor.SparseTensor):
|
||||
raise TypeError("Expected `SparseTensor`, got %s." % a)
|
||||
if a.values.dtype.base_dtype not in _VALID_DTYPES:
|
||||
raise TypeError("Invalid dtype %s." % a.values.dtype)
|
||||
# pylint: disable=protected-access
|
||||
return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
|
||||
|
||||
ops.NotDifferentiable("SetSize")
|
||||
|
||||
|
||||
ops.NotDifferentiable("DenseToDenseSetOperation")
|
||||
ops.NotDifferentiable("DenseToSparseSetOperation")
|
||||
ops.NotDifferentiable("SparseToSparseSetOperation")
|
||||
|
||||
|
||||
def _set_operation(a, b, set_operation, validate_indices=True):
|
||||
"""Compute set operation of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
set_operation: String indicating set operaiton. See
|
||||
SetOperationOp::SetOperationFromContext for valid values.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the results
|
||||
of the set operation.
|
||||
|
||||
Raises:
|
||||
TypeError: If inputs are invalid types.
|
||||
ValueError: If `a` is sparse and `b` is dense.
|
||||
"""
|
||||
a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||
if a.dtype.base_dtype not in _VALID_DTYPES:
|
||||
raise TypeError("'a' invalid dtype %s." % a.dtype)
|
||||
b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b")
|
||||
if b.dtype.base_dtype != a.dtype.base_dtype:
|
||||
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(a, sparse_tensor.SparseTensor):
|
||||
if isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
|
||||
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
|
||||
set_operation, validate_indices)
|
||||
else:
|
||||
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
|
||||
"Please flip the order of your inputs.")
|
||||
elif isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
|
||||
a, b.indices, b.values, b.shape, set_operation, validate_indices)
|
||||
else:
|
||||
indices, values, shape = _set_ops.dense_to_dense_set_operation(
|
||||
a, b, set_operation, validate_indices)
|
||||
# pylint: enable=protected-access
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
def set_intersection(a, b, validate_indices=True):
|
||||
"""Compute set intersection of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
intersections.
|
||||
"""
|
||||
return _set_operation(a, b, "intersection", validate_indices)
|
||||
|
||||
|
||||
def set_difference(a, b, aminusb=True, validate_indices=True):
|
||||
"""Compute set difference of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
aminusb: Whether to subtract `b` from `a`, vs vice versa.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
differences.
|
||||
"""
|
||||
return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
|
||||
|
||||
|
||||
def set_union(a, b, validate_indices=True):
|
||||
"""Compute set union of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
unions.
|
||||
"""
|
||||
return _set_operation(a, b, "union", validate_indices)
|
||||
set_union = sets.set_union
|
||||
|
@ -422,6 +422,7 @@ tf_gen_op_libs(
|
||||
"random_ops",
|
||||
"resource_variable_ops",
|
||||
"sdca_ops",
|
||||
"set_ops",
|
||||
"script_ops",
|
||||
"sendrecv_ops",
|
||||
"sparse_ops",
|
||||
@ -465,6 +466,7 @@ cc_library(
|
||||
":script_ops_op_lib",
|
||||
":sdca_ops_op_lib",
|
||||
":sendrecv_ops_op_lib",
|
||||
":set_ops_op_lib",
|
||||
":sparse_ops_op_lib",
|
||||
":state_ops_op_lib",
|
||||
":string_ops_op_lib",
|
||||
@ -587,6 +589,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sdca_ops",
|
||||
"//tensorflow/core/kernels:set_kernels",
|
||||
"//tensorflow/core/kernels:sparse",
|
||||
"//tensorflow/core/kernels:state",
|
||||
"//tensorflow/core/kernels:string",
|
||||
@ -2179,6 +2182,7 @@ tf_cc_tests(
|
||||
"ops/nn_ops_test.cc",
|
||||
"ops/parsing_ops_test.cc",
|
||||
"ops/random_ops_test.cc",
|
||||
"ops/set_ops_test.cc",
|
||||
"ops/sparse_ops_test.cc",
|
||||
"ops/state_ops_test.cc",
|
||||
"ops/string_ops_test.cc",
|
||||
|
@ -431,6 +431,17 @@ tf_kernel_library(
|
||||
deps = ARRAY_DEPS,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "set_kernels",
|
||||
prefix = "set_kernels",
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:set_ops_op_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "debug_ops",
|
||||
prefix = "debug_ops",
|
||||
|
@ -746,6 +746,11 @@ tf_gen_op_wrapper_private_py(
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "set_ops_gen",
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "state_ops_gen",
|
||||
require_shape_functions = True,
|
||||
@ -795,6 +800,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sets",
|
||||
srcs = ["ops/sets.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":set_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "candidate_sampling_ops",
|
||||
srcs = ["ops/candidate_sampling_ops.py"],
|
||||
@ -1516,6 +1531,7 @@ py_library(
|
||||
":script_ops",
|
||||
":seq2seq",
|
||||
":session_ops",
|
||||
":sets",
|
||||
":sparse_grad",
|
||||
":sparse_ops",
|
||||
":special_math_ops",
|
||||
|
@ -87,6 +87,7 @@ from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import sdca_ops as sdca
|
||||
from tensorflow.python.ops import image_ops as image
|
||||
from tensorflow.python.ops import sets
|
||||
from tensorflow.python.user_ops import user_ops
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.summary import summary
|
||||
@ -223,6 +224,7 @@ _allowed_symbols.extend([
|
||||
'resources',
|
||||
'resource_loader',
|
||||
'sdca',
|
||||
'sets',
|
||||
'summary',
|
||||
'sysconfig',
|
||||
'test',
|
||||
@ -244,8 +246,9 @@ remove_undocumented(__name__, _allowed_symbols,
|
||||
[framework_lib, array_ops, client_lib, check_ops,
|
||||
compat, constant_op, control_flow_ops, functional_ops,
|
||||
histogram_ops, io_ops, math_ops, nn, resource_loader,
|
||||
resources, script_ops, session_ops, sparse_ops, state_ops,
|
||||
string_ops, summary, tensor_array_ops, train, layers])
|
||||
resources, sets, script_ops, session_ops, sparse_ops,
|
||||
state_ops, string_ops, summary, tensor_array_ops, train,
|
||||
layers])
|
||||
|
||||
# Special dunders that we choose to export:
|
||||
_exported_dunders = set([
|
||||
|
@ -35,6 +35,7 @@
|
||||
@@control_dependencies
|
||||
@@convert_to_tensor
|
||||
@@convert_to_tensor_or_indexed_slices
|
||||
@@convert_to_tensor_or_sparse_tensor
|
||||
@@get_default_graph
|
||||
@@reset_default_graph
|
||||
@@import_graph_def
|
||||
@ -93,6 +94,7 @@ from tensorflow.python.framework.ops import convert_to_tensor
|
||||
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
|
||||
from tensorflow.python.framework.random_seed import get_seed
|
||||
from tensorflow.python.framework.random_seed import set_random_seed
|
||||
from tensorflow.python.framework.sparse_tensor import convert_to_tensor_or_sparse_tensor
|
||||
from tensorflow.python.framework.subscribe import subscribe
|
||||
from tensorflow.python.framework.importer import import_graph_def
|
||||
|
||||
|
@ -302,3 +302,33 @@ class SparseTensorValue(object):
|
||||
|
||||
def __getitem__(self, i):
|
||||
return [self.indices, self.values, self.dense_shape][i]
|
||||
|
||||
|
||||
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
|
||||
"""Converts value to a `SparseTensor` or `Tensor`.
|
||||
|
||||
Args:
|
||||
value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
|
||||
registered `Tensor` conversion function.
|
||||
dtype: Optional element type for the returned tensor. If missing, the
|
||||
type is inferred from the type of `value`.
|
||||
name: Optional name to use if a new `Tensor` is created.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` or `Tensor` based on `value`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If result type is incompatible with `dtype`.
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if isinstance(value, SparseTensorValue):
|
||||
value = SparseTensor.from_value(value)
|
||||
if isinstance(value, SparseTensor):
|
||||
if dtype and not dtype.is_compatible_with(value.dtype):
|
||||
raise RuntimeError(
|
||||
"Sparse dtype: requested = %s, actual = %s" % (
|
||||
dtype.name, value.dtype.name))
|
||||
return value
|
||||
return ops.internal_convert_to_tensor(
|
||||
value, dtype=dtype, name=name)
|
||||
|
@ -52,5 +52,31 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(sess_run_value.shape, value.shape)
|
||||
|
||||
|
||||
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_convert_dense(self):
|
||||
with self.test_session():
|
||||
value = [42, 43]
|
||||
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
|
||||
value)
|
||||
self.assertAllEqual(value, from_value.eval())
|
||||
|
||||
def test_convert_sparse(self):
|
||||
with self.test_session():
|
||||
indices = [[0, 1], [1, 0]]
|
||||
values = [42, 43]
|
||||
shape = [2, 2]
|
||||
sparse_tensor_value = sparse_tensor.SparseTensorValue(
|
||||
indices, values, shape)
|
||||
st = sparse_tensor.SparseTensor.from_value(sparse_tensor_value)
|
||||
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
|
||||
sparse_tensor_value).eval()
|
||||
from_tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(st).eval()
|
||||
for convertee in [from_value, from_tensor]:
|
||||
self.assertAllEqual(sparse_tensor_value.indices, convertee.indices)
|
||||
self.assertAllEqual(sparse_tensor_value.values, convertee.values)
|
||||
self.assertAllEqual(sparse_tensor_value.shape, convertee.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -1380,6 +1380,13 @@ sycl_py_test(
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "sets_test",
|
||||
size = "small",
|
||||
srcs = ["sets_test.py"],
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
184
tensorflow/python/ops/sets.py
Normal file
184
tensorflow/python/ops/sets.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python layer for sets.
|
||||
|
||||
@@set_size
|
||||
@@set_intersection
|
||||
@@set_union
|
||||
@@set_difference
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import gen_set_ops
|
||||
|
||||
|
||||
_VALID_DTYPES = set([
|
||||
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||
dtypes.uint8, dtypes.uint16, dtypes.string])
|
||||
|
||||
|
||||
def set_size(a, validate_indices=True):
|
||||
"""Compute number of unique elements along last dimension of `a`.
|
||||
|
||||
Args:
|
||||
a: `SparseTensor`, with indices sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a`.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
|
||||
rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
|
||||
number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `a` is an invalid types.
|
||||
"""
|
||||
a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||
if not isinstance(a, sparse_tensor.SparseTensor):
|
||||
raise TypeError("Expected `SparseTensor`, got %s." % a)
|
||||
if a.values.dtype.base_dtype not in _VALID_DTYPES:
|
||||
raise TypeError("Invalid dtype %s." % a.values.dtype)
|
||||
# pylint: disable=protected-access
|
||||
return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
|
||||
|
||||
ops.NotDifferentiable("SetSize")
|
||||
|
||||
|
||||
ops.NotDifferentiable("DenseToDenseSetOperation")
|
||||
ops.NotDifferentiable("DenseToSparseSetOperation")
|
||||
ops.NotDifferentiable("SparseToSparseSetOperation")
|
||||
|
||||
|
||||
def _set_operation(a, b, set_operation, validate_indices=True):
|
||||
"""Compute set operation of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
set_operation: String indicating set operaiton. See
|
||||
SetOperationOp::SetOperationFromContext for valid values.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the results
|
||||
of the set operation.
|
||||
|
||||
Raises:
|
||||
TypeError: If inputs are invalid types.
|
||||
ValueError: If `a` is sparse and `b` is dense.
|
||||
"""
|
||||
a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
|
||||
if a.dtype.base_dtype not in _VALID_DTYPES:
|
||||
raise TypeError("'a' invalid dtype %s." % a.dtype)
|
||||
b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
|
||||
if b.dtype.base_dtype != a.dtype.base_dtype:
|
||||
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(a, sparse_tensor.SparseTensor):
|
||||
if isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
|
||||
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
|
||||
set_operation, validate_indices)
|
||||
else:
|
||||
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
|
||||
"Please flip the order of your inputs.")
|
||||
elif isinstance(b, sparse_tensor.SparseTensor):
|
||||
indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
|
||||
a, b.indices, b.values, b.shape, set_operation, validate_indices)
|
||||
else:
|
||||
indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
|
||||
a, b, set_operation, validate_indices)
|
||||
# pylint: enable=protected-access
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
def set_intersection(a, b, validate_indices=True):
|
||||
"""Compute set intersection of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
intersections.
|
||||
"""
|
||||
return _set_operation(a, b, "intersection", validate_indices)
|
||||
|
||||
|
||||
def set_difference(a, b, aminusb=True, validate_indices=True):
|
||||
"""Compute set difference of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
aminusb: Whether to subtract `b` from `a`, vs vice versa.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
differences.
|
||||
"""
|
||||
return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
|
||||
|
||||
|
||||
def set_union(a, b, validate_indices=True):
|
||||
"""Compute set union of elements in last dimension of `a` and `b`.
|
||||
|
||||
All but the last dimension of `a` and `b` must match.
|
||||
|
||||
Args:
|
||||
a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
|
||||
must be sorted in row-major order.
|
||||
b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
|
||||
`SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
|
||||
sorted in row-major order.
|
||||
validate_indices: Whether to validate the order and range of sparse indices
|
||||
in `a` and `b`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same rank as `a` and `b`, and all but the last
|
||||
dimension the same. Elements along the last dimension contain the
|
||||
unions.
|
||||
"""
|
||||
return _set_operation(a, b, "union", validate_indices)
|
Loading…
Reference in New Issue
Block a user