STT-tensorflow/tensorflow/python/kernel_tests/sets_test.py
Gaurav Jain 24f578cd66 Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
2018-11-29 15:43:25 -08:00

1271 lines
39 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for set_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sets
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
_DTYPES = set([
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8,
dtypes.uint16, dtypes.string
])
def _values(values, dtype):
return np.array(
values,
dtype=(np.unicode if (dtype == dtypes.string) else dtype.as_numpy_dtype))
def _constant(values, dtype):
return constant_op.constant(_values(values, dtype), dtype=dtype)
def _dense_to_sparse(dense, dtype):
indices = []
values = []
max_row_len = 0
for row in dense:
max_row_len = max(max_row_len, len(row))
shape = [len(dense), max_row_len]
row_ix = 0
for row in dense:
col_ix = 0
for cell in row:
indices.append([row_ix, col_ix])
values.append(str(cell) if dtype == dtypes.string else cell)
col_ix += 1
row_ix += 1
return sparse_tensor_lib.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(values, dtype),
constant_op.constant(shape, dtypes.int64))
class SetOpsTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def test_set_size_2d(self):
for dtype in _DTYPES:
self._test_set_size_2d(dtype)
def _test_set_size_2d(self, dtype):
self.assertAllEqual([1], self._set_size(_dense_to_sparse([[1]], dtype)))
self.assertAllEqual([2, 1],
self._set_size(_dense_to_sparse([[1, 9], [1]], dtype)))
self.assertAllEqual(
[3, 0], self._set_size(_dense_to_sparse([[1, 9, 2], []], dtype)))
self.assertAllEqual(
[0, 3], self._set_size(_dense_to_sparse([[], [1, 9, 2]], dtype)))
@test_util.run_deprecated_v1
def test_set_size_duplicates_2d(self):
for dtype in _DTYPES:
self._test_set_size_duplicates_2d(dtype)
def _test_set_size_duplicates_2d(self, dtype):
self.assertAllEqual(
[1], self._set_size(_dense_to_sparse([[1, 1, 1, 1, 1, 1]], dtype)))
self.assertAllEqual([2, 7, 3, 0, 1],
self._set_size(
_dense_to_sparse([[1, 9], [
6, 7, 8, 8, 6, 7, 5, 3, 3, 0, 6, 6, 9, 0, 0, 0
], [999, 1, -1000], [], [-1]], dtype)))
@test_util.run_deprecated_v1
def test_set_size_3d(self):
for dtype in _DTYPES:
self._test_set_size_3d(dtype)
def test_set_size_3d_invalid_indices(self):
for dtype in _DTYPES:
self._test_set_size_3d(dtype, invalid_indices=True)
def _test_set_size_3d(self, dtype, invalid_indices=False):
if invalid_indices:
indices = constant_op.constant([
[0, 1, 0], [0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0], [1, 1, 1], [1, 1, 2], # 1,1
[0, 0, 0], [0, 0, 2], # 0,0
# 2,0
[2, 1, 1] # 2,1
], dtypes.int64)
else:
indices = constant_op.constant([
[0, 0, 0], [0, 0, 2], # 0,0
[0, 1, 0], [0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0], [1, 1, 1], [1, 1, 2], # 1,1
# 2,0
[2, 1, 1] # 2,1
], dtypes.int64)
sp = sparse_tensor_lib.SparseTensor(
indices,
_constant([
1, 9, # 0,0
3, 3, # 0,1
1, # 1,0
9, 7, 8, # 1,1
# 2,0
5 # 2,1
], dtype),
constant_op.constant([3, 2, 3], dtypes.int64))
if invalid_indices:
with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
self._set_size(sp)
else:
self.assertAllEqual([
[2, # 0,0
1], # 0,1
[1, # 1,0
3], # 1,1
[0, # 2,0
1] # 2,1
], self._set_size(sp))
def _set_size(self, sparse_data):
# Validate that we get the same results with or without `validate_indices`.
ops = [
sets.set_size(sparse_data, validate_indices=True),
sets.set_size(sparse_data, validate_indices=False)
]
for op in ops:
self.assertEqual(None, op.get_shape().dims)
self.assertEqual(dtypes.int32, op.dtype)
with self.cached_session() as sess:
results = self.evaluate(ops)
self.assertAllEqual(results[0], results[1])
return results[0]
@test_util.run_deprecated_v1
def test_set_intersection_multirow_2d(self):
for dtype in _DTYPES:
self._test_set_intersection_multirow_2d(dtype)
def _test_set_intersection_multirow_2d(self, dtype):
a_values = [[9, 1, 5], [2, 4, 3]]
b_values = [[1, 9], [1]]
expected_indices = [[0, 0], [0, 1]]
expected_values = _values([1, 9], dtype)
expected_shape = [2, 2]
expected_counts = [2, 0]
# Dense to sparse.
a = _constant(a_values, dtype=dtype)
sp_b = _dense_to_sparse(b_values, dtype=dtype)
intersection = self._set_intersection(a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
intersection = self._set_intersection(sp_a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_intersection_count(sp_a, sp_b))
@test_util.run_deprecated_v1
def test_dense_set_intersection_multirow_2d(self):
for dtype in _DTYPES:
self._test_dense_set_intersection_multirow_2d(dtype)
def _test_dense_set_intersection_multirow_2d(self, dtype):
a_values = [[9, 1, 5], [2, 4, 3]]
b_values = [[1, 9], [1, 5]]
expected_indices = [[0, 0], [0, 1]]
expected_values = _values([1, 9], dtype)
expected_shape = [2, 2]
expected_counts = [2, 0]
# Dense to dense.
a = _constant(a_values, dtype)
b = _constant(b_values, dtype)
intersection = self._set_intersection(a, b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
@test_util.run_deprecated_v1
def test_set_intersection_duplicates_2d(self):
for dtype in _DTYPES:
self._test_set_intersection_duplicates_2d(dtype)
def _test_set_intersection_duplicates_2d(self, dtype):
a_values = [[1, 1, 3]]
b_values = [[1]]
expected_indices = [[0, 0]]
expected_values = _values([1], dtype)
expected_shape = [1, 1]
expected_counts = [1]
# Dense to dense.
a = _constant(a_values, dtype=dtype)
b = _constant(b_values, dtype=dtype)
intersection = self._set_intersection(a, b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
# Dense to sparse.
sp_b = _dense_to_sparse(b_values, dtype=dtype)
intersection = self._set_intersection(a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
intersection = self._set_intersection(sp_a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_intersection_count(sp_a, sp_b))
@test_util.run_deprecated_v1
def test_set_intersection_3d(self):
for dtype in _DTYPES:
self._test_set_intersection_3d(dtype=dtype)
def test_set_intersection_3d_invalid_indices(self):
for dtype in _DTYPES:
self._test_set_intersection_3d(dtype=dtype, invalid_indices=True)
def _test_set_intersection_3d(self, dtype, invalid_indices=False):
if invalid_indices:
indices = constant_op.constant(
[
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
[0, 0, 0],
[0, 0, 2], # 0,0
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
else:
indices = constant_op.constant(
[
[0, 0, 0],
[0, 0, 2], # 0,0
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
sp_a = sparse_tensor_lib.SparseTensor(
indices,
_constant(
[
1,
9, # 0,0
3,
3, # 0,1
1, # 1,0
9,
7,
8, # 1,1
# 2,0
5 # 2,1
# 3,*
],
dtype),
constant_op.constant([4, 2, 3], dtypes.int64))
sp_b = sparse_tensor_lib.SparseTensor(
constant_op.constant(
[
[0, 0, 0],
[0, 0, 3], # 0,0
# 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1], # 1,1
[2, 0, 1], # 2,0
[2, 1, 1], # 2,1
[3, 0, 0], # 3,0
[3, 1, 0] # 3,1
],
dtypes.int64),
_constant(
[
1,
3, # 0,0
# 0,1
3, # 1,0
7,
8, # 1,1
2, # 2,0
5, # 2,1
4, # 3,0
4 # 3,1
],
dtype),
constant_op.constant([4, 2, 4], dtypes.int64))
if invalid_indices:
with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
self._set_intersection(sp_a, sp_b)
else:
expected_indices = [
[0, 0, 0], # 0,0
# 0,1
# 1,0
[1, 1, 0],
[1, 1, 1], # 1,1
# 2,0
[2, 1, 0], # 2,1
# 3,*
]
expected_values = _values(
[
1, # 0,0
# 0,1
# 1,0
7,
8, # 1,1
# 2,0
5, # 2,1
# 3,*
],
dtype)
expected_shape = [4, 2, 2]
expected_counts = [
[
1, # 0,0
0 # 0,1
],
[
0, # 1,0
2 # 1,1
],
[
0, # 2,0
1 # 2,1
],
[
0, # 3,0
0 # 3,1
]
]
# Sparse to sparse.
intersection = self._set_intersection(sp_a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_intersection_count(sp_a, sp_b))
# NOTE: sparse_to_dense doesn't support uint8 and uint16.
if dtype not in [dtypes.uint8, dtypes.uint16]:
# Dense to sparse.
a = math_ops.cast(
sparse_ops.sparse_to_dense(
sp_a.indices,
sp_a.dense_shape,
sp_a.values,
default_value="-1" if dtype == dtypes.string else -1),
dtype=dtype)
intersection = self._set_intersection(a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_intersection_count(a, sp_b))
# Dense to dense.
b = math_ops.cast(
sparse_ops.sparse_to_dense(
sp_b.indices,
sp_b.dense_shape,
sp_b.values,
default_value="-2" if dtype == dtypes.string else -2),
dtype=dtype)
intersection = self._set_intersection(a, b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
def _assert_static_shapes(self, input_tensor, result_sparse_tensor):
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
sparse_shape_dims = input_tensor.dense_shape.get_shape().dims
if sparse_shape_dims is None:
expected_rank = None
else:
expected_rank = sparse_shape_dims[0].value
else:
expected_rank = input_tensor.get_shape().ndims
self.assertAllEqual((None, expected_rank),
result_sparse_tensor.indices.get_shape().as_list())
self.assertAllEqual((None,),
result_sparse_tensor.values.get_shape().as_list())
self.assertAllEqual((expected_rank,),
result_sparse_tensor.dense_shape.get_shape().as_list())
def _run_equivalent_set_ops(self, ops):
"""Assert all ops return the same shapes, and return 1st result."""
# Collect shapes and results for all ops, and assert static shapes match.
dynamic_indices_shape_ops = []
dynamic_values_shape_ops = []
static_indices_shape = None
static_values_shape = None
with self.cached_session() as sess:
for op in ops:
if static_indices_shape is None:
static_indices_shape = op.indices.get_shape()
else:
self.assertAllEqual(
static_indices_shape.as_list(), op.indices.get_shape().as_list())
if static_values_shape is None:
static_values_shape = op.values.get_shape()
else:
self.assertAllEqual(
static_values_shape.as_list(), op.values.get_shape().as_list())
dynamic_indices_shape_ops.append(array_ops.shape(op.indices))
dynamic_values_shape_ops.append(array_ops.shape(op.values))
results = sess.run(
list(ops) + dynamic_indices_shape_ops + dynamic_values_shape_ops)
op_count = len(ops)
op_results = results[0:op_count]
dynamic_indices_shapes = results[op_count:2 * op_count]
dynamic_values_shapes = results[2 * op_count:3 * op_count]
# Assert static and dynamic tensor shapes, and result shapes, are all
# consistent.
static_indices_shape.assert_is_compatible_with(dynamic_indices_shapes[0])
static_values_shape.assert_is_compatible_with(dynamic_values_shapes[0])
self.assertAllEqual(dynamic_indices_shapes[0], op_results[0].indices.shape)
self.assertAllEqual(dynamic_values_shapes[0], op_results[0].values.shape)
# Assert dynamic shapes and values are the same for all ops.
for i in range(1, len(ops)):
self.assertAllEqual(dynamic_indices_shapes[0], dynamic_indices_shapes[i])
self.assertAllEqual(dynamic_values_shapes[0], dynamic_values_shapes[i])
self.assertAllEqual(op_results[0].indices, op_results[i].indices)
self.assertAllEqual(op_results[0].values, op_results[i].values)
self.assertAllEqual(op_results[0].dense_shape, op_results[i].dense_shape)
return op_results[0]
def _set_intersection(self, a, b):
# Validate that we get the same results with or without `validate_indices`,
# and with a & b swapped.
ops = (
sets.set_intersection(
a, b, validate_indices=True),
sets.set_intersection(
a, b, validate_indices=False),
sets.set_intersection(
b, a, validate_indices=True),
sets.set_intersection(
b, a, validate_indices=False),)
for op in ops:
self._assert_static_shapes(a, op)
return self._run_equivalent_set_ops(ops)
def _set_intersection_count(self, a, b):
op = sets.set_size(sets.set_intersection(a, b))
with self.cached_session() as sess:
return self.evaluate(op)
@test_util.run_deprecated_v1
def test_set_difference_multirow_2d(self):
for dtype in _DTYPES:
self._test_set_difference_multirow_2d(dtype)
def _test_set_difference_multirow_2d(self, dtype):
a_values = [[1, 1, 1], [1, 5, 9], [4, 5, 3], [5, 5, 1]]
b_values = [[], [1, 2], [1, 2, 2], []]
# a - b.
expected_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0],
[3, 1]]
expected_values = _values([1, 5, 9, 3, 4, 5, 1, 5], dtype)
expected_shape = [4, 3]
expected_counts = [1, 2, 3, 2]
# Dense to sparse.
a = _constant(a_values, dtype=dtype)
sp_b = _dense_to_sparse(b_values, dtype=dtype)
difference = self._set_difference(a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, True))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
difference = self._set_difference(sp_a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b, True))
# b - a.
expected_indices = [[1, 0], [2, 0], [2, 1]]
expected_values = _values([2, 1, 2], dtype)
expected_shape = [4, 2]
expected_counts = [0, 1, 2, 0]
# Dense to sparse.
difference = self._set_difference(a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, False))
# Sparse to sparse.
difference = self._set_difference(sp_a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b, False))
@test_util.run_deprecated_v1
def test_dense_set_difference_multirow_2d(self):
for dtype in _DTYPES:
self._test_dense_set_difference_multirow_2d(dtype)
def _test_dense_set_difference_multirow_2d(self, dtype):
a_values = [[1, 5, 9], [4, 5, 3]]
b_values = [[1, 2, 6], [1, 2, 2]]
# a - b.
expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
expected_values = _values([5, 9, 3, 4, 5], dtype)
expected_shape = [2, 3]
expected_counts = [2, 3]
# Dense to dense.
a = _constant(a_values, dtype=dtype)
b = _constant(b_values, dtype=dtype)
difference = self._set_difference(a, b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_difference_count(a, b, True))
# b - a.
expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1]]
expected_values = _values([2, 6, 1, 2], dtype)
expected_shape = [2, 2]
expected_counts = [2, 2]
# Dense to dense.
difference = self._set_difference(a, b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, b, False))
@test_util.run_deprecated_v1
def test_sparse_set_difference_multirow_2d(self):
for dtype in _DTYPES:
self._test_sparse_set_difference_multirow_2d(dtype)
def _test_sparse_set_difference_multirow_2d(self, dtype):
sp_a = _dense_to_sparse(
[[], [1, 5, 9], [4, 5, 3, 3, 4, 5], [5, 1]], dtype=dtype)
sp_b = _dense_to_sparse([[], [1, 2], [1, 2, 2], []], dtype=dtype)
# a - b.
expected_indices = [[1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0], [3, 1]]
expected_values = _values([5, 9, 3, 4, 5, 1, 5], dtype)
expected_shape = [4, 3]
expected_counts = [0, 2, 3, 2]
difference = self._set_difference(sp_a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b, True))
# b - a.
expected_indices = [[1, 0], [2, 0], [2, 1]]
expected_values = _values([2, 1, 2], dtype)
expected_shape = [4, 2]
expected_counts = [0, 1, 2, 0]
difference = self._set_difference(sp_a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b, False))
@test_util.run_deprecated_v1
def test_set_difference_duplicates_2d(self):
for dtype in _DTYPES:
self._test_set_difference_duplicates_2d(dtype)
def _test_set_difference_duplicates_2d(self, dtype):
a_values = [[1, 1, 3]]
b_values = [[1, 2, 2]]
# a - b.
expected_indices = [[0, 0]]
expected_values = _values([3], dtype)
expected_shape = [1, 1]
expected_counts = [1]
# Dense to sparse.
a = _constant(a_values, dtype=dtype)
sp_b = _dense_to_sparse(b_values, dtype=dtype)
difference = self._set_difference(a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, True))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
difference = self._set_difference(sp_a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, True))
# b - a.
expected_indices = [[0, 0]]
expected_values = _values([2], dtype)
expected_shape = [1, 1]
expected_counts = [1]
# Dense to sparse.
difference = self._set_difference(a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, False))
# Sparse to sparse.
difference = self._set_difference(sp_a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(a, sp_b, False))
@test_util.run_deprecated_v1
def test_sparse_set_difference_3d(self):
for dtype in _DTYPES:
self._test_sparse_set_difference_3d(dtype)
def test_sparse_set_difference_3d_invalid_indices(self):
for dtype in _DTYPES:
self._test_sparse_set_difference_3d(dtype, invalid_indices=True)
def _test_sparse_set_difference_3d(self, dtype, invalid_indices=False):
if invalid_indices:
indices = constant_op.constant(
[
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
[0, 0, 0],
[0, 0, 2], # 0,0
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
else:
indices = constant_op.constant(
[
[0, 0, 0],
[0, 0, 2], # 0,0
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
sp_a = sparse_tensor_lib.SparseTensor(
indices,
_constant(
[
1,
9, # 0,0
3,
3, # 0,1
1, # 1,0
9,
7,
8, # 1,1
# 2,0
5 # 2,1
# 3,*
],
dtype),
constant_op.constant([4, 2, 3], dtypes.int64))
sp_b = sparse_tensor_lib.SparseTensor(
constant_op.constant(
[
[0, 0, 0],
[0, 0, 3], # 0,0
# 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1], # 1,1
[2, 0, 1], # 2,0
[2, 1, 1], # 2,1
[3, 0, 0], # 3,0
[3, 1, 0] # 3,1
],
dtypes.int64),
_constant(
[
1,
3, # 0,0
# 0,1
3, # 1,0
7,
8, # 1,1
2, # 2,0
5, # 2,1
4, # 3,0
4 # 3,1
],
dtype),
constant_op.constant([4, 2, 4], dtypes.int64))
if invalid_indices:
with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
self._set_difference(sp_a, sp_b, False)
with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
self._set_difference(sp_a, sp_b, True)
else:
# a-b
expected_indices = [
[0, 0, 0], # 0,0
[0, 1, 0], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0], # 1,1
# 2,*
# 3,*
]
expected_values = _values(
[
9, # 0,0
3, # 0,1
1, # 1,0
9, # 1,1
# 2,*
# 3,*
],
dtype)
expected_shape = [4, 2, 1]
expected_counts = [
[
1, # 0,0
1 # 0,1
],
[
1, # 1,0
1 # 1,1
],
[
0, # 2,0
0 # 2,1
],
[
0, # 3,0
0 # 3,1
]
]
difference = self._set_difference(sp_a, sp_b, True)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b))
# b-a
expected_indices = [
[0, 0, 0], # 0,0
# 0,1
[1, 0, 0], # 1,0
# 1,1
[2, 0, 0], # 2,0
# 2,1
[3, 0, 0], # 3,0
[3, 1, 0] # 3,1
]
expected_values = _values(
[
3, # 0,0
# 0,1
3, # 1,0
# 1,1
2, # 2,0
# 2,1
4, # 3,0
4, # 3,1
],
dtype)
expected_shape = [4, 2, 1]
expected_counts = [
[
1, # 0,0
0 # 0,1
],
[
1, # 1,0
0 # 1,1
],
[
1, # 2,0
0 # 2,1
],
[
1, # 3,0
1 # 3,1
]
]
difference = self._set_difference(sp_a, sp_b, False)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
difference,
dtype=dtype)
self.assertAllEqual(expected_counts,
self._set_difference_count(sp_a, sp_b, False))
def _set_difference(self, a, b, aminusb=True):
# Validate that we get the same results with or without `validate_indices`,
# and with a & b swapped.
ops = (
sets.set_difference(
a, b, aminusb=aminusb, validate_indices=True),
sets.set_difference(
a, b, aminusb=aminusb, validate_indices=False),
sets.set_difference(
b, a, aminusb=not aminusb, validate_indices=True),
sets.set_difference(
b, a, aminusb=not aminusb, validate_indices=False),)
for op in ops:
self._assert_static_shapes(a, op)
return self._run_equivalent_set_ops(ops)
def _set_difference_count(self, a, b, aminusb=True):
op = sets.set_size(sets.set_difference(a, b, aminusb))
with self.cached_session() as sess:
return self.evaluate(op)
@test_util.run_deprecated_v1
def test_set_union_multirow_2d(self):
for dtype in _DTYPES:
self._test_set_union_multirow_2d(dtype)
def _test_set_union_multirow_2d(self, dtype):
a_values = [[9, 1, 5], [2, 4, 3]]
b_values = [[1, 9], [1]]
expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]]
expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype)
expected_shape = [2, 4]
expected_counts = [3, 4]
# Dense to sparse.
a = _constant(a_values, dtype=dtype)
sp_b = _dense_to_sparse(b_values, dtype=dtype)
union = self._set_union(a, sp_b)
self._assert_set_operation(
expected_indices, expected_values, expected_shape, union, dtype=dtype)
self.assertAllEqual(expected_counts, self._set_union_count(a, sp_b))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
union = self._set_union(sp_a, sp_b)
self._assert_set_operation(
expected_indices, expected_values, expected_shape, union, dtype=dtype)
self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b))
@test_util.run_deprecated_v1
def test_dense_set_union_multirow_2d(self):
for dtype in _DTYPES:
self._test_dense_set_union_multirow_2d(dtype)
def _test_dense_set_union_multirow_2d(self, dtype):
a_values = [[9, 1, 5], [2, 4, 3]]
b_values = [[1, 9], [1, 2]]
expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]]
expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype)
expected_shape = [2, 4]
expected_counts = [3, 4]
# Dense to dense.
a = _constant(a_values, dtype=dtype)
b = _constant(b_values, dtype=dtype)
union = self._set_union(a, b)
self._assert_set_operation(
expected_indices, expected_values, expected_shape, union, dtype=dtype)
self.assertAllEqual(expected_counts, self._set_union_count(a, b))
@test_util.run_deprecated_v1
def test_set_union_duplicates_2d(self):
for dtype in _DTYPES:
self._test_set_union_duplicates_2d(dtype)
def _test_set_union_duplicates_2d(self, dtype):
a_values = [[1, 1, 3]]
b_values = [[1]]
expected_indices = [[0, 0], [0, 1]]
expected_values = _values([1, 3], dtype)
expected_shape = [1, 2]
# Dense to sparse.
a = _constant(a_values, dtype=dtype)
sp_b = _dense_to_sparse(b_values, dtype=dtype)
union = self._set_union(a, sp_b)
self._assert_set_operation(
expected_indices, expected_values, expected_shape, union, dtype=dtype)
self.assertAllEqual([2], self._set_union_count(a, sp_b))
# Sparse to sparse.
sp_a = _dense_to_sparse(a_values, dtype=dtype)
union = self._set_union(sp_a, sp_b)
self._assert_set_operation(
expected_indices, expected_values, expected_shape, union, dtype=dtype)
self.assertAllEqual([2], self._set_union_count(sp_a, sp_b))
@test_util.run_deprecated_v1
def test_sparse_set_union_3d(self):
for dtype in _DTYPES:
self._test_sparse_set_union_3d(dtype)
def test_sparse_set_union_3d_invalid_indices(self):
for dtype in _DTYPES:
self._test_sparse_set_union_3d(dtype, invalid_indices=True)
def _test_sparse_set_union_3d(self, dtype, invalid_indices=False):
if invalid_indices:
indices = constant_op.constant(
[
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[0, 0, 0],
[0, 0, 2], # 0,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
else:
indices = constant_op.constant(
[
[0, 0, 0],
[0, 0, 2], # 0,0
[0, 1, 0],
[0, 1, 1], # 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
# 2,0
[2, 1, 1] # 2,1
# 3,*
],
dtypes.int64)
sp_a = sparse_tensor_lib.SparseTensor(
indices,
_constant(
[
1,
9, # 0,0
3,
3, # 0,1
1, # 1,0
9,
7,
8, # 1,1
# 2,0
5 # 2,1
# 3,*
],
dtype),
constant_op.constant([4, 2, 3], dtypes.int64))
sp_b = sparse_tensor_lib.SparseTensor(
constant_op.constant(
[
[0, 0, 0],
[0, 0, 3], # 0,0
# 0,1
[1, 0, 0], # 1,0
[1, 1, 0],
[1, 1, 1], # 1,1
[2, 0, 1], # 2,0
[2, 1, 1], # 2,1
[3, 0, 0], # 3,0
[3, 1, 0] # 3,1
],
dtypes.int64),
_constant(
[
1,
3, # 0,0
# 0,1
3, # 1,0
7,
8, # 1,1
2, # 2,0
5, # 2,1
4, # 3,0
4 # 3,1
],
dtype),
constant_op.constant([4, 2, 4], dtypes.int64))
if invalid_indices:
with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
self._set_union(sp_a, sp_b)
else:
expected_indices = [
[0, 0, 0],
[0, 0, 1],
[0, 0, 2], # 0,0
[0, 1, 0], # 0,1
[1, 0, 0],
[1, 0, 1], # 1,0
[1, 1, 0],
[1, 1, 1],
[1, 1, 2], # 1,1
[2, 0, 0], # 2,0
[2, 1, 0], # 2,1
[3, 0, 0], # 3,0
[3, 1, 0], # 3,1
]
expected_values = _values(
[
1,
3,
9, # 0,0
3, # 0,1
1,
3, # 1,0
7,
8,
9, # 1,1
2, # 2,0
5, # 2,1
4, # 3,0
4, # 3,1
],
dtype)
expected_shape = [4, 2, 3]
expected_counts = [
[
3, # 0,0
1 # 0,1
],
[
2, # 1,0
3 # 1,1
],
[
1, # 2,0
1 # 2,1
],
[
1, # 3,0
1 # 3,1
]
]
intersection = self._set_union(sp_a, sp_b)
self._assert_set_operation(
expected_indices,
expected_values,
expected_shape,
intersection,
dtype=dtype)
self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b))
def _set_union(self, a, b):
# Validate that we get the same results with or without `validate_indices`,
# and with a & b swapped.
ops = (
sets.set_union(
a, b, validate_indices=True),
sets.set_union(
a, b, validate_indices=False),
sets.set_union(
b, a, validate_indices=True),
sets.set_union(
b, a, validate_indices=False),)
for op in ops:
self._assert_static_shapes(a, op)
return self._run_equivalent_set_ops(ops)
def _set_union_count(self, a, b):
op = sets.set_size(sets.set_union(a, b))
with self.cached_session() as sess:
return self.evaluate(op)
def _assert_set_operation(self, expected_indices, expected_values,
expected_shape, sparse_tensor_value, dtype):
self.assertAllEqual(expected_indices, sparse_tensor_value.indices)
self.assertAllEqual(len(expected_indices), len(expected_values))
self.assertAllEqual(len(expected_values), len(sparse_tensor_value.values))
expected_set = set()
actual_set = set()
last_indices = None
for indices, expected_value, actual_value in zip(
expected_indices, expected_values, sparse_tensor_value.values):
if dtype == dtypes.string:
actual_value = actual_value.decode("utf-8")
if last_indices and (last_indices[:-1] != indices[:-1]):
self.assertEqual(expected_set, actual_set,
"Expected %s, got %s, at %s." % (expected_set,
actual_set, indices))
expected_set.clear()
actual_set.clear()
expected_set.add(expected_value)
actual_set.add(actual_value)
last_indices = indices
self.assertEqual(expected_set, actual_set,
"Expected %s, got %s, at %s." % (expected_set, actual_set,
last_indices))
self.assertAllEqual(expected_shape, sparse_tensor_value.dense_shape)
if __name__ == "__main__":
googletest.main()