Add support for negative indices to sparse_split.
PiperOrigin-RevId: 320082703 Change-Id: I71c13e4ad05eb4ce5f5f2c7e2bd3be028875929f
This commit is contained in:
parent
40a0ec5420
commit
1c4ee22e6f
@ -30,7 +30,7 @@ class SparseSplitOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const int64 split_dim = context->input(0).scalar<int64>()();
|
||||
const int64 axis_input = context->input(0).scalar<int64>()();
|
||||
const Tensor& input_indices = context->input(1);
|
||||
const Tensor& input_values = context->input(2);
|
||||
const Tensor& input_shape = context->input(3);
|
||||
@ -48,20 +48,20 @@ class SparseSplitOp : public OpKernel {
|
||||
"Input shape should be a vector but received shape ",
|
||||
input_shape.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
input_shape.dim_size(0) && split_dim < input_shape.vec<int64>().size(),
|
||||
errors::InvalidArgument(
|
||||
"Input split_dim should be between 0 and rank (",
|
||||
input_shape.vec<int64>().size(), "), got ", split_dim));
|
||||
const int64 input_rank = input_shape.vec<int64>().size();
|
||||
const int64 axis = (axis_input < 0) ? input_rank + axis_input : axis_input;
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(split_dim),
|
||||
errors::InvalidArgument("Input num_split should be between 1 "
|
||||
"and the splitting dimension size (",
|
||||
input_shape.vec<int64>()(split_dim), "), got ",
|
||||
num_split_));
|
||||
context, axis >= 0 && axis < input_rank,
|
||||
errors::InvalidArgument("Input axis should be in range [", -input_rank,
|
||||
", ", input_rank, "), got ", axis_input));
|
||||
|
||||
OP_REQUIRES(context,
|
||||
num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(axis),
|
||||
errors::InvalidArgument("Input num_split should be between 1 "
|
||||
"and the splitting dimension size (",
|
||||
input_shape.vec<int64>()(axis),
|
||||
"), got ", num_split_));
|
||||
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -70,9 +70,8 @@ class SparseSplitOp : public OpKernel {
|
||||
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
|
||||
|
||||
std::vector<sparse::SparseTensor> outputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Split<T>(sparse_tensor, split_dim,
|
||||
num_split_, &outputs));
|
||||
OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(
|
||||
sparse_tensor, axis, num_split_, &outputs));
|
||||
|
||||
for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
|
||||
context->set_output(slice_index, outputs[slice_index].indices());
|
||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -76,182 +76,176 @@ class SparseSplitOpTest(test.TestCase):
|
||||
return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2(
|
||||
))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitMatrixRows(self):
|
||||
with self.session(use_gpu=False):
|
||||
sp_tensors = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=2, axis=0)
|
||||
for axis in (0, -2):
|
||||
sp_tensors = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=2, axis=axis))
|
||||
self.assertAllEqual(len(sp_tensors), 2)
|
||||
self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4],
|
||||
[0, 5], [1, 1], [1, 3],
|
||||
[1, 4]])
|
||||
self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5, 11, 13, 14])
|
||||
self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [2, 6])
|
||||
self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 0], [0, 3], [0, 5],
|
||||
[1, 0], [1, 2], [1, 3],
|
||||
[1, 5]])
|
||||
self.assertAllEqual(sp_tensors[1].values.eval(),
|
||||
[20, 23, 25, 30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [2, 6])
|
||||
self.assertAllEqual(
|
||||
sp_tensors[0].indices,
|
||||
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]])
|
||||
self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5, 11, 13, 14])
|
||||
self.assertAllEqual(sp_tensors[0].dense_shape, [2, 6])
|
||||
self.assertAllEqual(
|
||||
sp_tensors[1].indices,
|
||||
[[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]])
|
||||
self.assertAllEqual(sp_tensors[1].values, [20, 23, 25, 30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors[1].dense_shape, [2, 6])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitMatrixUnevenCols(self):
|
||||
with self.session(use_gpu=False):
|
||||
sp_tensors_3 = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=3, axis=1)
|
||||
for axis in (1, -1):
|
||||
sp_tensors_3 = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis))
|
||||
self.assertAllEqual(len(sp_tensors_3), 3)
|
||||
self.assertAllEqual(sp_tensors_3[0].indices.eval(),
|
||||
[[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2],
|
||||
[4, 1]])
|
||||
self.assertAllEqual(sp_tensors_3[0].values.eval(),
|
||||
[0, 2, 11, 20, 30, 32, 41])
|
||||
self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [5, 3])
|
||||
self.assertAllEqual(sp_tensors_3[1].indices.eval(),
|
||||
self.assertAllEqual(
|
||||
sp_tensors_3[0].indices,
|
||||
[[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]])
|
||||
self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 11, 20, 30, 32, 41])
|
||||
self.assertAllEqual(sp_tensors_3[0].dense_shape, [5, 3])
|
||||
self.assertAllEqual(sp_tensors_3[1].indices,
|
||||
[[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
|
||||
self.assertAllEqual(sp_tensors_3[1].values.eval(),
|
||||
[4, 13, 14, 23, 33, 44])
|
||||
self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [5, 2])
|
||||
self.assertAllEqual(sp_tensors_3[2].indices.eval(),
|
||||
self.assertAllEqual(sp_tensors_3[1].values, [4, 13, 14, 23, 33, 44])
|
||||
self.assertAllEqual(sp_tensors_3[1].dense_shape, [5, 2])
|
||||
self.assertAllEqual(sp_tensors_3[2].indices,
|
||||
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
|
||||
self.assertAllEqual(sp_tensors_3[2].values.eval(), [5, 16, 25, 35, 46])
|
||||
self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [5, 2])
|
||||
self.assertAllEqual(sp_tensors_3[2].values, [5, 16, 25, 35, 46])
|
||||
self.assertAllEqual(sp_tensors_3[2].dense_shape, [5, 2])
|
||||
sp_tensors_4 = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=4, axis=1)
|
||||
sp_input=self._SparseTensor_5x7(), num_split=4, axis=axis)
|
||||
self.assertAllEqual(len(sp_tensors_4), 4)
|
||||
self.assertAllEqual(sp_tensors_4[0].indices.eval(),
|
||||
self.assertAllEqual(sp_tensors_4[0].indices,
|
||||
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
|
||||
self.assertAllEqual(sp_tensors_4[0].values.eval(), [0, 11, 20, 30, 41])
|
||||
self.assertAllEqual(sp_tensors_4[0].dense_shape.eval(), [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[1].indices.eval(),
|
||||
self.assertAllEqual(sp_tensors_4[0].values, [0, 11, 20, 30, 41])
|
||||
self.assertAllEqual(sp_tensors_4[0].dense_shape, [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[1].indices,
|
||||
[[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]])
|
||||
self.assertAllEqual(sp_tensors_4[1].values.eval(), [2, 13, 23, 32, 33])
|
||||
self.assertAllEqual(sp_tensors_4[1].dense_shape.eval(), [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[2].indices.eval(),
|
||||
self.assertAllEqual(sp_tensors_4[1].values, [2, 13, 23, 32, 33])
|
||||
self.assertAllEqual(sp_tensors_4[1].dense_shape, [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[2].indices,
|
||||
[[0, 0], [0, 1], [1, 0], [2, 1], [3, 1], [4, 0]])
|
||||
self.assertAllEqual(sp_tensors_4[2].values.eval(), [4, 5, 14, 25, 35, 44])
|
||||
self.assertAllEqual(sp_tensors_4[2].dense_shape.eval(), [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[3].indices.eval(), [[1, 0], [4, 0]])
|
||||
self.assertAllEqual(sp_tensors_4[3].values.eval(), [16, 46])
|
||||
self.assertAllEqual(sp_tensors_4[3].dense_shape.eval(), [5, 1])
|
||||
self.assertAllEqual(sp_tensors_4[2].values, [4, 5, 14, 25, 35, 44])
|
||||
self.assertAllEqual(sp_tensors_4[2].dense_shape, [5, 2])
|
||||
self.assertAllEqual(sp_tensors_4[3].indices, [[1, 0], [4, 0]])
|
||||
self.assertAllEqual(sp_tensors_4[3].values, [16, 46])
|
||||
self.assertAllEqual(sp_tensors_4[3].dense_shape, [5, 1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitMatrixUnevenRows(self):
|
||||
with self.session(use_gpu=False):
|
||||
sp_tensors_2 = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=2, axis=0)
|
||||
self.assertAllEqual(sp_tensors_2[0].indices.eval(),
|
||||
for axis in (0, -2):
|
||||
sp_tensors_2 = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=2, axis=axis))
|
||||
self.assertAllEqual(sp_tensors_2[0].indices,
|
||||
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
|
||||
[1, 4], [1, 6], [2, 0], [2, 3], [2, 5]])
|
||||
self.assertAllEqual(sp_tensors_2[0].values.eval(),
|
||||
self.assertAllEqual(sp_tensors_2[0].values,
|
||||
[0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25])
|
||||
self.assertAllEqual(sp_tensors_2[0].dense_shape.eval(), [3, 7])
|
||||
self.assertAllEqual(sp_tensors_2[1].indices.eval(),
|
||||
[[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4],
|
||||
[1, 6]])
|
||||
self.assertAllEqual(sp_tensors_2[1].values.eval(),
|
||||
[30, 32, 33, 35, 41, 44, 46])
|
||||
self.assertAllEqual(sp_tensors_2[1].dense_shape.eval(), [2, 7])
|
||||
self.assertAllEqual(sp_tensors_2[0].dense_shape, [3, 7])
|
||||
self.assertAllEqual(
|
||||
sp_tensors_2[1].indices,
|
||||
[[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]])
|
||||
self.assertAllEqual(sp_tensors_2[1].values, [30, 32, 33, 35, 41, 44, 46])
|
||||
self.assertAllEqual(sp_tensors_2[1].dense_shape, [2, 7])
|
||||
self.assertAllEqual(len(sp_tensors_2), 2)
|
||||
sp_tensors_3 = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_5x7(), num_split=3, axis=0)
|
||||
sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis)
|
||||
self.assertAllEqual(len(sp_tensors_3), 3)
|
||||
self.assertAllEqual(sp_tensors_3[0].indices.eval(),
|
||||
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
|
||||
[1, 4], [1, 6]])
|
||||
self.assertAllEqual(sp_tensors_3[0].values.eval(),
|
||||
[0, 2, 4, 5, 11, 13, 14, 16])
|
||||
self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [2, 7])
|
||||
self.assertAllEqual(
|
||||
sp_tensors_3[0].indices,
|
||||
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]])
|
||||
self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 4, 5, 11, 13, 14, 16])
|
||||
self.assertAllEqual(sp_tensors_3[0].dense_shape, [2, 7])
|
||||
|
||||
self.assertAllEqual(sp_tensors_3[1].values.eval(),
|
||||
[20, 23, 25, 30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [2, 7])
|
||||
self.assertAllEqual(sp_tensors_3[2].indices.eval(), [[0, 1], [0, 4],
|
||||
[0, 6]])
|
||||
self.assertAllEqual(sp_tensors_3[2].values.eval(), [41, 44, 46])
|
||||
self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [1, 7])
|
||||
return
|
||||
self.assertAllEqual(sp_tensors_3[1].values, [20, 23, 25, 30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors_3[1].dense_shape, [2, 7])
|
||||
self.assertAllEqual(sp_tensors_3[2].indices, [[0, 1], [0, 4], [0, 6]])
|
||||
self.assertAllEqual(sp_tensors_3[2].values, [41, 44, 46])
|
||||
self.assertAllEqual(sp_tensors_3[2].dense_shape, [1, 7])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitAllRows(self):
|
||||
with self.session(use_gpu=False):
|
||||
sp_tensors = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=4, axis=0)
|
||||
for axis in (0, -2):
|
||||
sp_tensors = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=4, axis=axis))
|
||||
self.assertAllEqual(len(sp_tensors), 4)
|
||||
self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4],
|
||||
[0, 5]])
|
||||
self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5])
|
||||
self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [1, 6])
|
||||
self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 1], [0, 3], [0,
|
||||
4]])
|
||||
self.assertAllEqual(sp_tensors[1].values.eval(), [11, 13, 14])
|
||||
self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [1, 6])
|
||||
self.assertAllEqual(sp_tensors[2].indices.eval(), [[0, 0], [0, 3], [0,
|
||||
5]])
|
||||
self.assertAllEqual(sp_tensors[2].values.eval(), [20, 23, 25])
|
||||
self.assertAllEqual(sp_tensors[2].dense_shape.eval(), [1, 6])
|
||||
self.assertAllEqual(sp_tensors[3].indices.eval(), [[0, 0], [0, 2], [0, 3],
|
||||
[0, 5]])
|
||||
self.assertAllEqual(sp_tensors[3].values.eval(), [30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors[3].dense_shape.eval(), [1, 6])
|
||||
self.assertAllEqual(sp_tensors[0].indices,
|
||||
[[0, 0], [0, 2], [0, 4], [0, 5]])
|
||||
self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5])
|
||||
self.assertAllEqual(sp_tensors[0].dense_shape, [1, 6])
|
||||
self.assertAllEqual(sp_tensors[1].indices, [[0, 1], [0, 3], [0, 4]])
|
||||
self.assertAllEqual(sp_tensors[1].values, [11, 13, 14])
|
||||
self.assertAllEqual(sp_tensors[1].dense_shape, [1, 6])
|
||||
self.assertAllEqual(sp_tensors[2].indices, [[0, 0], [0, 3], [0, 5]])
|
||||
self.assertAllEqual(sp_tensors[2].values, [20, 23, 25])
|
||||
self.assertAllEqual(sp_tensors[2].dense_shape, [1, 6])
|
||||
self.assertAllEqual(sp_tensors[3].indices,
|
||||
[[0, 0], [0, 2], [0, 3], [0, 5]])
|
||||
self.assertAllEqual(sp_tensors[3].values, [30, 32, 33, 35])
|
||||
self.assertAllEqual(sp_tensors[3].dense_shape, [1, 6])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitColumns(self):
|
||||
with self.session(use_gpu=False):
|
||||
sparse_tensors = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=3, axis=1)
|
||||
for axis in (1, -1):
|
||||
sparse_tensors = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
|
||||
self.assertAllEqual(len(sparse_tensors), 3)
|
||||
self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [1, 1],
|
||||
[2, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 11, 20, 30])
|
||||
self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 2])
|
||||
self.assertAllEqual(sparse_tensors[1].indices.eval(),
|
||||
self.assertAllEqual(sparse_tensors[0].indices,
|
||||
[[0, 0], [1, 1], [2, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[0].values, [0, 11, 20, 30])
|
||||
self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 2])
|
||||
self.assertAllEqual(sparse_tensors[1].indices,
|
||||
[[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]])
|
||||
self.assertAllEqual(sparse_tensors[1].values.eval(), [2, 13, 23, 32, 33])
|
||||
self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 2])
|
||||
self.assertAllEqual(sparse_tensors[2].indices.eval(),
|
||||
self.assertAllEqual(sparse_tensors[1].values, [2, 13, 23, 32, 33])
|
||||
self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 2])
|
||||
self.assertAllEqual(sparse_tensors[2].indices,
|
||||
[[0, 0], [0, 1], [1, 0], [2, 1], [3, 1]])
|
||||
self.assertAllEqual(sparse_tensors[2].values.eval(), [4, 5, 14, 25, 35])
|
||||
self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 2])
|
||||
self.assertAllEqual(sparse_tensors[2].values, [4, 5, 14, 25, 35])
|
||||
self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSplitAllColumns(self):
|
||||
with self.session(use_gpu=False):
|
||||
sparse_tensors = sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=6, axis=1)
|
||||
for axis in (1, -1):
|
||||
sparse_tensors = self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=6, axis=axis))
|
||||
self.assertAllEqual(len(sparse_tensors), 6)
|
||||
self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [2, 0],
|
||||
[3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 20, 30])
|
||||
self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[1].indices.eval(), [[1, 0]])
|
||||
self.assertAllEqual(sparse_tensors[1].values.eval(), [11])
|
||||
self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[2].indices.eval(), [[0, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[2].values.eval(), [2, 32])
|
||||
self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[3].indices.eval(), [[1, 0], [2, 0],
|
||||
[3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[3].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[3].values.eval(), [13, 23, 33])
|
||||
self.assertAllEqual(sparse_tensors[4].indices.eval(), [[0, 0], [1, 0]])
|
||||
self.assertAllEqual(sparse_tensors[4].values.eval(), [4, 14])
|
||||
self.assertAllEqual(sparse_tensors[4].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[5].indices.eval(), [[0, 0], [2, 0],
|
||||
[3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[5].values.eval(), [5, 25, 35])
|
||||
self.assertAllEqual(sparse_tensors[5].dense_shape.eval(), [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[0].indices, [[0, 0], [2, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[0].values, [0, 20, 30])
|
||||
self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[1].indices, [[1, 0]])
|
||||
self.assertAllEqual(sparse_tensors[1].values, [11])
|
||||
self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[2].indices, [[0, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[2].values, [2, 32])
|
||||
self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[3].indices, [[1, 0], [2, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[3].dense_shape, [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[3].values, [13, 23, 33])
|
||||
self.assertAllEqual(sparse_tensors[4].indices, [[0, 0], [1, 0]])
|
||||
self.assertAllEqual(sparse_tensors[4].values, [4, 14])
|
||||
self.assertAllEqual(sparse_tensors[4].dense_shape, [4, 1])
|
||||
self.assertAllEqual(sparse_tensors[5].indices, [[0, 0], [2, 0], [3, 0]])
|
||||
self.assertAllEqual(sparse_tensors[5].values, [5, 25, 35])
|
||||
self.assertAllEqual(sparse_tensors[5].dense_shape, [4, 1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSliceConcat(self):
|
||||
for sp_input in (self._SparseTensorValue_3x4x2(),
|
||||
self._SparseTensor_3x4x2()):
|
||||
with self.cached_session(use_gpu=False):
|
||||
for axis in (1, -2):
|
||||
sparse_tensors = sparse_ops.sparse_split(
|
||||
sp_input=sp_input, num_split=2, axis=1)
|
||||
concat_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
|
||||
sp_input=sp_input, num_split=2, axis=axis)
|
||||
concat_tensor = self.evaluate(
|
||||
sparse_ops.sparse_concat(1, sparse_tensors))
|
||||
expected_output = self._SparseTensor_3x4x2()
|
||||
self.assertAllEqual(concat_tensor.indices.eval(),
|
||||
expected_output.indices.eval())
|
||||
self.assertAllEqual(concat_tensor.indices, expected_output.indices)
|
||||
|
||||
def testInvalidAxis(self):
|
||||
for axis in (-3, 2):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
r'axis should be in range \[-2, 2\)'):
|
||||
self.evaluate(
|
||||
sparse_ops.sparse_split(
|
||||
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
|
||||
|
||||
def testArgumentErrors(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Keyword arguments are required'):
|
||||
|
@ -983,7 +983,9 @@ def sparse_split(keyword_required=KeywordRequired(),
|
||||
keyword_required: Python 2 standin for * (temporary for argument reorder)
|
||||
sp_input: The `SparseTensor` to split.
|
||||
num_split: A Python integer. The number of ways to split.
|
||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in
|
||||
range [-rank, rank), where rank is the number of dimensions in the input
|
||||
`SparseTensor`.
|
||||
name: A name for the operation (optional).
|
||||
split_dim: Deprecated old name for axis.
|
||||
|
||||
@ -1031,27 +1033,50 @@ def sparse_split_v2(sp_input=None,
|
||||
|
||||
If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split`
|
||||
each slice starting from 0:`shape[axis] % num_split` gets extra one
|
||||
dimension. For example, if `axis = 1` and `num_split = 2` and the
|
||||
input is:
|
||||
dimension. For example:
|
||||
|
||||
input_tensor = shape = [2, 7]
|
||||
[ a d e ]
|
||||
[b c ]
|
||||
>>> indices = [[0, 2], [0, 4], [0, 5], [1, 0], [1, 1]]
|
||||
>>> values = [1, 2, 3, 4, 5]
|
||||
>>> t = tf.SparseTensor(indices=indices, values=values, dense_shape=[2, 7])
|
||||
>>> tf.sparse.to_dense(t)
|
||||
<tf.Tensor: shape=(2, 7), dtype=int32, numpy=
|
||||
array([[0, 0, 1, 0, 2, 3, 0],
|
||||
[4, 5, 0, 0, 0, 0, 0]], dtype=int32)>
|
||||
|
||||
Graphically the output tensors are:
|
||||
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=1)
|
||||
>>> tf.sparse.to_dense(output[0])
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[0, 0, 1, 0],
|
||||
[4, 5, 0, 0]], dtype=int32)>
|
||||
>>> tf.sparse.to_dense(output[1])
|
||||
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
|
||||
array([[2, 3, 0],
|
||||
[0, 0, 0]], dtype=int32)>
|
||||
|
||||
output_tensor[0] =
|
||||
[ a ]
|
||||
[b c ]
|
||||
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=0)
|
||||
>>> tf.sparse.to_dense(output[0])
|
||||
<tf.Tensor: shape=(1, 7), dtype=int32, numpy=array([[0, 0, 1, 0, 2, 3, 0]],
|
||||
dtype=int32)>
|
||||
>>> tf.sparse.to_dense(output[1])
|
||||
<tf.Tensor: shape=(1, 7), dtype=int32, numpy=array([[4, 5, 0, 0, 0, 0, 0]],
|
||||
dtype=int32)>
|
||||
|
||||
output_tensor[1] =
|
||||
[ d e ]
|
||||
[ ]
|
||||
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=-1)
|
||||
>>> tf.sparse.to_dense(output[0])
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[0, 0, 1, 0],
|
||||
[4, 5, 0, 0]], dtype=int32)>
|
||||
>>> tf.sparse.to_dense(output[1])
|
||||
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
|
||||
array([[2, 3, 0],
|
||||
[0, 0, 0]], dtype=int32)>
|
||||
|
||||
Args:
|
||||
sp_input: The `SparseTensor` to split.
|
||||
num_split: A Python integer. The number of ways to split.
|
||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in
|
||||
range [-rank, rank), where rank is the number of dimensions in the input
|
||||
`SparseTensor`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user