From 1af94c269874440373c1d18d823110b1f5eabc19 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Thu, 1 Dec 2016 16:28:32 -0800 Subject: [PATCH] Moves metrics/sets and tensor_util.convert_to_tensor_or_sparse_tensor from contrib to core. Change: 140793359 --- tensorflow/BUILD | 1 - tensorflow/contrib/BUILD | 2 - .../contrib/cmake/tf_core_kernels.cmake | 2 - .../framework/python/framework/tensor_util.py | 34 +--- .../python/framework/tensor_util_test.py | 26 --- tensorflow/contrib/metrics/BUILD | 85 +------- tensorflow/contrib/metrics/kernels/BUILD | 31 --- .../contrib/metrics/python/ops/set_ops.py | 164 +--------------- tensorflow/core/BUILD | 4 + tensorflow/core/kernels/BUILD | 11 ++ .../metrics => core}/kernels/set_kernels.cc | 0 .../{contrib/metrics => core}/ops/set_ops.cc | 0 .../metrics => core}/ops/set_ops_test.cc | 0 tensorflow/python/BUILD | 16 ++ tensorflow/python/__init__.py | 7 +- tensorflow/python/framework/framework_lib.py | 2 + tensorflow/python/framework/sparse_tensor.py | 30 +++ .../python/framework/sparse_tensor_test.py | 26 +++ tensorflow/python/kernel_tests/BUILD | 7 + .../kernel_tests/sets_test.py} | 0 tensorflow/python/ops/sets.py | 184 ++++++++++++++++++ 21 files changed, 298 insertions(+), 334 deletions(-) delete mode 100644 tensorflow/contrib/metrics/kernels/BUILD rename tensorflow/{contrib/metrics => core}/kernels/set_kernels.cc (100%) rename tensorflow/{contrib/metrics => core}/ops/set_ops.cc (100%) rename tensorflow/{contrib/metrics => core}/ops/set_ops_test.cc (100%) rename tensorflow/{contrib/metrics/python/kernel_tests/set_ops_test.py => python/kernel_tests/sets_test.py} (100%) create mode 100644 tensorflow/python/ops/sets.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 7ba877d5f8f..73b5954337d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -108,7 +108,6 @@ filegroup( "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/metrics:all_files", - "//tensorflow/contrib/metrics/kernels:all_files", "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/rnn:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 9a6b116c619..27cb689d794 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -60,7 +60,6 @@ cc_library( "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/layers:bucketization_op_kernel", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", - "//tensorflow/contrib/metrics:set_ops_kernels", ], ) @@ -72,7 +71,6 @@ cc_library( "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/layers:bucketization_op_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", - "//tensorflow/contrib/metrics:set_ops_op_lib", ], ) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 7af5e3cc441..96554145f3c 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -36,8 +36,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/metrics/kernels/set_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc" diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index c149d148499..a326b78a5f2 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -39,6 +39,10 @@ __all__ = [ 'with_same_shape'] +convert_to_tensor_or_sparse_tensor = ( + sparse_tensor.convert_to_tensor_or_sparse_tensor) + + def _assert_same_base_type(items, expected_type=None): r"""Asserts all items are of the same base type. @@ -361,33 +365,3 @@ def with_shape(expected_shape, tensor): tensor.name, expected_shape, actual_shape)) return tensor - - -def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): - """Converts value to a `SparseTensor` or `Tensor`. - - Args: - value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a - registered `Tensor` conversion function. - dtype: Optional element type for the returned tensor. If missing, the - type is inferred from the type of `value`. - name: Optional name to use if a new `Tensor` is created. - - Returns: - A `SparseTensor` or `Tensor` based on `value`. - - Raises: - RuntimeError: If result type is incompatible with `dtype`. - """ - if dtype is not None: - dtype = dtypes.as_dtype(dtype) - if isinstance(value, sparse_tensor.SparseTensorValue): - value = sparse_tensor.SparseTensor.from_value(value) - if isinstance(value, sparse_tensor.SparseTensor): - if dtype and not dtype.is_compatible_with(value.dtype): - raise RuntimeError( - 'Sparse dtype: requested = %s, actual = %s' % ( - dtype.name, value.dtype.name)) - return value - return ops.internal_convert_to_tensor( - value, dtype=dtype, name=name) diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index d5686251a68..dba302bab66 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -279,32 +279,6 @@ class WithShapeTest(tf.test.TestCase): ValueError, tensor_2x2.eval, {tensor_partial_shape: [42.0]}) -class ConvertToTensorOrSparseTensorTest(tf.test.TestCase): - - def test_convert_dense(self): - with self.test_session(): - value = [42, 43] - from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor( - value) - self.assertAllEqual(value, from_value.eval()) - - def test_convert_sparse(self): - with self.test_session(): - indices = [[0, 1], [1, 0]] - values = [42, 43] - shape = [2, 2] - sparse_tensor_value = tf.SparseTensorValue(indices, values, shape) - sparse_tensor = tf.SparseTensor.from_value(sparse_tensor_value) - from_value = tf.contrib.framework.convert_to_tensor_or_sparse_tensor( - sparse_tensor_value).eval() - from_tensor = tf.contrib.framework.convert_to_tensor_or_sparse_tensor( - sparse_tensor).eval() - for convertee in [from_value, from_tensor]: - self.assertAllEqual(sparse_tensor_value.indices, convertee.indices) - self.assertAllEqual(sparse_tensor_value.values, convertee.values) - self.assertAllEqual(sparse_tensor_value.shape, convertee.shape) - - class RemoveSqueezableDimensionsTest(tf.test.TestCase): def testRemoveSqueezableDimensions(self): diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index bee8b9567a3..89e84ca535f 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -8,93 +8,16 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "tf_cc_tests") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_kernel_tests_linkstatic", -) - -tf_custom_op_library( - # TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it. - name = "python/ops/_set_ops.so", - srcs = [ - "ops/set_ops.cc", - ], - deps = [ - "//tensorflow/contrib/metrics/kernels:set_kernels", - ], -) - -tf_gen_op_libs( - op_lib_names = ["set_ops"], -) - -tf_gen_op_wrapper_py( - name = "set_ops", - hidden = [ - "DenseToDenseSetOperation", - "DenseToSparseSetOperation", - "SparseToSparseSetOperation", - "SetSize", - ], - deps = [":set_ops_op_lib"], -) - -tf_kernel_library( - name = "set_ops_kernels", - deps = [ - "//tensorflow/contrib/metrics/kernels:set_kernels", - "//tensorflow/core:framework", - ], - alwayslink = 1, -) +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "metrics_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + glob(["python/metrics/*.py"]), - data = [":python/ops/_set_ops.so"], - srcs_version = "PY2AND3", - deps = [":set_ops"], -) - -py_test( - name = "set_ops_test", - size = "small", - srcs = ["python/kernel_tests/set_ops_test.py"], srcs_version = "PY2AND3", deps = [ - ":metrics_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -tf_cc_tests( - size = "small", - srcs = [ - "ops/set_ops_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":set_ops_op_lib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:ops", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//third_party/eigen3", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:sets", ], ) diff --git a/tensorflow/contrib/metrics/kernels/BUILD b/tensorflow/contrib/metrics/kernels/BUILD deleted file mode 100644 index 967c98ad60e..00000000000 --- a/tensorflow/contrib/metrics/kernels/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -# Description: -# Contains kernels for evaluation metrics and summary statistics. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -cc_library( - name = "set_kernels", - srcs = ["set_kernels.cc"], - copts = ["-Wno-sign-compare"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf//:protobuf", - ], - alwayslink = 1, -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/metrics/python/ops/set_ops.py b/tensorflow/contrib/metrics/python/ops/set_ops.py index dd737a14c29..9b80d08830f 100644 --- a/tensorflow/contrib/metrics/python/ops/set_ops.py +++ b/tensorflow/contrib/metrics/python/ops/set_ops.py @@ -17,167 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.framework import tensor_util - -from tensorflow.contrib.util import loader -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.platform import resource_loader +from tensorflow.python.ops import sets -_set_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_set_ops.so")) +set_size = sets.set_size -_VALID_DTYPES = set([ - dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16, dtypes.string]) +set_intersection = sets.set_intersection +set_difference = sets.set_difference -def set_size(a, validate_indices=True): - """Compute number of unique elements along last dimension of `a`. - - Args: - a: `SparseTensor`, with indices sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a`. - - Returns: - `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with - rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the - number of unique elements in the corresponding `[0...n-1]` dimension of `a`. - - Raises: - TypeError: If `a` is an invalid types. - """ - a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a") - if not isinstance(a, sparse_tensor.SparseTensor): - raise TypeError("Expected `SparseTensor`, got %s." % a) - if a.values.dtype.base_dtype not in _VALID_DTYPES: - raise TypeError("Invalid dtype %s." % a.values.dtype) - # pylint: disable=protected-access - return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices) - -ops.NotDifferentiable("SetSize") - - -ops.NotDifferentiable("DenseToDenseSetOperation") -ops.NotDifferentiable("DenseToSparseSetOperation") -ops.NotDifferentiable("SparseToSparseSetOperation") - - -def _set_operation(a, b, set_operation, validate_indices=True): - """Compute set operation of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. Must be - `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be - sorted in row-major order. - set_operation: String indicating set operaiton. See - SetOperationOp::SetOperationFromContext for valid values. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` with the same rank as `a` and `b`, and all but the last - dimension the same. Elements along the last dimension contain the results - of the set operation. - - Raises: - TypeError: If inputs are invalid types. - ValueError: If `a` is sparse and `b` is dense. - """ - a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a") - if a.dtype.base_dtype not in _VALID_DTYPES: - raise TypeError("'a' invalid dtype %s." % a.dtype) - b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b") - if b.dtype.base_dtype != a.dtype.base_dtype: - raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) - # pylint: disable=protected-access - if isinstance(a, sparse_tensor.SparseTensor): - if isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = _set_ops.sparse_to_sparse_set_operation( - a.indices, a.values, a.shape, b.indices, b.values, b.shape, - set_operation, validate_indices) - else: - raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " - "Please flip the order of your inputs.") - elif isinstance(b, sparse_tensor.SparseTensor): - indices, values, shape = _set_ops.dense_to_sparse_set_operation( - a, b.indices, b.values, b.shape, set_operation, validate_indices) - else: - indices, values, shape = _set_ops.dense_to_dense_set_operation( - a, b, set_operation, validate_indices) - # pylint: enable=protected-access - return sparse_tensor.SparseTensor(indices, values, shape) - - -def set_intersection(a, b, validate_indices=True): - """Compute set intersection of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. Must be - `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be - sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` with the same rank as `a` and `b`, and all but the last - dimension the same. Elements along the last dimension contain the - intersections. - """ - return _set_operation(a, b, "intersection", validate_indices) - - -def set_difference(a, b, aminusb=True, validate_indices=True): - """Compute set difference of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. Must be - `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be - sorted in row-major order. - aminusb: Whether to subtract `b` from `a`, vs vice versa. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` with the same rank as `a` and `b`, and all but the last - dimension the same. Elements along the last dimension contain the - differences. - """ - return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) - - -def set_union(a, b, validate_indices=True): - """Compute set union of elements in last dimension of `a` and `b`. - - All but the last dimension of `a` and `b` must match. - - Args: - a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices - must be sorted in row-major order. - b: `Tensor` or `SparseTensor` of the same type as `a`. Must be - `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be - sorted in row-major order. - validate_indices: Whether to validate the order and range of sparse indices - in `a` and `b`. - - Returns: - A `SparseTensor` with the same rank as `a` and `b`, and all but the last - dimension the same. Elements along the last dimension contain the - unions. - """ - return _set_operation(a, b, "union", validate_indices) +set_union = sets.set_union diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 487e00f14fe..60630859f8f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -422,6 +422,7 @@ tf_gen_op_libs( "random_ops", "resource_variable_ops", "sdca_ops", + "set_ops", "script_ops", "sendrecv_ops", "sparse_ops", @@ -465,6 +466,7 @@ cc_library( ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", + ":set_ops_op_lib", ":sparse_ops_op_lib", ":state_ops_op_lib", ":string_ops_op_lib", @@ -587,6 +589,7 @@ cc_library( "//tensorflow/core/kernels:required", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sdca_ops", + "//tensorflow/core/kernels:set_kernels", "//tensorflow/core/kernels:sparse", "//tensorflow/core/kernels:state", "//tensorflow/core/kernels:string", @@ -2179,6 +2182,7 @@ tf_cc_tests( "ops/nn_ops_test.cc", "ops/parsing_ops_test.cc", "ops/random_ops_test.cc", + "ops/set_ops_test.cc", "ops/sparse_ops_test.cc", "ops/state_ops_test.cc", "ops/string_ops_test.cc", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9a98d79a61a..d976af92534 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -431,6 +431,17 @@ tf_kernel_library( deps = ARRAY_DEPS, ) +tf_kernel_library( + name = "set_kernels", + prefix = "set_kernels", + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:set_ops_op_lib", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "debug_ops", prefix = "debug_ops", diff --git a/tensorflow/contrib/metrics/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc similarity index 100% rename from tensorflow/contrib/metrics/kernels/set_kernels.cc rename to tensorflow/core/kernels/set_kernels.cc diff --git a/tensorflow/contrib/metrics/ops/set_ops.cc b/tensorflow/core/ops/set_ops.cc similarity index 100% rename from tensorflow/contrib/metrics/ops/set_ops.cc rename to tensorflow/core/ops/set_ops.cc diff --git a/tensorflow/contrib/metrics/ops/set_ops_test.cc b/tensorflow/core/ops/set_ops_test.cc similarity index 100% rename from tensorflow/contrib/metrics/ops/set_ops_test.cc rename to tensorflow/core/ops/set_ops_test.cc diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4195221f8f3..180aa291ec4 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -746,6 +746,11 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, ) +tf_gen_op_wrapper_private_py( + name = "set_ops_gen", + require_shape_functions = True, +) + tf_gen_op_wrapper_private_py( name = "state_ops_gen", require_shape_functions = True, @@ -795,6 +800,16 @@ py_library( ], ) +py_library( + name = "sets", + srcs = ["ops/sets.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":set_ops_gen", + ], +) + py_library( name = "candidate_sampling_ops", srcs = ["ops/candidate_sampling_ops.py"], @@ -1516,6 +1531,7 @@ py_library( ":script_ops", ":seq2seq", ":session_ops", + ":sets", ":sparse_grad", ":sparse_ops", ":special_math_ops", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 2a7a76c3962..cd2f4fa328a 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -87,6 +87,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import resources from tensorflow.python.ops import sdca_ops as sdca from tensorflow.python.ops import image_ops as image +from tensorflow.python.ops import sets from tensorflow.python.user_ops import user_ops from tensorflow.python.util import compat from tensorflow.python.summary import summary @@ -223,6 +224,7 @@ _allowed_symbols.extend([ 'resources', 'resource_loader', 'sdca', + 'sets', 'summary', 'sysconfig', 'test', @@ -244,8 +246,9 @@ remove_undocumented(__name__, _allowed_symbols, [framework_lib, array_ops, client_lib, check_ops, compat, constant_op, control_flow_ops, functional_ops, histogram_ops, io_ops, math_ops, nn, resource_loader, - resources, script_ops, session_ops, sparse_ops, state_ops, - string_ops, summary, tensor_array_ops, train, layers]) + resources, sets, script_ops, session_ops, sparse_ops, + state_ops, string_ops, summary, tensor_array_ops, train, + layers]) # Special dunders that we choose to export: _exported_dunders = set([ diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 281979384ed..fe935881c68 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -35,6 +35,7 @@ @@control_dependencies @@convert_to_tensor @@convert_to_tensor_or_indexed_slices +@@convert_to_tensor_or_sparse_tensor @@get_default_graph @@reset_default_graph @@import_graph_def @@ -93,6 +94,7 @@ from tensorflow.python.framework.ops import convert_to_tensor from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices from tensorflow.python.framework.random_seed import get_seed from tensorflow.python.framework.random_seed import set_random_seed +from tensorflow.python.framework.sparse_tensor import convert_to_tensor_or_sparse_tensor from tensorflow.python.framework.subscribe import subscribe from tensorflow.python.framework.importer import import_graph_def diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 4c5a03533da..4b21c5ac135 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -302,3 +302,33 @@ class SparseTensorValue(object): def __getitem__(self, i): return [self.indices, self.values, self.dense_shape][i] + + +def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): + """Converts value to a `SparseTensor` or `Tensor`. + + Args: + value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a + registered `Tensor` conversion function. + dtype: Optional element type for the returned tensor. If missing, the + type is inferred from the type of `value`. + name: Optional name to use if a new `Tensor` is created. + + Returns: + A `SparseTensor` or `Tensor` based on `value`. + + Raises: + RuntimeError: If result type is incompatible with `dtype`. + """ + if dtype is not None: + dtype = dtypes.as_dtype(dtype) + if isinstance(value, SparseTensorValue): + value = SparseTensor.from_value(value) + if isinstance(value, SparseTensor): + if dtype and not dtype.is_compatible_with(value.dtype): + raise RuntimeError( + "Sparse dtype: requested = %s, actual = %s" % ( + dtype.name, value.dtype.name)) + return value + return ops.internal_convert_to_tensor( + value, dtype=dtype, name=name) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index b5f8142afc6..8138b186f42 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -52,5 +52,31 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertAllEqual(sess_run_value.shape, value.shape) +class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): + + def test_convert_dense(self): + with self.test_session(): + value = [42, 43] + from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor( + value) + self.assertAllEqual(value, from_value.eval()) + + def test_convert_sparse(self): + with self.test_session(): + indices = [[0, 1], [1, 0]] + values = [42, 43] + shape = [2, 2] + sparse_tensor_value = sparse_tensor.SparseTensorValue( + indices, values, shape) + st = sparse_tensor.SparseTensor.from_value(sparse_tensor_value) + from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor( + sparse_tensor_value).eval() + from_tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(st).eval() + for convertee in [from_value, from_tensor]: + self.assertAllEqual(sparse_tensor_value.indices, convertee.indices) + self.assertAllEqual(sparse_tensor_value.values, convertee.values) + self.assertAllEqual(sparse_tensor_value.shape, convertee.shape) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e0975086826..6fdb1653157 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1380,6 +1380,13 @@ sycl_py_test( additional_deps = ["//tensorflow:tensorflow_py"], ) +tf_py_test( + name = "sets_test", + size = "small", + srcs = ["sets_test.py"], + additional_deps = ["//tensorflow:tensorflow_py"], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py b/tensorflow/python/kernel_tests/sets_test.py similarity index 100% rename from tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py rename to tensorflow/python/kernel_tests/sets_test.py diff --git a/tensorflow/python/ops/sets.py b/tensorflow/python/ops/sets.py new file mode 100644 index 00000000000..92e7f2ed53e --- /dev/null +++ b/tensorflow/python/ops/sets.py @@ -0,0 +1,184 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python layer for sets. + +@@set_size +@@set_intersection +@@set_union +@@set_difference +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_set_ops + + +_VALID_DTYPES = set([ + dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16, dtypes.string]) + + +def set_size(a, validate_indices=True): + """Compute number of unique elements along last dimension of `a`. + + Args: + a: `SparseTensor`, with indices sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a`. + + Returns: + `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with + rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the + number of unique elements in the corresponding `[0...n-1]` dimension of `a`. + + Raises: + TypeError: If `a` is an invalid types. + """ + a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") + if not isinstance(a, sparse_tensor.SparseTensor): + raise TypeError("Expected `SparseTensor`, got %s." % a) + if a.values.dtype.base_dtype not in _VALID_DTYPES: + raise TypeError("Invalid dtype %s." % a.values.dtype) + # pylint: disable=protected-access + return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices) + +ops.NotDifferentiable("SetSize") + + +ops.NotDifferentiable("DenseToDenseSetOperation") +ops.NotDifferentiable("DenseToSparseSetOperation") +ops.NotDifferentiable("SparseToSparseSetOperation") + + +def _set_operation(a, b, set_operation, validate_indices=True): + """Compute set operation of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. Must be + `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be + sorted in row-major order. + set_operation: String indicating set operaiton. See + SetOperationOp::SetOperationFromContext for valid values. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` with the same rank as `a` and `b`, and all but the last + dimension the same. Elements along the last dimension contain the results + of the set operation. + + Raises: + TypeError: If inputs are invalid types. + ValueError: If `a` is sparse and `b` is dense. + """ + a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") + if a.dtype.base_dtype not in _VALID_DTYPES: + raise TypeError("'a' invalid dtype %s." % a.dtype) + b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") + if b.dtype.base_dtype != a.dtype.base_dtype: + raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) + # pylint: disable=protected-access + if isinstance(a, sparse_tensor.SparseTensor): + if isinstance(b, sparse_tensor.SparseTensor): + indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( + a.indices, a.values, a.shape, b.indices, b.values, b.shape, + set_operation, validate_indices) + else: + raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " + "Please flip the order of your inputs.") + elif isinstance(b, sparse_tensor.SparseTensor): + indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( + a, b.indices, b.values, b.shape, set_operation, validate_indices) + else: + indices, values, shape = gen_set_ops.dense_to_dense_set_operation( + a, b, set_operation, validate_indices) + # pylint: enable=protected-access + return sparse_tensor.SparseTensor(indices, values, shape) + + +def set_intersection(a, b, validate_indices=True): + """Compute set intersection of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. Must be + `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be + sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` with the same rank as `a` and `b`, and all but the last + dimension the same. Elements along the last dimension contain the + intersections. + """ + return _set_operation(a, b, "intersection", validate_indices) + + +def set_difference(a, b, aminusb=True, validate_indices=True): + """Compute set difference of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. Must be + `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be + sorted in row-major order. + aminusb: Whether to subtract `b` from `a`, vs vice versa. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` with the same rank as `a` and `b`, and all but the last + dimension the same. Elements along the last dimension contain the + differences. + """ + return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) + + +def set_union(a, b, validate_indices=True): + """Compute set union of elements in last dimension of `a` and `b`. + + All but the last dimension of `a` and `b` must match. + + Args: + a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices + must be sorted in row-major order. + b: `Tensor` or `SparseTensor` of the same type as `a`. Must be + `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be + sorted in row-major order. + validate_indices: Whether to validate the order and range of sparse indices + in `a` and `b`. + + Returns: + A `SparseTensor` with the same rank as `a` and `b`, and all but the last + dimension the same. Elements along the last dimension contain the + unions. + """ + return _set_operation(a, b, "union", validate_indices)