From b5acc0bd16619ea3ce7be9ddce63348f04d840b9 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Mon, 20 Apr 2020 20:49:41 -0700 Subject: [PATCH] sparse_reshape(): Skip shape inference in Python if implied and zero dimensions coexist Also add unit tests for - the consistency in reshaping non-zero-sized shape between sparse and regular tensros. - the proper errors thrown when implied and zero dimension coexist. PiperOrigin-RevId: 307532172 Change-Id: Ie3af0ec2199fe3632eef5887d7f10958211a9d89 --- tensorflow/python/kernel_tests/BUILD | 1 + .../kernel_tests/sparse_reshape_op_test.py | 71 ++++++++++++++++++- tensorflow/python/ops/sparse_ops.py | 9 ++- 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 7d320853e8e..a8d9946e3c4 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1014,6 +1014,7 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:sparse_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py index 56aaf4cb557..6ec51bb9735 100644 --- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -41,7 +43,6 @@ class SparseReshapeTest(test.TestCase): ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]).astype(np.int64) val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64) - shape = np.array([5, 6]).astype(np.int64) return sparse_tensor.SparseTensorValue(ind, val, shape) @@ -329,5 +330,73 @@ class SparseReshapeTest(test.TestCase): self.assertAllEqual(output_val.dense_shape, new_shape) +class EmptySparseTensorReshapeTest(test.TestCase, parameterized.TestCase): + """Tests for reshaping 0-sized SparseTensors, compared w/ dense tensors.""" + + def _MakeAndReshapeTensor(self, tensor_class, original_shape, target_shape): + if tensor_class == "sparse": + ind = np.zeros([0, len(original_shape)]).astype(np.int64) + val = np.array([]).astype(np.float64) + shape = np.array(original_shape).astype(np.int64) + sp_input = sparse_tensor.SparseTensorValue(ind, val, shape) + sp_output = self.evaluate( + sparse_ops.sparse_reshape(sp_input, target_shape)) + return sp_output.dense_shape + else: + dense_input = array_ops.zeros(original_shape) + dense_output = self.evaluate(array_ops.reshape(dense_input, target_shape)) + return dense_output.shape + + @parameterized.named_parameters([ + ("Dense", "dense"), + ("Sparse", "sparse"), + ]) + def testImpliedReshapeEmpty1DTensor(self, tensor_class): + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1]), [0, 1]) + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1, 2]), [0, 1, 2]) + + @parameterized.named_parameters([ + ("Dense", "dense"), + ("Sparse", "sparse"), + ]) + def testImpliedReshapeEmpty2DTensor(self, tensor_class): + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 1]), [0, 1]) + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 2, 3]), [0, 2, 3]) + + @parameterized.named_parameters([ + ("Dense", "dense"), + ("Sparse", "sparse"), + ]) + def testImpliedReshapeEmpty3DTensor(self, tensor_class): + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [1, 0, 0], [-1, 2, 3]), + [0, 2, 3]) + + @parameterized.named_parameters([ + ("Dense", "dense"), + ("Sparse", "sparse"), + ]) + def testImpliedReshapeEmpty4DTensor(self, tensor_class): + self.assertAllEqual( + self._MakeAndReshapeTensor(tensor_class, [2, 4, 0, 6], [-1, 4, 6, 2]), + [0, 4, 6, 2]) + + def testImpliedDimTogetherWithZeroDimCausesError(self): + # NOTE: When implied dimensions and zero dimensions coexist in the target + # shape, the behavior currently differs between sparse and regular tensors. + with self.assertRaises(errors.InvalidArgumentError): + self._MakeAndReshapeTensor("sparse", [0], [-1, 0]) + with self.assertRaises(errors.InvalidArgumentError): + self._MakeAndReshapeTensor("sparse", [1, 0], [-1, 0]) + with self.assertRaises(errors.InvalidArgumentError): + self._MakeAndReshapeTensor("sparse", [1, 2, 0], [2, -1, 0]) + with self.assertRaises(errors.InvalidArgumentError): + self._MakeAndReshapeTensor("sparse", [1, 2, 3, 0], [2, 0, -1, 3]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 5096b332364..844aa3c744c 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -860,14 +860,19 @@ def sparse_reshape(sp_input, shape, name=None): original_reshaped_shape = list(reshaped_shape_const) # A copy in_shape_size = np.prod(sp_input.shape.as_list()) num_implied = sum(dim is None for dim in reshaped_shape_const) - if num_implied == 1: + + # If there is a 0 dim in the user-provided shape, we cannot infer the + # unknown dim reliably. This is why we skip the `if` branch below when + # a 0 is present in `reshaped_shape_const`. Same below. + if num_implied == 1 and 0 not in reshaped_shape_const: implied_idx = original_reshaped_shape.index(None) non_implied_idx = ( original_reshaped_shape[:implied_idx] + original_reshaped_shape[implied_idx + 1:]) reshaped_shape_const[implied_idx] = int( in_shape_size // np.prod(non_implied_idx)) - if num_implied <= 1: + if num_implied == 0 or (num_implied == 1 and + 0 not in reshaped_shape_const): reshaped_size = np.prod(reshaped_shape_const) if reshaped_size != in_shape_size: raise ValueError(