Update tf.sparse.reshape to support a wider variety of partially known shapes.

Specifically, if st.dense_shape == tf.constant([x, y, z])[2:], and the reshape
value is compatible, the output dense_shape can now be fully known
(in graph mode).

PiperOrigin-RevId: 263627533
This commit is contained in:
Eugene Brevdo 2019-08-15 13:17:39 -07:00 committed by TensorFlower Gardener
parent 162e9cc3e6
commit 9c6cf98672
2 changed files with 48 additions and 19 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@ -79,6 +80,18 @@ class SparseReshapeTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
sparse_ops.sparse_reshape(sp_input, shape=(-1, 7))
@test_util.run_deprecated_v1
def testPropagatesFullyKnownDenseShapeWhenShapePartiallyKnown(self):
sp_input = sparse_tensor.SparseTensor.from_value(
self._SparseTensorValue_2x3x4())
self.assertAllEqual((2, 3, 4), sp_input.shape)
sp_output = sparse_ops.sparse_reshape(
sp_input, shape=array_ops.concat(
(constant_op.constant([2], dtype=dtypes.int64),
array_ops.placeholder(dtype=dtypes.int64, shape=[1])),
axis=0))
self.assertAllEqual((2, 3 * 4), sp_output.shape)
def testSameShape(self):
with self.session(use_gpu=False) as sess:
input_val = self._SparseTensorValue_5x6()

View File

@ -27,6 +27,7 @@ import numbers
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -785,29 +786,44 @@ def sparse_reshape(sp_input, shape, name=None):
reshaped_ind, reshaped_shape = gen_sparse_ops.sparse_reshape(
sp_input.indices, sp_input.dense_shape, shape, name=name)
reshaped_shape_const = tensor_util.constant_value(shape)
if (reshaped_shape_const is not None and
sp_input.get_shape().is_fully_defined()):
num_implied = sum((dim == -1) for dim in reshaped_shape_const)
if num_implied > 1:
raise ValueError("At most one dimension can be inferred (-1). Found: %s"
% reshaped_shape_const)
original_reshaped_shape = list(reshaped_shape_const) # Copy.
in_shape_size = np.prod(sp_input.get_shape().as_list())
if num_implied:
implied_idx = original_reshaped_shape.index(-1)
reshaped_shape_const = tensor_util.constant_value_as_shape(shape)
reshaped_shape_const = (
reshaped_shape_const.as_list() if reshaped_shape_const.ndims is not None
else None)
if (reshaped_shape_const is not None
and sp_input.shape.is_fully_defined()):
# constant_value_as_shape tends to get more information about the partial
# shape values, but here we specifically need to know if the *user* passed
# a shape with 2+ unknown dimensions; and for that constant_value
# provides either the user's direct value or None if only partial elements
# are known via the python shape inference code.
shape_const_by_user = tensor_util.constant_value(shape)
if shape_const_by_user is not None:
num_implied_by_user = sum(d == -1 for d in shape_const_by_user)
if num_implied_by_user > 1:
raise ValueError(
"At most one dimension can be inferred (-1). Found: %s"
% shape_const_by_user)
original_reshaped_shape = list(reshaped_shape_const) # A copy
in_shape_size = np.prod(sp_input.shape.as_list())
num_implied = sum(dim is None for dim in reshaped_shape_const)
if num_implied == 1:
implied_idx = original_reshaped_shape.index(None)
non_implied_idx = (
original_reshaped_shape[:implied_idx] +
original_reshaped_shape[implied_idx + 1:])
reshaped_shape_const[implied_idx] = (
reshaped_shape_const[implied_idx] = int(
in_shape_size // np.prod(non_implied_idx))
reshaped_size = np.prod(reshaped_shape_const)
if reshaped_size != in_shape_size:
raise ValueError("Cannot reshape a tensor with %d elements to shape %s "
"(%d elements)." %
(in_shape_size, original_reshaped_shape,
reshaped_size))
reshaped_shape = reshaped_shape_const
if num_implied <= 1:
reshaped_size = np.prod(reshaped_shape_const)
if reshaped_size != in_shape_size:
raise ValueError(
"Cannot reshape a tensor with %d elements to shape %s "
"(%d elements)." %
(in_shape_size, original_reshaped_shape, reshaped_size))
reshaped_shape = constant_op.constant(
reshaped_shape_const, dtype=dtypes.int64)
return sparse_tensor.SparseTensor(reshaped_ind,
array_ops.identity(sp_input.values),