From 7c19b82af4b31adbdea87b743d042b51fa0eeb34 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Sep 2017 12:25:01 -0700 Subject: [PATCH] Update tf.sparse_reset_shape so that when shrinking the shape of an empty sparse tensor, the result has a shape of all zeros. PiperOrigin-RevId: 168419639 --- .../python/kernel_tests/sparse_ops_test.py | 18 ++++++++++++++++++ tensorflow/python/ops/sparse_ops.py | 10 ++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 9161b8c5d1c..1ab78a07784 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -325,6 +325,13 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): constant_op.constant(self._VAL_2_5_6, dtypes.int32), constant_op.constant(self._SHP_2_5_6, dtypes.int64)) + def _SparseTensor_2x5x6_Empty(self): + return sparse_tensor.SparseTensor( + constant_op.constant( + np.empty(shape=[0, 3], dtype=np.int64), dtypes.int64), + constant_op.constant(np.empty(shape=[0], dtype=np.int32), dtypes.int32), + constant_op.constant(self._SHP_2_5_6, dtypes.int64)) + def _SparseTensorValue_2x5x6(self): return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6, self._SHP_2_5_6) @@ -387,6 +394,17 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33]) self.assertAllEqual(output.dense_shape, [2, 4, 5]) + def testTightBoundingBoxEmpty(self): + with self.test_session(use_gpu=False) as sess: + sp_input = self._SparseTensor_2x5x6_Empty() + sp_output = sparse_ops.sparse_reset_shape(sp_input) + + output = sess.run(sp_output) + + self.assertAllEqual(output.indices.shape, [0, 3]) + self.assertAllEqual(output.values.shape, [0]) + self.assertAllEqual(output.dense_shape, [0, 0, 0]) + def testInvalidRank(self): with self.test_session(use_gpu=False): sp_input = self._SparseTensor_2x5x6() diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index e3990791c62..404041dfe14 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1225,7 +1225,8 @@ def sparse_reset_shape(sp_input, new_shape=None): """Resets the shape of a `SparseTensor` with indices and values unchanged. If `new_shape` is None, returns a copy of `sp_input` with its shape reset - to the tight bounding box of `sp_input`. + to the tight bounding box of `sp_input`. This will be a shape consisting of + all zeros if sp_input has no values. If `new_shape` is provided, then it must be larger or equal in all dimensions compared to the shape of `sp_input`. When this condition is met, the returned @@ -1284,9 +1285,10 @@ def sparse_reset_shape(sp_input, new_shape=None): in_shape = array_ops.identity(sp_input.dense_shape) if new_shape is None: - dim_low_bound = math_ops.reduce_max(in_indices, 0) - output_shape_tensor = math_ops.add(dim_low_bound, - array_ops.ones_like(in_shape)) + dim_low_bound = math_ops.reduce_max(in_indices, axis=0) + output_shape_tensor = math_ops.maximum( + array_ops.constant(0, dtype=dtypes.int64), + math_ops.add(dim_low_bound, array_ops.ones_like(in_shape))) else: output_shape_tensor = ops.convert_to_tensor(new_shape) output_shape_tensor.get_shape().assert_has_rank(1)