sparse_reshape(): Skip shape inference in Python if implied and zero dimensions coexist
Also add unit tests for - the consistency in reshaping non-zero-sized shape between sparse and regular tensros. - the proper errors thrown when implied and zero dimension coexist. PiperOrigin-RevId: 307532172 Change-Id: Ie3af0ec2199fe3632eef5887d7f10958211a9d89
This commit is contained in:
parent
52d3040e6f
commit
b5acc0bd16
@ -1014,6 +1014,7 @@ tf_py_test(
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,10 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,7 +43,6 @@ class SparseReshapeTest(test.TestCase):
|
||||
ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
|
||||
[3, 3]]).astype(np.int64)
|
||||
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64)
|
||||
|
||||
shape = np.array([5, 6]).astype(np.int64)
|
||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
|
||||
@ -329,5 +330,73 @@ class SparseReshapeTest(test.TestCase):
|
||||
self.assertAllEqual(output_val.dense_shape, new_shape)
|
||||
|
||||
|
||||
class EmptySparseTensorReshapeTest(test.TestCase, parameterized.TestCase):
|
||||
"""Tests for reshaping 0-sized SparseTensors, compared w/ dense tensors."""
|
||||
|
||||
def _MakeAndReshapeTensor(self, tensor_class, original_shape, target_shape):
|
||||
if tensor_class == "sparse":
|
||||
ind = np.zeros([0, len(original_shape)]).astype(np.int64)
|
||||
val = np.array([]).astype(np.float64)
|
||||
shape = np.array(original_shape).astype(np.int64)
|
||||
sp_input = sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
sp_output = self.evaluate(
|
||||
sparse_ops.sparse_reshape(sp_input, target_shape))
|
||||
return sp_output.dense_shape
|
||||
else:
|
||||
dense_input = array_ops.zeros(original_shape)
|
||||
dense_output = self.evaluate(array_ops.reshape(dense_input, target_shape))
|
||||
return dense_output.shape
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Dense", "dense"),
|
||||
("Sparse", "sparse"),
|
||||
])
|
||||
def testImpliedReshapeEmpty1DTensor(self, tensor_class):
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1]), [0, 1])
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1, 2]), [0, 1, 2])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Dense", "dense"),
|
||||
("Sparse", "sparse"),
|
||||
])
|
||||
def testImpliedReshapeEmpty2DTensor(self, tensor_class):
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 1]), [0, 1])
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 2, 3]), [0, 2, 3])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Dense", "dense"),
|
||||
("Sparse", "sparse"),
|
||||
])
|
||||
def testImpliedReshapeEmpty3DTensor(self, tensor_class):
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [1, 0, 0], [-1, 2, 3]),
|
||||
[0, 2, 3])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("Dense", "dense"),
|
||||
("Sparse", "sparse"),
|
||||
])
|
||||
def testImpliedReshapeEmpty4DTensor(self, tensor_class):
|
||||
self.assertAllEqual(
|
||||
self._MakeAndReshapeTensor(tensor_class, [2, 4, 0, 6], [-1, 4, 6, 2]),
|
||||
[0, 4, 6, 2])
|
||||
|
||||
def testImpliedDimTogetherWithZeroDimCausesError(self):
|
||||
# NOTE: When implied dimensions and zero dimensions coexist in the target
|
||||
# shape, the behavior currently differs between sparse and regular tensors.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._MakeAndReshapeTensor("sparse", [0], [-1, 0])
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._MakeAndReshapeTensor("sparse", [1, 0], [-1, 0])
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._MakeAndReshapeTensor("sparse", [1, 2, 0], [2, -1, 0])
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self._MakeAndReshapeTensor("sparse", [1, 2, 3, 0], [2, 0, -1, 3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -860,14 +860,19 @@ def sparse_reshape(sp_input, shape, name=None):
|
||||
original_reshaped_shape = list(reshaped_shape_const) # A copy
|
||||
in_shape_size = np.prod(sp_input.shape.as_list())
|
||||
num_implied = sum(dim is None for dim in reshaped_shape_const)
|
||||
if num_implied == 1:
|
||||
|
||||
# If there is a 0 dim in the user-provided shape, we cannot infer the
|
||||
# unknown dim reliably. This is why we skip the `if` branch below when
|
||||
# a 0 is present in `reshaped_shape_const`. Same below.
|
||||
if num_implied == 1 and 0 not in reshaped_shape_const:
|
||||
implied_idx = original_reshaped_shape.index(None)
|
||||
non_implied_idx = (
|
||||
original_reshaped_shape[:implied_idx] +
|
||||
original_reshaped_shape[implied_idx + 1:])
|
||||
reshaped_shape_const[implied_idx] = int(
|
||||
in_shape_size // np.prod(non_implied_idx))
|
||||
if num_implied <= 1:
|
||||
if num_implied == 0 or (num_implied == 1 and
|
||||
0 not in reshaped_shape_const):
|
||||
reshaped_size = np.prod(reshaped_shape_const)
|
||||
if reshaped_size != in_shape_size:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user