Merge pull request #24018 from yongtang:21964-sparse.concat-shape
PiperOrigin-RevId: 248764459
This commit is contained in:
commit
703e6c7e09
tensorflow/python
@ -342,6 +342,15 @@ class SparseConcatTest(test.TestCase):
|
||||
self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
|
||||
self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
|
||||
|
||||
def testConcatShape(self):
|
||||
# Test case for GitHub 21964.
|
||||
x = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
|
||||
y = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
|
||||
z = sparse_ops.sparse_concat(-1, [x, y])
|
||||
self.assertEqual(z.get_shape().as_list(), [2, 4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -331,6 +331,12 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None):
|
||||
output_ind, output_val, output_shape = (
|
||||
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
|
||||
|
||||
shapes_value = [tensor_util.constant_value(shape) for shape in shapes]
|
||||
if shapes_value and all(shape is not None for shape in shapes_value):
|
||||
dim = sum(shape[axis] for shape in shapes_value)
|
||||
output_shape = shapes_value[0]
|
||||
output_shape[axis] = dim
|
||||
output_shape = ops.convert_to_tensor(output_shape)
|
||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user