Merge pull request from yongtang:21964-sparse.concat-shape

PiperOrigin-RevId: 248764459
This commit is contained in:
TensorFlower Gardener 2019-05-17 12:51:38 -07:00
commit 703e6c7e09
2 changed files with 15 additions and 0 deletions
tensorflow/python

View File

@ -342,6 +342,15 @@ class SparseConcatTest(test.TestCase):
self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
def testConcatShape(self):
# Test case for GitHub 21964.
x = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
y = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
z = sparse_ops.sparse_concat(-1, [x, y])
self.assertEqual(z.get_shape().as_list(), [2, 4])
if __name__ == "__main__":
test.main()

View File

@ -331,6 +331,12 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None):
output_ind, output_val, output_shape = (
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
shapes_value = [tensor_util.constant_value(shape) for shape in shapes]
if shapes_value and all(shape is not None for shape in shapes_value):
dim = sum(shape[axis] for shape in shapes_value)
output_shape = shapes_value[0]
output_shape[axis] = dim
output_shape = ops.convert_to_tensor(output_shape)
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)