Add support for negative indices to sparse_split.

PiperOrigin-RevId: 320082703
Change-Id: I71c13e4ad05eb4ce5f5f2c7e2bd3be028875929f
This commit is contained in:
Victor de Souza 2020-07-07 16:18:34 -07:00 committed by TensorFlower Gardener
parent 40a0ec5420
commit 1c4ee22e6f
3 changed files with 187 additions and 169 deletions

View File

@ -30,7 +30,7 @@ class SparseSplitOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
const int64 split_dim = context->input(0).scalar<int64>()();
const int64 axis_input = context->input(0).scalar<int64>()();
const Tensor& input_indices = context->input(1);
const Tensor& input_values = context->input(2);
const Tensor& input_shape = context->input(3);
@ -48,20 +48,20 @@ class SparseSplitOp : public OpKernel {
"Input shape should be a vector but received shape ",
input_shape.shape().DebugString()));
OP_REQUIRES(
context,
input_shape.dim_size(0) && split_dim < input_shape.vec<int64>().size(),
errors::InvalidArgument(
"Input split_dim should be between 0 and rank (",
input_shape.vec<int64>().size(), "), got ", split_dim));
const int64 input_rank = input_shape.vec<int64>().size();
const int64 axis = (axis_input < 0) ? input_rank + axis_input : axis_input;
OP_REQUIRES(
context,
num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(split_dim),
errors::InvalidArgument("Input num_split should be between 1 "
"and the splitting dimension size (",
input_shape.vec<int64>()(split_dim), "), got ",
num_split_));
context, axis >= 0 && axis < input_rank,
errors::InvalidArgument("Input axis should be in range [", -input_rank,
", ", input_rank, "), got ", axis_input));
OP_REQUIRES(context,
num_split_ >= 1 && num_split_ <= input_shape.vec<int64>()(axis),
errors::InvalidArgument("Input num_split should be between 1 "
"and the splitting dimension size (",
input_shape.vec<int64>()(axis),
"), got ", num_split_));
sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(context,
@ -70,9 +70,8 @@ class SparseSplitOp : public OpKernel {
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
std::vector<sparse::SparseTensor> outputs;
OP_REQUIRES_OK(context,
sparse::SparseTensor::Split<T>(sparse_tensor, split_dim,
num_split_, &outputs));
OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(
sparse_tensor, axis, num_split_, &outputs));
for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
context->set_output(slice_index, outputs[slice_index].indices());

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@ -76,182 +76,176 @@ class SparseSplitOpTest(test.TestCase):
return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2(
))
@test_util.run_deprecated_v1
def testSplitMatrixRows(self):
with self.session(use_gpu=False):
sp_tensors = sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=2, axis=0)
for axis in (0, -2):
sp_tensors = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=2, axis=axis))
self.assertAllEqual(len(sp_tensors), 2)
self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4],
[0, 5], [1, 1], [1, 3],
[1, 4]])
self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5, 11, 13, 14])
self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [2, 6])
self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 0], [0, 3], [0, 5],
[1, 0], [1, 2], [1, 3],
[1, 5]])
self.assertAllEqual(sp_tensors[1].values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [2, 6])
self.assertAllEqual(
sp_tensors[0].indices,
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]])
self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5, 11, 13, 14])
self.assertAllEqual(sp_tensors[0].dense_shape, [2, 6])
self.assertAllEqual(
sp_tensors[1].indices,
[[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]])
self.assertAllEqual(sp_tensors[1].values, [20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensors[1].dense_shape, [2, 6])
@test_util.run_deprecated_v1
def testSplitMatrixUnevenCols(self):
with self.session(use_gpu=False):
sp_tensors_3 = sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=3, axis=1)
for axis in (1, -1):
sp_tensors_3 = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis))
self.assertAllEqual(len(sp_tensors_3), 3)
self.assertAllEqual(sp_tensors_3[0].indices.eval(),
[[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2],
[4, 1]])
self.assertAllEqual(sp_tensors_3[0].values.eval(),
[0, 2, 11, 20, 30, 32, 41])
self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [5, 3])
self.assertAllEqual(sp_tensors_3[1].indices.eval(),
self.assertAllEqual(
sp_tensors_3[0].indices,
[[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]])
self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 11, 20, 30, 32, 41])
self.assertAllEqual(sp_tensors_3[0].dense_shape, [5, 3])
self.assertAllEqual(sp_tensors_3[1].indices,
[[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
self.assertAllEqual(sp_tensors_3[1].values.eval(),
[4, 13, 14, 23, 33, 44])
self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensors_3[2].indices.eval(),
self.assertAllEqual(sp_tensors_3[1].values, [4, 13, 14, 23, 33, 44])
self.assertAllEqual(sp_tensors_3[1].dense_shape, [5, 2])
self.assertAllEqual(sp_tensors_3[2].indices,
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
self.assertAllEqual(sp_tensors_3[2].values.eval(), [5, 16, 25, 35, 46])
self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensors_3[2].values, [5, 16, 25, 35, 46])
self.assertAllEqual(sp_tensors_3[2].dense_shape, [5, 2])
sp_tensors_4 = sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=4, axis=1)
sp_input=self._SparseTensor_5x7(), num_split=4, axis=axis)
self.assertAllEqual(len(sp_tensors_4), 4)
self.assertAllEqual(sp_tensors_4[0].indices.eval(),
self.assertAllEqual(sp_tensors_4[0].indices,
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
self.assertAllEqual(sp_tensors_4[0].values.eval(), [0, 11, 20, 30, 41])
self.assertAllEqual(sp_tensors_4[0].dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensors_4[1].indices.eval(),
self.assertAllEqual(sp_tensors_4[0].values, [0, 11, 20, 30, 41])
self.assertAllEqual(sp_tensors_4[0].dense_shape, [5, 2])
self.assertAllEqual(sp_tensors_4[1].indices,
[[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]])
self.assertAllEqual(sp_tensors_4[1].values.eval(), [2, 13, 23, 32, 33])
self.assertAllEqual(sp_tensors_4[1].dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensors_4[2].indices.eval(),
self.assertAllEqual(sp_tensors_4[1].values, [2, 13, 23, 32, 33])
self.assertAllEqual(sp_tensors_4[1].dense_shape, [5, 2])
self.assertAllEqual(sp_tensors_4[2].indices,
[[0, 0], [0, 1], [1, 0], [2, 1], [3, 1], [4, 0]])
self.assertAllEqual(sp_tensors_4[2].values.eval(), [4, 5, 14, 25, 35, 44])
self.assertAllEqual(sp_tensors_4[2].dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensors_4[3].indices.eval(), [[1, 0], [4, 0]])
self.assertAllEqual(sp_tensors_4[3].values.eval(), [16, 46])
self.assertAllEqual(sp_tensors_4[3].dense_shape.eval(), [5, 1])
self.assertAllEqual(sp_tensors_4[2].values, [4, 5, 14, 25, 35, 44])
self.assertAllEqual(sp_tensors_4[2].dense_shape, [5, 2])
self.assertAllEqual(sp_tensors_4[3].indices, [[1, 0], [4, 0]])
self.assertAllEqual(sp_tensors_4[3].values, [16, 46])
self.assertAllEqual(sp_tensors_4[3].dense_shape, [5, 1])
@test_util.run_deprecated_v1
def testSplitMatrixUnevenRows(self):
with self.session(use_gpu=False):
sp_tensors_2 = sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=2, axis=0)
self.assertAllEqual(sp_tensors_2[0].indices.eval(),
for axis in (0, -2):
sp_tensors_2 = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=2, axis=axis))
self.assertAllEqual(sp_tensors_2[0].indices,
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
[1, 4], [1, 6], [2, 0], [2, 3], [2, 5]])
self.assertAllEqual(sp_tensors_2[0].values.eval(),
self.assertAllEqual(sp_tensors_2[0].values,
[0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25])
self.assertAllEqual(sp_tensors_2[0].dense_shape.eval(), [3, 7])
self.assertAllEqual(sp_tensors_2[1].indices.eval(),
[[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4],
[1, 6]])
self.assertAllEqual(sp_tensors_2[1].values.eval(),
[30, 32, 33, 35, 41, 44, 46])
self.assertAllEqual(sp_tensors_2[1].dense_shape.eval(), [2, 7])
self.assertAllEqual(sp_tensors_2[0].dense_shape, [3, 7])
self.assertAllEqual(
sp_tensors_2[1].indices,
[[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensors_2[1].values, [30, 32, 33, 35, 41, 44, 46])
self.assertAllEqual(sp_tensors_2[1].dense_shape, [2, 7])
self.assertAllEqual(len(sp_tensors_2), 2)
sp_tensors_3 = sparse_ops.sparse_split(
sp_input=self._SparseTensor_5x7(), num_split=3, axis=0)
sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis)
self.assertAllEqual(len(sp_tensors_3), 3)
self.assertAllEqual(sp_tensors_3[0].indices.eval(),
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
[1, 4], [1, 6]])
self.assertAllEqual(sp_tensors_3[0].values.eval(),
[0, 2, 4, 5, 11, 13, 14, 16])
self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [2, 7])
self.assertAllEqual(
sp_tensors_3[0].indices,
[[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 4, 5, 11, 13, 14, 16])
self.assertAllEqual(sp_tensors_3[0].dense_shape, [2, 7])
self.assertAllEqual(sp_tensors_3[1].values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [2, 7])
self.assertAllEqual(sp_tensors_3[2].indices.eval(), [[0, 1], [0, 4],
[0, 6]])
self.assertAllEqual(sp_tensors_3[2].values.eval(), [41, 44, 46])
self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [1, 7])
return
self.assertAllEqual(sp_tensors_3[1].values, [20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensors_3[1].dense_shape, [2, 7])
self.assertAllEqual(sp_tensors_3[2].indices, [[0, 1], [0, 4], [0, 6]])
self.assertAllEqual(sp_tensors_3[2].values, [41, 44, 46])
self.assertAllEqual(sp_tensors_3[2].dense_shape, [1, 7])
@test_util.run_deprecated_v1
def testSplitAllRows(self):
with self.session(use_gpu=False):
sp_tensors = sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=4, axis=0)
for axis in (0, -2):
sp_tensors = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=4, axis=axis))
self.assertAllEqual(len(sp_tensors), 4)
self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4],
[0, 5]])
self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5])
self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [1, 6])
self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 1], [0, 3], [0,
4]])
self.assertAllEqual(sp_tensors[1].values.eval(), [11, 13, 14])
self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [1, 6])
self.assertAllEqual(sp_tensors[2].indices.eval(), [[0, 0], [0, 3], [0,
5]])
self.assertAllEqual(sp_tensors[2].values.eval(), [20, 23, 25])
self.assertAllEqual(sp_tensors[2].dense_shape.eval(), [1, 6])
self.assertAllEqual(sp_tensors[3].indices.eval(), [[0, 0], [0, 2], [0, 3],
[0, 5]])
self.assertAllEqual(sp_tensors[3].values.eval(), [30, 32, 33, 35])
self.assertAllEqual(sp_tensors[3].dense_shape.eval(), [1, 6])
self.assertAllEqual(sp_tensors[0].indices,
[[0, 0], [0, 2], [0, 4], [0, 5]])
self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5])
self.assertAllEqual(sp_tensors[0].dense_shape, [1, 6])
self.assertAllEqual(sp_tensors[1].indices, [[0, 1], [0, 3], [0, 4]])
self.assertAllEqual(sp_tensors[1].values, [11, 13, 14])
self.assertAllEqual(sp_tensors[1].dense_shape, [1, 6])
self.assertAllEqual(sp_tensors[2].indices, [[0, 0], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensors[2].values, [20, 23, 25])
self.assertAllEqual(sp_tensors[2].dense_shape, [1, 6])
self.assertAllEqual(sp_tensors[3].indices,
[[0, 0], [0, 2], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensors[3].values, [30, 32, 33, 35])
self.assertAllEqual(sp_tensors[3].dense_shape, [1, 6])
@test_util.run_deprecated_v1
def testSplitColumns(self):
with self.session(use_gpu=False):
sparse_tensors = sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=3, axis=1)
for axis in (1, -1):
sparse_tensors = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
self.assertAllEqual(len(sparse_tensors), 3)
self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [1, 1],
[2, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 11, 20, 30])
self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 2])
self.assertAllEqual(sparse_tensors[1].indices.eval(),
self.assertAllEqual(sparse_tensors[0].indices,
[[0, 0], [1, 1], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[0].values, [0, 11, 20, 30])
self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 2])
self.assertAllEqual(sparse_tensors[1].indices,
[[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]])
self.assertAllEqual(sparse_tensors[1].values.eval(), [2, 13, 23, 32, 33])
self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 2])
self.assertAllEqual(sparse_tensors[2].indices.eval(),
self.assertAllEqual(sparse_tensors[1].values, [2, 13, 23, 32, 33])
self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 2])
self.assertAllEqual(sparse_tensors[2].indices,
[[0, 0], [0, 1], [1, 0], [2, 1], [3, 1]])
self.assertAllEqual(sparse_tensors[2].values.eval(), [4, 5, 14, 25, 35])
self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 2])
self.assertAllEqual(sparse_tensors[2].values, [4, 5, 14, 25, 35])
self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 2])
@test_util.run_deprecated_v1
def testSplitAllColumns(self):
with self.session(use_gpu=False):
sparse_tensors = sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=6, axis=1)
for axis in (1, -1):
sparse_tensors = self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=6, axis=axis))
self.assertAllEqual(len(sparse_tensors), 6)
self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [2, 0],
[3, 0]])
self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 20, 30])
self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[1].indices.eval(), [[1, 0]])
self.assertAllEqual(sparse_tensors[1].values.eval(), [11])
self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[2].indices.eval(), [[0, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[2].values.eval(), [2, 32])
self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[3].indices.eval(), [[1, 0], [2, 0],
[3, 0]])
self.assertAllEqual(sparse_tensors[3].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[3].values.eval(), [13, 23, 33])
self.assertAllEqual(sparse_tensors[4].indices.eval(), [[0, 0], [1, 0]])
self.assertAllEqual(sparse_tensors[4].values.eval(), [4, 14])
self.assertAllEqual(sparse_tensors[4].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[5].indices.eval(), [[0, 0], [2, 0],
[3, 0]])
self.assertAllEqual(sparse_tensors[5].values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensors[5].dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensors[0].indices, [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[0].values, [0, 20, 30])
self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 1])
self.assertAllEqual(sparse_tensors[1].indices, [[1, 0]])
self.assertAllEqual(sparse_tensors[1].values, [11])
self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 1])
self.assertAllEqual(sparse_tensors[2].indices, [[0, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[2].values, [2, 32])
self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 1])
self.assertAllEqual(sparse_tensors[3].indices, [[1, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[3].dense_shape, [4, 1])
self.assertAllEqual(sparse_tensors[3].values, [13, 23, 33])
self.assertAllEqual(sparse_tensors[4].indices, [[0, 0], [1, 0]])
self.assertAllEqual(sparse_tensors[4].values, [4, 14])
self.assertAllEqual(sparse_tensors[4].dense_shape, [4, 1])
self.assertAllEqual(sparse_tensors[5].indices, [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensors[5].values, [5, 25, 35])
self.assertAllEqual(sparse_tensors[5].dense_shape, [4, 1])
@test_util.run_deprecated_v1
def testSliceConcat(self):
for sp_input in (self._SparseTensorValue_3x4x2(),
self._SparseTensor_3x4x2()):
with self.cached_session(use_gpu=False):
for axis in (1, -2):
sparse_tensors = sparse_ops.sparse_split(
sp_input=sp_input, num_split=2, axis=1)
concat_tensor = sparse_ops.sparse_concat(1, sparse_tensors)
sp_input=sp_input, num_split=2, axis=axis)
concat_tensor = self.evaluate(
sparse_ops.sparse_concat(1, sparse_tensors))
expected_output = self._SparseTensor_3x4x2()
self.assertAllEqual(concat_tensor.indices.eval(),
expected_output.indices.eval())
self.assertAllEqual(concat_tensor.indices, expected_output.indices)
def testInvalidAxis(self):
for axis in (-3, 2):
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r'axis should be in range \[-2, 2\)'):
self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
def testArgumentErrors(self):
with self.assertRaisesRegex(ValueError, 'Keyword arguments are required'):

View File

@ -983,7 +983,9 @@ def sparse_split(keyword_required=KeywordRequired(),
keyword_required: Python 2 standin for * (temporary for argument reorder)
sp_input: The `SparseTensor` to split.
num_split: A Python integer. The number of ways to split.
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in
range [-rank, rank), where rank is the number of dimensions in the input
`SparseTensor`.
name: A name for the operation (optional).
split_dim: Deprecated old name for axis.
@ -1031,27 +1033,50 @@ def sparse_split_v2(sp_input=None,
If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split`
each slice starting from 0:`shape[axis] % num_split` gets extra one
dimension. For example, if `axis = 1` and `num_split = 2` and the
input is:
dimension. For example:
input_tensor = shape = [2, 7]
[ a d e ]
[b c ]
>>> indices = [[0, 2], [0, 4], [0, 5], [1, 0], [1, 1]]
>>> values = [1, 2, 3, 4, 5]
>>> t = tf.SparseTensor(indices=indices, values=values, dense_shape=[2, 7])
>>> tf.sparse.to_dense(t)
<tf.Tensor: shape=(2, 7), dtype=int32, numpy=
array([[0, 0, 1, 0, 2, 3, 0],
[4, 5, 0, 0, 0, 0, 0]], dtype=int32)>
Graphically the output tensors are:
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=1)
>>> tf.sparse.to_dense(output[0])
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[0, 0, 1, 0],
[4, 5, 0, 0]], dtype=int32)>
>>> tf.sparse.to_dense(output[1])
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[2, 3, 0],
[0, 0, 0]], dtype=int32)>
output_tensor[0] =
[ a ]
[b c ]
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=0)
>>> tf.sparse.to_dense(output[0])
<tf.Tensor: shape=(1, 7), dtype=int32, numpy=array([[0, 0, 1, 0, 2, 3, 0]],
dtype=int32)>
>>> tf.sparse.to_dense(output[1])
<tf.Tensor: shape=(1, 7), dtype=int32, numpy=array([[4, 5, 0, 0, 0, 0, 0]],
dtype=int32)>
output_tensor[1] =
[ d e ]
[ ]
>>> output = tf.sparse.split(sp_input=t, num_split=2, axis=-1)
>>> tf.sparse.to_dense(output[0])
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[0, 0, 1, 0],
[4, 5, 0, 0]], dtype=int32)>
>>> tf.sparse.to_dense(output[1])
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[2, 3, 0],
[0, 0, 0]], dtype=int32)>
Args:
sp_input: The `SparseTensor` to split.
num_split: A Python integer. The number of ways to split.
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in
range [-rank, rank), where rank is the number of dimensions in the input
`SparseTensor`.
name: A name for the operation (optional).
Returns: