Update tf.sparse_reset_shape so that when shrinking the shape of an empty
sparse tensor, the result has a shape of all zeros. PiperOrigin-RevId: 168419639
This commit is contained in:
parent
fcacb40d4c
commit
7c19b82af4
@ -325,6 +325,13 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
|||||||
constant_op.constant(self._VAL_2_5_6, dtypes.int32),
|
constant_op.constant(self._VAL_2_5_6, dtypes.int32),
|
||||||
constant_op.constant(self._SHP_2_5_6, dtypes.int64))
|
constant_op.constant(self._SHP_2_5_6, dtypes.int64))
|
||||||
|
|
||||||
|
def _SparseTensor_2x5x6_Empty(self):
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
constant_op.constant(
|
||||||
|
np.empty(shape=[0, 3], dtype=np.int64), dtypes.int64),
|
||||||
|
constant_op.constant(np.empty(shape=[0], dtype=np.int32), dtypes.int32),
|
||||||
|
constant_op.constant(self._SHP_2_5_6, dtypes.int64))
|
||||||
|
|
||||||
def _SparseTensorValue_2x5x6(self):
|
def _SparseTensorValue_2x5x6(self):
|
||||||
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
||||||
self._SHP_2_5_6)
|
self._SHP_2_5_6)
|
||||||
@ -387,6 +394,17 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
|
self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
|
||||||
self.assertAllEqual(output.dense_shape, [2, 4, 5])
|
self.assertAllEqual(output.dense_shape, [2, 4, 5])
|
||||||
|
|
||||||
|
def testTightBoundingBoxEmpty(self):
|
||||||
|
with self.test_session(use_gpu=False) as sess:
|
||||||
|
sp_input = self._SparseTensor_2x5x6_Empty()
|
||||||
|
sp_output = sparse_ops.sparse_reset_shape(sp_input)
|
||||||
|
|
||||||
|
output = sess.run(sp_output)
|
||||||
|
|
||||||
|
self.assertAllEqual(output.indices.shape, [0, 3])
|
||||||
|
self.assertAllEqual(output.values.shape, [0])
|
||||||
|
self.assertAllEqual(output.dense_shape, [0, 0, 0])
|
||||||
|
|
||||||
def testInvalidRank(self):
|
def testInvalidRank(self):
|
||||||
with self.test_session(use_gpu=False):
|
with self.test_session(use_gpu=False):
|
||||||
sp_input = self._SparseTensor_2x5x6()
|
sp_input = self._SparseTensor_2x5x6()
|
||||||
|
@ -1225,7 +1225,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
|||||||
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
|
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
|
||||||
|
|
||||||
If `new_shape` is None, returns a copy of `sp_input` with its shape reset
|
If `new_shape` is None, returns a copy of `sp_input` with its shape reset
|
||||||
to the tight bounding box of `sp_input`.
|
to the tight bounding box of `sp_input`. This will be a shape consisting of
|
||||||
|
all zeros if sp_input has no values.
|
||||||
|
|
||||||
If `new_shape` is provided, then it must be larger or equal in all dimensions
|
If `new_shape` is provided, then it must be larger or equal in all dimensions
|
||||||
compared to the shape of `sp_input`. When this condition is met, the returned
|
compared to the shape of `sp_input`. When this condition is met, the returned
|
||||||
@ -1284,9 +1285,10 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
|||||||
in_shape = array_ops.identity(sp_input.dense_shape)
|
in_shape = array_ops.identity(sp_input.dense_shape)
|
||||||
|
|
||||||
if new_shape is None:
|
if new_shape is None:
|
||||||
dim_low_bound = math_ops.reduce_max(in_indices, 0)
|
dim_low_bound = math_ops.reduce_max(in_indices, axis=0)
|
||||||
output_shape_tensor = math_ops.add(dim_low_bound,
|
output_shape_tensor = math_ops.maximum(
|
||||||
array_ops.ones_like(in_shape))
|
array_ops.constant(0, dtype=dtypes.int64),
|
||||||
|
math_ops.add(dim_low_bound, array_ops.ones_like(in_shape)))
|
||||||
else:
|
else:
|
||||||
output_shape_tensor = ops.convert_to_tensor(new_shape)
|
output_shape_tensor = ops.convert_to_tensor(new_shape)
|
||||||
output_shape_tensor.get_shape().assert_has_rank(1)
|
output_shape_tensor.get_shape().assert_has_rank(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user