Merge pull request #24018 from yongtang:21964-sparse.concat-shape
PiperOrigin-RevId: 248764459
This commit is contained in:
commit
703e6c7e09
tensorflow/python
@ -342,6 +342,15 @@ class SparseConcatTest(test.TestCase):
|
|||||||
self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
|
self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
|
||||||
self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
|
self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
|
||||||
|
|
||||||
|
def testConcatShape(self):
|
||||||
|
# Test case for GitHub 21964.
|
||||||
|
x = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
|
||||||
|
y = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
|
||||||
|
z = sparse_ops.sparse_concat(-1, [x, y])
|
||||||
|
self.assertEqual(z.get_shape().as_list(), [2, 4])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -331,6 +331,12 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None):
|
|||||||
output_ind, output_val, output_shape = (
|
output_ind, output_val, output_shape = (
|
||||||
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
|
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
|
||||||
|
|
||||||
|
shapes_value = [tensor_util.constant_value(shape) for shape in shapes]
|
||||||
|
if shapes_value and all(shape is not None for shape in shapes_value):
|
||||||
|
dim = sum(shape[axis] for shape in shapes_value)
|
||||||
|
output_shape = shapes_value[0]
|
||||||
|
output_shape[axis] = dim
|
||||||
|
output_shape = ops.convert_to_tensor(output_shape)
|
||||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user