Update tf.sparse.reshape to support a wider variety of partially known shapes.
Specifically, if st.dense_shape == tf.constant([x, y, z])[2:], and the reshape value is compatible, the output dense_shape can now be fully known (in graph mode). PiperOrigin-RevId: 263627533
This commit is contained in:
parent
162e9cc3e6
commit
9c6cf98672
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -79,6 +80,18 @@ class SparseReshapeTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
|
||||
sparse_ops.sparse_reshape(sp_input, shape=(-1, 7))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testPropagatesFullyKnownDenseShapeWhenShapePartiallyKnown(self):
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||
self._SparseTensorValue_2x3x4())
|
||||
self.assertAllEqual((2, 3, 4), sp_input.shape)
|
||||
sp_output = sparse_ops.sparse_reshape(
|
||||
sp_input, shape=array_ops.concat(
|
||||
(constant_op.constant([2], dtype=dtypes.int64),
|
||||
array_ops.placeholder(dtype=dtypes.int64, shape=[1])),
|
||||
axis=0))
|
||||
self.assertAllEqual((2, 3 * 4), sp_output.shape)
|
||||
|
||||
def testSameShape(self):
|
||||
with self.session(use_gpu=False) as sess:
|
||||
input_val = self._SparseTensorValue_5x6()
|
||||
|
@ -27,6 +27,7 @@ import numbers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -785,29 +786,44 @@ def sparse_reshape(sp_input, shape, name=None):
|
||||
reshaped_ind, reshaped_shape = gen_sparse_ops.sparse_reshape(
|
||||
sp_input.indices, sp_input.dense_shape, shape, name=name)
|
||||
|
||||
reshaped_shape_const = tensor_util.constant_value(shape)
|
||||
if (reshaped_shape_const is not None and
|
||||
sp_input.get_shape().is_fully_defined()):
|
||||
num_implied = sum((dim == -1) for dim in reshaped_shape_const)
|
||||
if num_implied > 1:
|
||||
raise ValueError("At most one dimension can be inferred (-1). Found: %s"
|
||||
% reshaped_shape_const)
|
||||
original_reshaped_shape = list(reshaped_shape_const) # Copy.
|
||||
in_shape_size = np.prod(sp_input.get_shape().as_list())
|
||||
if num_implied:
|
||||
implied_idx = original_reshaped_shape.index(-1)
|
||||
reshaped_shape_const = tensor_util.constant_value_as_shape(shape)
|
||||
reshaped_shape_const = (
|
||||
reshaped_shape_const.as_list() if reshaped_shape_const.ndims is not None
|
||||
else None)
|
||||
|
||||
if (reshaped_shape_const is not None
|
||||
and sp_input.shape.is_fully_defined()):
|
||||
# constant_value_as_shape tends to get more information about the partial
|
||||
# shape values, but here we specifically need to know if the *user* passed
|
||||
# a shape with 2+ unknown dimensions; and for that constant_value
|
||||
# provides either the user's direct value or None if only partial elements
|
||||
# are known via the python shape inference code.
|
||||
shape_const_by_user = tensor_util.constant_value(shape)
|
||||
if shape_const_by_user is not None:
|
||||
num_implied_by_user = sum(d == -1 for d in shape_const_by_user)
|
||||
if num_implied_by_user > 1:
|
||||
raise ValueError(
|
||||
"At most one dimension can be inferred (-1). Found: %s"
|
||||
% shape_const_by_user)
|
||||
original_reshaped_shape = list(reshaped_shape_const) # A copy
|
||||
in_shape_size = np.prod(sp_input.shape.as_list())
|
||||
num_implied = sum(dim is None for dim in reshaped_shape_const)
|
||||
if num_implied == 1:
|
||||
implied_idx = original_reshaped_shape.index(None)
|
||||
non_implied_idx = (
|
||||
original_reshaped_shape[:implied_idx] +
|
||||
original_reshaped_shape[implied_idx + 1:])
|
||||
reshaped_shape_const[implied_idx] = (
|
||||
reshaped_shape_const[implied_idx] = int(
|
||||
in_shape_size // np.prod(non_implied_idx))
|
||||
reshaped_size = np.prod(reshaped_shape_const)
|
||||
if reshaped_size != in_shape_size:
|
||||
raise ValueError("Cannot reshape a tensor with %d elements to shape %s "
|
||||
"(%d elements)." %
|
||||
(in_shape_size, original_reshaped_shape,
|
||||
reshaped_size))
|
||||
reshaped_shape = reshaped_shape_const
|
||||
if num_implied <= 1:
|
||||
reshaped_size = np.prod(reshaped_shape_const)
|
||||
if reshaped_size != in_shape_size:
|
||||
raise ValueError(
|
||||
"Cannot reshape a tensor with %d elements to shape %s "
|
||||
"(%d elements)." %
|
||||
(in_shape_size, original_reshaped_shape, reshaped_size))
|
||||
reshaped_shape = constant_op.constant(
|
||||
reshaped_shape_const, dtype=dtypes.int64)
|
||||
|
||||
return sparse_tensor.SparseTensor(reshaped_ind,
|
||||
array_ops.identity(sp_input.values),
|
||||
|
Loading…
x
Reference in New Issue
Block a user