sparse_reshape(): Skip shape inference in Python if implied and zero dimensions coexist

Also add unit tests for
- the consistency in reshaping non-zero-sized shape between sparse and regular
  tensros.
- the proper errors thrown when implied and zero dimension coexist.

PiperOrigin-RevId: 307532172
Change-Id: Ie3af0ec2199fe3632eef5887d7f10958211a9d89
This commit is contained in:
Shanqing Cai 2020-04-20 20:49:41 -07:00 committed by TensorFlower Gardener
parent 52d3040e6f
commit b5acc0bd16
3 changed files with 78 additions and 3 deletions

View File

@ -1014,6 +1014,7 @@ tf_py_test(
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_ops",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -18,10 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -41,7 +43,6 @@ class SparseReshapeTest(test.TestCase):
ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
[3, 3]]).astype(np.int64) [3, 3]]).astype(np.int64)
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64) val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64)
shape = np.array([5, 6]).astype(np.int64) shape = np.array([5, 6]).astype(np.int64)
return sparse_tensor.SparseTensorValue(ind, val, shape) return sparse_tensor.SparseTensorValue(ind, val, shape)
@ -329,5 +330,73 @@ class SparseReshapeTest(test.TestCase):
self.assertAllEqual(output_val.dense_shape, new_shape) self.assertAllEqual(output_val.dense_shape, new_shape)
class EmptySparseTensorReshapeTest(test.TestCase, parameterized.TestCase):
"""Tests for reshaping 0-sized SparseTensors, compared w/ dense tensors."""
def _MakeAndReshapeTensor(self, tensor_class, original_shape, target_shape):
if tensor_class == "sparse":
ind = np.zeros([0, len(original_shape)]).astype(np.int64)
val = np.array([]).astype(np.float64)
shape = np.array(original_shape).astype(np.int64)
sp_input = sparse_tensor.SparseTensorValue(ind, val, shape)
sp_output = self.evaluate(
sparse_ops.sparse_reshape(sp_input, target_shape))
return sp_output.dense_shape
else:
dense_input = array_ops.zeros(original_shape)
dense_output = self.evaluate(array_ops.reshape(dense_input, target_shape))
return dense_output.shape
@parameterized.named_parameters([
("Dense", "dense"),
("Sparse", "sparse"),
])
def testImpliedReshapeEmpty1DTensor(self, tensor_class):
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1]), [0, 1])
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1, 2]), [0, 1, 2])
@parameterized.named_parameters([
("Dense", "dense"),
("Sparse", "sparse"),
])
def testImpliedReshapeEmpty2DTensor(self, tensor_class):
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 1]), [0, 1])
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 2, 3]), [0, 2, 3])
@parameterized.named_parameters([
("Dense", "dense"),
("Sparse", "sparse"),
])
def testImpliedReshapeEmpty3DTensor(self, tensor_class):
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [1, 0, 0], [-1, 2, 3]),
[0, 2, 3])
@parameterized.named_parameters([
("Dense", "dense"),
("Sparse", "sparse"),
])
def testImpliedReshapeEmpty4DTensor(self, tensor_class):
self.assertAllEqual(
self._MakeAndReshapeTensor(tensor_class, [2, 4, 0, 6], [-1, 4, 6, 2]),
[0, 4, 6, 2])
def testImpliedDimTogetherWithZeroDimCausesError(self):
# NOTE: When implied dimensions and zero dimensions coexist in the target
# shape, the behavior currently differs between sparse and regular tensors.
with self.assertRaises(errors.InvalidArgumentError):
self._MakeAndReshapeTensor("sparse", [0], [-1, 0])
with self.assertRaises(errors.InvalidArgumentError):
self._MakeAndReshapeTensor("sparse", [1, 0], [-1, 0])
with self.assertRaises(errors.InvalidArgumentError):
self._MakeAndReshapeTensor("sparse", [1, 2, 0], [2, -1, 0])
with self.assertRaises(errors.InvalidArgumentError):
self._MakeAndReshapeTensor("sparse", [1, 2, 3, 0], [2, 0, -1, 3])
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -860,14 +860,19 @@ def sparse_reshape(sp_input, shape, name=None):
original_reshaped_shape = list(reshaped_shape_const) # A copy original_reshaped_shape = list(reshaped_shape_const) # A copy
in_shape_size = np.prod(sp_input.shape.as_list()) in_shape_size = np.prod(sp_input.shape.as_list())
num_implied = sum(dim is None for dim in reshaped_shape_const) num_implied = sum(dim is None for dim in reshaped_shape_const)
if num_implied == 1:
# If there is a 0 dim in the user-provided shape, we cannot infer the
# unknown dim reliably. This is why we skip the `if` branch below when
# a 0 is present in `reshaped_shape_const`. Same below.
if num_implied == 1 and 0 not in reshaped_shape_const:
implied_idx = original_reshaped_shape.index(None) implied_idx = original_reshaped_shape.index(None)
non_implied_idx = ( non_implied_idx = (
original_reshaped_shape[:implied_idx] + original_reshaped_shape[:implied_idx] +
original_reshaped_shape[implied_idx + 1:]) original_reshaped_shape[implied_idx + 1:])
reshaped_shape_const[implied_idx] = int( reshaped_shape_const[implied_idx] = int(
in_shape_size // np.prod(non_implied_idx)) in_shape_size // np.prod(non_implied_idx))
if num_implied <= 1: if num_implied == 0 or (num_implied == 1 and
0 not in reshaped_shape_const):
reshaped_size = np.prod(reshaped_shape_const) reshaped_size = np.prod(reshaped_shape_const)
if reshaped_size != in_shape_size: if reshaped_size != in_shape_size:
raise ValueError( raise ValueError(