Fix sparse_concat for how it handles output tensor shape.

It was using inputs.dense_shape and convert it if it is static value. It would miss the static shape information if the inputs are created via sparse_placeholder, whose static shape information is populated after creation.

Update the logic to rely inputs.shape, which should be the source of truth.

Also update the sparse_ops_test and keras merge_test as an e2e verification.

Fix https://github.com/tensorflow/tensorflow/issues/45054

PiperOrigin-RevId: 350612688
Change-Id: Id69e600f3c80207f1c8626f531676fa4340286c1
This commit is contained in:
Scott Zhu 2021-01-07 12:11:14 -08:00 committed by TensorFlower Gardener
parent 3a6b834adb
commit 8438d59abb
4 changed files with 50 additions and 7 deletions

View File

@ -684,6 +684,7 @@ tf_py_test(
size = "medium",
srcs = ["merge_test.py"],
python_version = "PY3",
shard_count = 4,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",

View File

@ -373,6 +373,20 @@ class MergeLayersTestNoExecution(test.TestCase):
mask = layer.output_mask
self.assertListEqual(mask.shape.as_list(), [None, 4])
def test_merge_concatenate_sparse_shape(self):
i1 = keras.layers.Input(shape=(1,), batch_size=2, sparse=True)
i2 = keras.layers.Input(shape=(2,), batch_size=2, sparse=True)
layer = keras.layers.Concatenate(axis=1)
o = layer([i1, i2])
self.assertListEqual(o.shape.as_list(), [2, 3])
# Make sure it also respect None as the batch size
i1 = keras.layers.Input(shape=(1,), sparse=True)
i2 = keras.layers.Input(shape=(2,), sparse=True)
layer = keras.layers.Concatenate(axis=1)
o = layer([i1, i2])
self.assertListEqual(o.shape.as_list(), [None, 3])
def test_user_changes_to_input_structure(self):
a = keras.layers.Input(shape=(4, 5))
struct = [a, a]

View File

@ -419,13 +419,32 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None):
output_ind, output_val, output_shape = (
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
shapes_value = [tensor_util.constant_value(shape) for shape in shapes]
if shapes_value and all(shape is not None for shape in shapes_value):
dim = sum(shape[axis] for shape in shapes_value)
output_shape = shapes_value[0]
output_shape[axis] = dim
output_shape = ops.convert_to_tensor(output_shape)
input_shapes = [inp.shape for inp in sp_inputs]
if all(shape.rank is not None for shape in input_shapes):
if expand_nonconcat_dims:
static_output_shape = []
for dim in range(input_shapes[0].rank):
static_output_shape.append(
max(tensor_shape.dimension_at_index(shape, dim)
for shape in input_shapes))
else:
static_output_shape = input_shapes[0].as_list()
static_output_shape[axis] = sum(
tensor_shape.dimension_at_index(shape, axis)
for shape in input_shapes)
else:
static_output_shape = tensor_shape.unknown_shape()
if all(shape.is_fully_defined() for shape in input_shapes):
output_shape = ops.convert_to_tensor(static_output_shape,
dtype=dtypes.int64)
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
else:
# In case there are partially defined shape, we couldn't update the
# output_shape tensor value. We update the output._dense_shape_default,
# which populate output.shape as the best effort.
output = sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
output._dense_shape_default = tensor_shape.TensorShape(static_output_shape)
return output
sparse_concat_v2.__doc__ = sparse_concat.__doc__.replace(

View File

@ -287,6 +287,15 @@ class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
def testSparseConcatStaticShape(self):
if context.executing_eagerly():
self.skipTest('sparse_spaceholder is only available in graph context.')
input_a = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 1))
input_b = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 2))
result = sparse_ops.sparse_concat_v2(axis=1, sp_inputs=[input_a, input_b])
self.assertEqual(result.shape, [2, 3])
if __name__ == '__main__':
googletest.main()