From 1c4ee22e6fd7da8769d4d16136a4de42d2dbc7cb Mon Sep 17 00:00:00 2001 From: Victor de Souza Date: Tue, 7 Jul 2020 16:18:34 -0700 Subject: [PATCH] Add support for negative indices to sparse_split. PiperOrigin-RevId: 320082703 Change-Id: I71c13e4ad05eb4ce5f5f2c7e2bd3be028875929f --- tensorflow/core/kernels/sparse_split_op.cc | 31 +- .../kernel_tests/sparse_split_op_test.py | 272 +++++++++--------- tensorflow/python/ops/sparse_ops.py | 53 +++- 3 files changed, 187 insertions(+), 169 deletions(-) diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc index 3d02be47cbb..5c0457aa956 100644 --- a/tensorflow/core/kernels/sparse_split_op.cc +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -30,7 +30,7 @@ class SparseSplitOp : public OpKernel { } void Compute(OpKernelContext* context) override { - const int64 split_dim = context->input(0).scalar()(); + const int64 axis_input = context->input(0).scalar()(); const Tensor& input_indices = context->input(1); const Tensor& input_values = context->input(2); const Tensor& input_shape = context->input(3); @@ -48,20 +48,20 @@ class SparseSplitOp : public OpKernel { "Input shape should be a vector but received shape ", input_shape.shape().DebugString())); - OP_REQUIRES( - context, - input_shape.dim_size(0) && split_dim < input_shape.vec().size(), - errors::InvalidArgument( - "Input split_dim should be between 0 and rank (", - input_shape.vec().size(), "), got ", split_dim)); + const int64 input_rank = input_shape.vec().size(); + const int64 axis = (axis_input < 0) ? input_rank + axis_input : axis_input; OP_REQUIRES( - context, - num_split_ >= 1 && num_split_ <= input_shape.vec()(split_dim), - errors::InvalidArgument("Input num_split should be between 1 " - "and the splitting dimension size (", - input_shape.vec()(split_dim), "), got ", - num_split_)); + context, axis >= 0 && axis < input_rank, + errors::InvalidArgument("Input axis should be in range [", -input_rank, + ", ", input_rank, "), got ", axis_input)); + + OP_REQUIRES(context, + num_split_ >= 1 && num_split_ <= input_shape.vec()(axis), + errors::InvalidArgument("Input num_split should be between 1 " + "and the splitting dimension size (", + input_shape.vec()(axis), + "), got ", num_split_)); sparse::SparseTensor sparse_tensor; OP_REQUIRES_OK(context, @@ -70,9 +70,8 @@ class SparseSplitOp : public OpKernel { TensorShape(input_shape.vec()), &sparse_tensor)); std::vector outputs; - OP_REQUIRES_OK(context, - sparse::SparseTensor::Split(sparse_tensor, split_dim, - num_split_, &outputs)); + OP_REQUIRES_OK(context, sparse::SparseTensor::Split( + sparse_tensor, axis, num_split_, &outputs)); for (int slice_index = 0; slice_index < num_split_; ++slice_index) { context->set_output(slice_index, outputs[slice_index].indices()); diff --git a/tensorflow/python/kernel_tests/sparse_split_op_test.py b/tensorflow/python/kernel_tests/sparse_split_op_test.py index bdd4b8e7634..31ef1129f13 100644 --- a/tensorflow/python/kernel_tests/sparse_split_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_split_op_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -76,182 +76,176 @@ class SparseSplitOpTest(test.TestCase): return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2( )) - @test_util.run_deprecated_v1 def testSplitMatrixRows(self): - with self.session(use_gpu=False): - sp_tensors = sparse_ops.sparse_split( - sp_input=self._SparseTensor_4x6(), num_split=2, axis=0) + for axis in (0, -2): + sp_tensors = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=2, axis=axis)) self.assertAllEqual(len(sp_tensors), 2) - self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4], - [0, 5], [1, 1], [1, 3], - [1, 4]]) - self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5, 11, 13, 14]) - self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [2, 6]) - self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 0], [0, 3], [0, 5], - [1, 0], [1, 2], [1, 3], - [1, 5]]) - self.assertAllEqual(sp_tensors[1].values.eval(), - [20, 23, 25, 30, 32, 33, 35]) - self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [2, 6]) + self.assertAllEqual( + sp_tensors[0].indices, + [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]]) + self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5, 11, 13, 14]) + self.assertAllEqual(sp_tensors[0].dense_shape, [2, 6]) + self.assertAllEqual( + sp_tensors[1].indices, + [[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]]) + self.assertAllEqual(sp_tensors[1].values, [20, 23, 25, 30, 32, 33, 35]) + self.assertAllEqual(sp_tensors[1].dense_shape, [2, 6]) - @test_util.run_deprecated_v1 def testSplitMatrixUnevenCols(self): - with self.session(use_gpu=False): - sp_tensors_3 = sparse_ops.sparse_split( - sp_input=self._SparseTensor_5x7(), num_split=3, axis=1) + for axis in (1, -1): + sp_tensors_3 = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis)) self.assertAllEqual(len(sp_tensors_3), 3) - self.assertAllEqual(sp_tensors_3[0].indices.eval(), - [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], - [4, 1]]) - self.assertAllEqual(sp_tensors_3[0].values.eval(), - [0, 2, 11, 20, 30, 32, 41]) - self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [5, 3]) - self.assertAllEqual(sp_tensors_3[1].indices.eval(), + self.assertAllEqual( + sp_tensors_3[0].indices, + [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]]) + self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 11, 20, 30, 32, 41]) + self.assertAllEqual(sp_tensors_3[0].dense_shape, [5, 3]) + self.assertAllEqual(sp_tensors_3[1].indices, [[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]]) - self.assertAllEqual(sp_tensors_3[1].values.eval(), - [4, 13, 14, 23, 33, 44]) - self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [5, 2]) - self.assertAllEqual(sp_tensors_3[2].indices.eval(), + self.assertAllEqual(sp_tensors_3[1].values, [4, 13, 14, 23, 33, 44]) + self.assertAllEqual(sp_tensors_3[1].dense_shape, [5, 2]) + self.assertAllEqual(sp_tensors_3[2].indices, [[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]]) - self.assertAllEqual(sp_tensors_3[2].values.eval(), [5, 16, 25, 35, 46]) - self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [5, 2]) + self.assertAllEqual(sp_tensors_3[2].values, [5, 16, 25, 35, 46]) + self.assertAllEqual(sp_tensors_3[2].dense_shape, [5, 2]) sp_tensors_4 = sparse_ops.sparse_split( - sp_input=self._SparseTensor_5x7(), num_split=4, axis=1) + sp_input=self._SparseTensor_5x7(), num_split=4, axis=axis) self.assertAllEqual(len(sp_tensors_4), 4) - self.assertAllEqual(sp_tensors_4[0].indices.eval(), + self.assertAllEqual(sp_tensors_4[0].indices, [[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]]) - self.assertAllEqual(sp_tensors_4[0].values.eval(), [0, 11, 20, 30, 41]) - self.assertAllEqual(sp_tensors_4[0].dense_shape.eval(), [5, 2]) - self.assertAllEqual(sp_tensors_4[1].indices.eval(), + self.assertAllEqual(sp_tensors_4[0].values, [0, 11, 20, 30, 41]) + self.assertAllEqual(sp_tensors_4[0].dense_shape, [5, 2]) + self.assertAllEqual(sp_tensors_4[1].indices, [[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]]) - self.assertAllEqual(sp_tensors_4[1].values.eval(), [2, 13, 23, 32, 33]) - self.assertAllEqual(sp_tensors_4[1].dense_shape.eval(), [5, 2]) - self.assertAllEqual(sp_tensors_4[2].indices.eval(), + self.assertAllEqual(sp_tensors_4[1].values, [2, 13, 23, 32, 33]) + self.assertAllEqual(sp_tensors_4[1].dense_shape, [5, 2]) + self.assertAllEqual(sp_tensors_4[2].indices, [[0, 0], [0, 1], [1, 0], [2, 1], [3, 1], [4, 0]]) - self.assertAllEqual(sp_tensors_4[2].values.eval(), [4, 5, 14, 25, 35, 44]) - self.assertAllEqual(sp_tensors_4[2].dense_shape.eval(), [5, 2]) - self.assertAllEqual(sp_tensors_4[3].indices.eval(), [[1, 0], [4, 0]]) - self.assertAllEqual(sp_tensors_4[3].values.eval(), [16, 46]) - self.assertAllEqual(sp_tensors_4[3].dense_shape.eval(), [5, 1]) + self.assertAllEqual(sp_tensors_4[2].values, [4, 5, 14, 25, 35, 44]) + self.assertAllEqual(sp_tensors_4[2].dense_shape, [5, 2]) + self.assertAllEqual(sp_tensors_4[3].indices, [[1, 0], [4, 0]]) + self.assertAllEqual(sp_tensors_4[3].values, [16, 46]) + self.assertAllEqual(sp_tensors_4[3].dense_shape, [5, 1]) - @test_util.run_deprecated_v1 def testSplitMatrixUnevenRows(self): - with self.session(use_gpu=False): - sp_tensors_2 = sparse_ops.sparse_split( - sp_input=self._SparseTensor_5x7(), num_split=2, axis=0) - self.assertAllEqual(sp_tensors_2[0].indices.eval(), + for axis in (0, -2): + sp_tensors_2 = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_5x7(), num_split=2, axis=axis)) + self.assertAllEqual(sp_tensors_2[0].indices, [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6], [2, 0], [2, 3], [2, 5]]) - self.assertAllEqual(sp_tensors_2[0].values.eval(), + self.assertAllEqual(sp_tensors_2[0].values, [0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25]) - self.assertAllEqual(sp_tensors_2[0].dense_shape.eval(), [3, 7]) - self.assertAllEqual(sp_tensors_2[1].indices.eval(), - [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], - [1, 6]]) - self.assertAllEqual(sp_tensors_2[1].values.eval(), - [30, 32, 33, 35, 41, 44, 46]) - self.assertAllEqual(sp_tensors_2[1].dense_shape.eval(), [2, 7]) + self.assertAllEqual(sp_tensors_2[0].dense_shape, [3, 7]) + self.assertAllEqual( + sp_tensors_2[1].indices, + [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]]) + self.assertAllEqual(sp_tensors_2[1].values, [30, 32, 33, 35, 41, 44, 46]) + self.assertAllEqual(sp_tensors_2[1].dense_shape, [2, 7]) self.assertAllEqual(len(sp_tensors_2), 2) sp_tensors_3 = sparse_ops.sparse_split( - sp_input=self._SparseTensor_5x7(), num_split=3, axis=0) + sp_input=self._SparseTensor_5x7(), num_split=3, axis=axis) self.assertAllEqual(len(sp_tensors_3), 3) - self.assertAllEqual(sp_tensors_3[0].indices.eval(), - [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], - [1, 4], [1, 6]]) - self.assertAllEqual(sp_tensors_3[0].values.eval(), - [0, 2, 4, 5, 11, 13, 14, 16]) - self.assertAllEqual(sp_tensors_3[0].dense_shape.eval(), [2, 7]) + self.assertAllEqual( + sp_tensors_3[0].indices, + [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]]) + self.assertAllEqual(sp_tensors_3[0].values, [0, 2, 4, 5, 11, 13, 14, 16]) + self.assertAllEqual(sp_tensors_3[0].dense_shape, [2, 7]) - self.assertAllEqual(sp_tensors_3[1].values.eval(), - [20, 23, 25, 30, 32, 33, 35]) - self.assertAllEqual(sp_tensors_3[1].dense_shape.eval(), [2, 7]) - self.assertAllEqual(sp_tensors_3[2].indices.eval(), [[0, 1], [0, 4], - [0, 6]]) - self.assertAllEqual(sp_tensors_3[2].values.eval(), [41, 44, 46]) - self.assertAllEqual(sp_tensors_3[2].dense_shape.eval(), [1, 7]) - return + self.assertAllEqual(sp_tensors_3[1].values, [20, 23, 25, 30, 32, 33, 35]) + self.assertAllEqual(sp_tensors_3[1].dense_shape, [2, 7]) + self.assertAllEqual(sp_tensors_3[2].indices, [[0, 1], [0, 4], [0, 6]]) + self.assertAllEqual(sp_tensors_3[2].values, [41, 44, 46]) + self.assertAllEqual(sp_tensors_3[2].dense_shape, [1, 7]) - @test_util.run_deprecated_v1 def testSplitAllRows(self): - with self.session(use_gpu=False): - sp_tensors = sparse_ops.sparse_split( - sp_input=self._SparseTensor_4x6(), num_split=4, axis=0) + for axis in (0, -2): + sp_tensors = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=4, axis=axis)) self.assertAllEqual(len(sp_tensors), 4) - self.assertAllEqual(sp_tensors[0].indices.eval(), [[0, 0], [0, 2], [0, 4], - [0, 5]]) - self.assertAllEqual(sp_tensors[0].values.eval(), [0, 2, 4, 5]) - self.assertAllEqual(sp_tensors[0].dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensors[1].indices.eval(), [[0, 1], [0, 3], [0, - 4]]) - self.assertAllEqual(sp_tensors[1].values.eval(), [11, 13, 14]) - self.assertAllEqual(sp_tensors[1].dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensors[2].indices.eval(), [[0, 0], [0, 3], [0, - 5]]) - self.assertAllEqual(sp_tensors[2].values.eval(), [20, 23, 25]) - self.assertAllEqual(sp_tensors[2].dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensors[3].indices.eval(), [[0, 0], [0, 2], [0, 3], - [0, 5]]) - self.assertAllEqual(sp_tensors[3].values.eval(), [30, 32, 33, 35]) - self.assertAllEqual(sp_tensors[3].dense_shape.eval(), [1, 6]) + self.assertAllEqual(sp_tensors[0].indices, + [[0, 0], [0, 2], [0, 4], [0, 5]]) + self.assertAllEqual(sp_tensors[0].values, [0, 2, 4, 5]) + self.assertAllEqual(sp_tensors[0].dense_shape, [1, 6]) + self.assertAllEqual(sp_tensors[1].indices, [[0, 1], [0, 3], [0, 4]]) + self.assertAllEqual(sp_tensors[1].values, [11, 13, 14]) + self.assertAllEqual(sp_tensors[1].dense_shape, [1, 6]) + self.assertAllEqual(sp_tensors[2].indices, [[0, 0], [0, 3], [0, 5]]) + self.assertAllEqual(sp_tensors[2].values, [20, 23, 25]) + self.assertAllEqual(sp_tensors[2].dense_shape, [1, 6]) + self.assertAllEqual(sp_tensors[3].indices, + [[0, 0], [0, 2], [0, 3], [0, 5]]) + self.assertAllEqual(sp_tensors[3].values, [30, 32, 33, 35]) + self.assertAllEqual(sp_tensors[3].dense_shape, [1, 6]) - @test_util.run_deprecated_v1 def testSplitColumns(self): - with self.session(use_gpu=False): - sparse_tensors = sparse_ops.sparse_split( - sp_input=self._SparseTensor_4x6(), num_split=3, axis=1) + for axis in (1, -1): + sparse_tensors = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis)) self.assertAllEqual(len(sparse_tensors), 3) - self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [1, 1], - [2, 0], [3, 0]]) - self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 11, 20, 30]) - self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 2]) - self.assertAllEqual(sparse_tensors[1].indices.eval(), + self.assertAllEqual(sparse_tensors[0].indices, + [[0, 0], [1, 1], [2, 0], [3, 0]]) + self.assertAllEqual(sparse_tensors[0].values, [0, 11, 20, 30]) + self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 2]) + self.assertAllEqual(sparse_tensors[1].indices, [[0, 0], [1, 1], [2, 1], [3, 0], [3, 1]]) - self.assertAllEqual(sparse_tensors[1].values.eval(), [2, 13, 23, 32, 33]) - self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 2]) - self.assertAllEqual(sparse_tensors[2].indices.eval(), + self.assertAllEqual(sparse_tensors[1].values, [2, 13, 23, 32, 33]) + self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 2]) + self.assertAllEqual(sparse_tensors[2].indices, [[0, 0], [0, 1], [1, 0], [2, 1], [3, 1]]) - self.assertAllEqual(sparse_tensors[2].values.eval(), [4, 5, 14, 25, 35]) - self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 2]) + self.assertAllEqual(sparse_tensors[2].values, [4, 5, 14, 25, 35]) + self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 2]) - @test_util.run_deprecated_v1 def testSplitAllColumns(self): - with self.session(use_gpu=False): - sparse_tensors = sparse_ops.sparse_split( - sp_input=self._SparseTensor_4x6(), num_split=6, axis=1) + for axis in (1, -1): + sparse_tensors = self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=6, axis=axis)) self.assertAllEqual(len(sparse_tensors), 6) - self.assertAllEqual(sparse_tensors[0].indices.eval(), [[0, 0], [2, 0], - [3, 0]]) - self.assertAllEqual(sparse_tensors[0].values.eval(), [0, 20, 30]) - self.assertAllEqual(sparse_tensors[0].dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensors[1].indices.eval(), [[1, 0]]) - self.assertAllEqual(sparse_tensors[1].values.eval(), [11]) - self.assertAllEqual(sparse_tensors[1].dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensors[2].indices.eval(), [[0, 0], [3, 0]]) - self.assertAllEqual(sparse_tensors[2].values.eval(), [2, 32]) - self.assertAllEqual(sparse_tensors[2].dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensors[3].indices.eval(), [[1, 0], [2, 0], - [3, 0]]) - self.assertAllEqual(sparse_tensors[3].dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensors[3].values.eval(), [13, 23, 33]) - self.assertAllEqual(sparse_tensors[4].indices.eval(), [[0, 0], [1, 0]]) - self.assertAllEqual(sparse_tensors[4].values.eval(), [4, 14]) - self.assertAllEqual(sparse_tensors[4].dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensors[5].indices.eval(), [[0, 0], [2, 0], - [3, 0]]) - self.assertAllEqual(sparse_tensors[5].values.eval(), [5, 25, 35]) - self.assertAllEqual(sparse_tensors[5].dense_shape.eval(), [4, 1]) + self.assertAllEqual(sparse_tensors[0].indices, [[0, 0], [2, 0], [3, 0]]) + self.assertAllEqual(sparse_tensors[0].values, [0, 20, 30]) + self.assertAllEqual(sparse_tensors[0].dense_shape, [4, 1]) + self.assertAllEqual(sparse_tensors[1].indices, [[1, 0]]) + self.assertAllEqual(sparse_tensors[1].values, [11]) + self.assertAllEqual(sparse_tensors[1].dense_shape, [4, 1]) + self.assertAllEqual(sparse_tensors[2].indices, [[0, 0], [3, 0]]) + self.assertAllEqual(sparse_tensors[2].values, [2, 32]) + self.assertAllEqual(sparse_tensors[2].dense_shape, [4, 1]) + self.assertAllEqual(sparse_tensors[3].indices, [[1, 0], [2, 0], [3, 0]]) + self.assertAllEqual(sparse_tensors[3].dense_shape, [4, 1]) + self.assertAllEqual(sparse_tensors[3].values, [13, 23, 33]) + self.assertAllEqual(sparse_tensors[4].indices, [[0, 0], [1, 0]]) + self.assertAllEqual(sparse_tensors[4].values, [4, 14]) + self.assertAllEqual(sparse_tensors[4].dense_shape, [4, 1]) + self.assertAllEqual(sparse_tensors[5].indices, [[0, 0], [2, 0], [3, 0]]) + self.assertAllEqual(sparse_tensors[5].values, [5, 25, 35]) + self.assertAllEqual(sparse_tensors[5].dense_shape, [4, 1]) - @test_util.run_deprecated_v1 def testSliceConcat(self): for sp_input in (self._SparseTensorValue_3x4x2(), self._SparseTensor_3x4x2()): - with self.cached_session(use_gpu=False): + for axis in (1, -2): sparse_tensors = sparse_ops.sparse_split( - sp_input=sp_input, num_split=2, axis=1) - concat_tensor = sparse_ops.sparse_concat(1, sparse_tensors) + sp_input=sp_input, num_split=2, axis=axis) + concat_tensor = self.evaluate( + sparse_ops.sparse_concat(1, sparse_tensors)) expected_output = self._SparseTensor_3x4x2() - self.assertAllEqual(concat_tensor.indices.eval(), - expected_output.indices.eval()) + self.assertAllEqual(concat_tensor.indices, expected_output.indices) + + def testInvalidAxis(self): + for axis in (-3, 2): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r'axis should be in range \[-2, 2\)'): + self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis)) def testArgumentErrors(self): with self.assertRaisesRegex(ValueError, 'Keyword arguments are required'): diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 5e956434342..3a145e96f19 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -983,7 +983,9 @@ def sparse_split(keyword_required=KeywordRequired(), keyword_required: Python 2 standin for * (temporary for argument reorder) sp_input: The `SparseTensor` to split. num_split: A Python integer. The number of ways to split. - axis: A 0-D `int32` `Tensor`. The dimension along which to split. + axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in + range [-rank, rank), where rank is the number of dimensions in the input + `SparseTensor`. name: A name for the operation (optional). split_dim: Deprecated old name for axis. @@ -1031,27 +1033,50 @@ def sparse_split_v2(sp_input=None, If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split` each slice starting from 0:`shape[axis] % num_split` gets extra one - dimension. For example, if `axis = 1` and `num_split = 2` and the - input is: + dimension. For example: - input_tensor = shape = [2, 7] - [ a d e ] - [b c ] + >>> indices = [[0, 2], [0, 4], [0, 5], [1, 0], [1, 1]] + >>> values = [1, 2, 3, 4, 5] + >>> t = tf.SparseTensor(indices=indices, values=values, dense_shape=[2, 7]) + >>> tf.sparse.to_dense(t) + - Graphically the output tensors are: + >>> output = tf.sparse.split(sp_input=t, num_split=2, axis=1) + >>> tf.sparse.to_dense(output[0]) + + >>> tf.sparse.to_dense(output[1]) + - output_tensor[0] = - [ a ] - [b c ] + >>> output = tf.sparse.split(sp_input=t, num_split=2, axis=0) + >>> tf.sparse.to_dense(output[0]) + + >>> tf.sparse.to_dense(output[1]) + - output_tensor[1] = - [ d e ] - [ ] + >>> output = tf.sparse.split(sp_input=t, num_split=2, axis=-1) + >>> tf.sparse.to_dense(output[0]) + + >>> tf.sparse.to_dense(output[1]) + Args: sp_input: The `SparseTensor` to split. num_split: A Python integer. The number of ways to split. - axis: A 0-D `int32` `Tensor`. The dimension along which to split. + axis: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in + range [-rank, rank), where rank is the number of dimensions in the input + `SparseTensor`. name: A name for the operation (optional). Returns: