sparse_reshape(): Skip shape inference in Python if implied and zero dimensions coexist
Also add unit tests for - the consistency in reshaping non-zero-sized shape between sparse and regular tensros. - the proper errors thrown when implied and zero dimension coexist. PiperOrigin-RevId: 307532172 Change-Id: Ie3af0ec2199fe3632eef5887d7f10958211a9d89
This commit is contained in:
parent
52d3040e6f
commit
b5acc0bd16
@ -1014,6 +1014,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,10 +18,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -41,7 +43,6 @@ class SparseReshapeTest(test.TestCase):
|
|||||||
ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
|
ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
|
||||||
[3, 3]]).astype(np.int64)
|
[3, 3]]).astype(np.int64)
|
||||||
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64)
|
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.float64)
|
||||||
|
|
||||||
shape = np.array([5, 6]).astype(np.int64)
|
shape = np.array([5, 6]).astype(np.int64)
|
||||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||||
|
|
||||||
@ -329,5 +330,73 @@ class SparseReshapeTest(test.TestCase):
|
|||||||
self.assertAllEqual(output_val.dense_shape, new_shape)
|
self.assertAllEqual(output_val.dense_shape, new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySparseTensorReshapeTest(test.TestCase, parameterized.TestCase):
|
||||||
|
"""Tests for reshaping 0-sized SparseTensors, compared w/ dense tensors."""
|
||||||
|
|
||||||
|
def _MakeAndReshapeTensor(self, tensor_class, original_shape, target_shape):
|
||||||
|
if tensor_class == "sparse":
|
||||||
|
ind = np.zeros([0, len(original_shape)]).astype(np.int64)
|
||||||
|
val = np.array([]).astype(np.float64)
|
||||||
|
shape = np.array(original_shape).astype(np.int64)
|
||||||
|
sp_input = sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||||
|
sp_output = self.evaluate(
|
||||||
|
sparse_ops.sparse_reshape(sp_input, target_shape))
|
||||||
|
return sp_output.dense_shape
|
||||||
|
else:
|
||||||
|
dense_input = array_ops.zeros(original_shape)
|
||||||
|
dense_output = self.evaluate(array_ops.reshape(dense_input, target_shape))
|
||||||
|
return dense_output.shape
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Dense", "dense"),
|
||||||
|
("Sparse", "sparse"),
|
||||||
|
])
|
||||||
|
def testImpliedReshapeEmpty1DTensor(self, tensor_class):
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1]), [0, 1])
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [0], [-1, 1, 2]), [0, 1, 2])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Dense", "dense"),
|
||||||
|
("Sparse", "sparse"),
|
||||||
|
])
|
||||||
|
def testImpliedReshapeEmpty2DTensor(self, tensor_class):
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 1]), [0, 1])
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [1, 0], [-1, 2, 3]), [0, 2, 3])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Dense", "dense"),
|
||||||
|
("Sparse", "sparse"),
|
||||||
|
])
|
||||||
|
def testImpliedReshapeEmpty3DTensor(self, tensor_class):
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [1, 0, 0], [-1, 2, 3]),
|
||||||
|
[0, 2, 3])
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
("Dense", "dense"),
|
||||||
|
("Sparse", "sparse"),
|
||||||
|
])
|
||||||
|
def testImpliedReshapeEmpty4DTensor(self, tensor_class):
|
||||||
|
self.assertAllEqual(
|
||||||
|
self._MakeAndReshapeTensor(tensor_class, [2, 4, 0, 6], [-1, 4, 6, 2]),
|
||||||
|
[0, 4, 6, 2])
|
||||||
|
|
||||||
|
def testImpliedDimTogetherWithZeroDimCausesError(self):
|
||||||
|
# NOTE: When implied dimensions and zero dimensions coexist in the target
|
||||||
|
# shape, the behavior currently differs between sparse and regular tensors.
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self._MakeAndReshapeTensor("sparse", [0], [-1, 0])
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self._MakeAndReshapeTensor("sparse", [1, 0], [-1, 0])
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self._MakeAndReshapeTensor("sparse", [1, 2, 0], [2, -1, 0])
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self._MakeAndReshapeTensor("sparse", [1, 2, 3, 0], [2, 0, -1, 3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -860,14 +860,19 @@ def sparse_reshape(sp_input, shape, name=None):
|
|||||||
original_reshaped_shape = list(reshaped_shape_const) # A copy
|
original_reshaped_shape = list(reshaped_shape_const) # A copy
|
||||||
in_shape_size = np.prod(sp_input.shape.as_list())
|
in_shape_size = np.prod(sp_input.shape.as_list())
|
||||||
num_implied = sum(dim is None for dim in reshaped_shape_const)
|
num_implied = sum(dim is None for dim in reshaped_shape_const)
|
||||||
if num_implied == 1:
|
|
||||||
|
# If there is a 0 dim in the user-provided shape, we cannot infer the
|
||||||
|
# unknown dim reliably. This is why we skip the `if` branch below when
|
||||||
|
# a 0 is present in `reshaped_shape_const`. Same below.
|
||||||
|
if num_implied == 1 and 0 not in reshaped_shape_const:
|
||||||
implied_idx = original_reshaped_shape.index(None)
|
implied_idx = original_reshaped_shape.index(None)
|
||||||
non_implied_idx = (
|
non_implied_idx = (
|
||||||
original_reshaped_shape[:implied_idx] +
|
original_reshaped_shape[:implied_idx] +
|
||||||
original_reshaped_shape[implied_idx + 1:])
|
original_reshaped_shape[implied_idx + 1:])
|
||||||
reshaped_shape_const[implied_idx] = int(
|
reshaped_shape_const[implied_idx] = int(
|
||||||
in_shape_size // np.prod(non_implied_idx))
|
in_shape_size // np.prod(non_implied_idx))
|
||||||
if num_implied <= 1:
|
if num_implied == 0 or (num_implied == 1 and
|
||||||
|
0 not in reshaped_shape_const):
|
||||||
reshaped_size = np.prod(reshaped_shape_const)
|
reshaped_size = np.prod(reshaped_shape_const)
|
||||||
if reshaped_size != in_shape_size:
|
if reshaped_size != in_shape_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user