diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 1caff66c651..de295f1466a 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -684,6 +684,7 @@ tf_py_test( size = "medium", srcs = ["merge_test.py"], python_version = "PY3", + shard_count = 4, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py index 6778c595648..c0a5813cce3 100644 --- a/tensorflow/python/keras/layers/merge_test.py +++ b/tensorflow/python/keras/layers/merge_test.py @@ -373,6 +373,20 @@ class MergeLayersTestNoExecution(test.TestCase): mask = layer.output_mask self.assertListEqual(mask.shape.as_list(), [None, 4]) + def test_merge_concatenate_sparse_shape(self): + i1 = keras.layers.Input(shape=(1,), batch_size=2, sparse=True) + i2 = keras.layers.Input(shape=(2,), batch_size=2, sparse=True) + layer = keras.layers.Concatenate(axis=1) + o = layer([i1, i2]) + self.assertListEqual(o.shape.as_list(), [2, 3]) + + # Make sure it also respect None as the batch size + i1 = keras.layers.Input(shape=(1,), sparse=True) + i2 = keras.layers.Input(shape=(2,), sparse=True) + layer = keras.layers.Concatenate(axis=1) + o = layer([i1, i2]) + self.assertListEqual(o.shape.as_list(), [None, 3]) + def test_user_changes_to_input_structure(self): a = keras.layers.Input(shape=(4, 5)) struct = [a, a] diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 501b07ded0c..4b70dbeba63 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -419,13 +419,32 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None): output_ind, output_val, output_shape = ( gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name)) - shapes_value = [tensor_util.constant_value(shape) for shape in shapes] - if shapes_value and all(shape is not None for shape in shapes_value): - dim = sum(shape[axis] for shape in shapes_value) - output_shape = shapes_value[0] - output_shape[axis] = dim - output_shape = ops.convert_to_tensor(output_shape) - return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) + input_shapes = [inp.shape for inp in sp_inputs] + if all(shape.rank is not None for shape in input_shapes): + if expand_nonconcat_dims: + static_output_shape = [] + for dim in range(input_shapes[0].rank): + static_output_shape.append( + max(tensor_shape.dimension_at_index(shape, dim) + for shape in input_shapes)) + else: + static_output_shape = input_shapes[0].as_list() + static_output_shape[axis] = sum( + tensor_shape.dimension_at_index(shape, axis) + for shape in input_shapes) + else: + static_output_shape = tensor_shape.unknown_shape() + if all(shape.is_fully_defined() for shape in input_shapes): + output_shape = ops.convert_to_tensor(static_output_shape, + dtype=dtypes.int64) + return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) + else: + # In case there are partially defined shape, we couldn't update the + # output_shape tensor value. We update the output._dense_shape_default, + # which populate output.shape as the best effort. + output = sparse_tensor.SparseTensor(output_ind, output_val, output_shape) + output._dense_shape_default = tensor_shape.TensorShape(static_output_shape) + return output sparse_concat_v2.__doc__ = sparse_concat.__doc__.replace( diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py index 74e150ad6c7..886ba2eb9cb 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/ops/sparse_ops_test.py @@ -287,6 +287,15 @@ class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): gen_sparse_ops.SparseFillEmptyRowsGrad( reverse_index_map=reverse_index_map, grad_values=grad_values)) + def testSparseConcatStaticShape(self): + if context.executing_eagerly(): + self.skipTest('sparse_spaceholder is only available in graph context.') + input_a = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 1)) + input_b = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 2)) + + result = sparse_ops.sparse_concat_v2(axis=1, sp_inputs=[input_a, input_b]) + self.assertEqual(result.shape, [2, 3]) + if __name__ == '__main__': googletest.main()