Update tf.sparse_reset_shape so that when shrinking the shape of an empty

sparse tensor, the result has a shape of all zeros.

PiperOrigin-RevId: 168419639
This commit is contained in:
A. Unique TensorFlower 2017-09-12 12:25:01 -07:00 committed by TensorFlower Gardener
parent fcacb40d4c
commit 7c19b82af4
2 changed files with 24 additions and 4 deletions

View File

@ -325,6 +325,13 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
constant_op.constant(self._VAL_2_5_6, dtypes.int32),
constant_op.constant(self._SHP_2_5_6, dtypes.int64))
def _SparseTensor_2x5x6_Empty(self):
return sparse_tensor.SparseTensor(
constant_op.constant(
np.empty(shape=[0, 3], dtype=np.int64), dtypes.int64),
constant_op.constant(np.empty(shape=[0], dtype=np.int32), dtypes.int32),
constant_op.constant(self._SHP_2_5_6, dtypes.int64))
def _SparseTensorValue_2x5x6(self):
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
self._SHP_2_5_6)
@ -387,6 +394,17 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
self.assertAllEqual(output.dense_shape, [2, 4, 5])
def testTightBoundingBoxEmpty(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensor_2x5x6_Empty()
sp_output = sparse_ops.sparse_reset_shape(sp_input)
output = sess.run(sp_output)
self.assertAllEqual(output.indices.shape, [0, 3])
self.assertAllEqual(output.values.shape, [0])
self.assertAllEqual(output.dense_shape, [0, 0, 0])
def testInvalidRank(self):
with self.test_session(use_gpu=False):
sp_input = self._SparseTensor_2x5x6()

View File

@ -1225,7 +1225,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
If `new_shape` is None, returns a copy of `sp_input` with its shape reset
to the tight bounding box of `sp_input`.
to the tight bounding box of `sp_input`. This will be a shape consisting of
all zeros if sp_input has no values.
If `new_shape` is provided, then it must be larger or equal in all dimensions
compared to the shape of `sp_input`. When this condition is met, the returned
@ -1284,9 +1285,10 @@ def sparse_reset_shape(sp_input, new_shape=None):
in_shape = array_ops.identity(sp_input.dense_shape)
if new_shape is None:
dim_low_bound = math_ops.reduce_max(in_indices, 0)
output_shape_tensor = math_ops.add(dim_low_bound,
array_ops.ones_like(in_shape))
dim_low_bound = math_ops.reduce_max(in_indices, axis=0)
output_shape_tensor = math_ops.maximum(
array_ops.constant(0, dtype=dtypes.int64),
math_ops.add(dim_low_bound, array_ops.ones_like(in_shape)))
else:
output_shape_tensor = ops.convert_to_tensor(new_shape)
output_shape_tensor.get_shape().assert_has_rank(1)