Fix sparse_concat for how it handles output tensor shape.
It was using inputs.dense_shape and convert it if it is static value. It would miss the static shape information if the inputs are created via sparse_placeholder, whose static shape information is populated after creation. Update the logic to rely inputs.shape, which should be the source of truth. Also update the sparse_ops_test and keras merge_test as an e2e verification. Fix https://github.com/tensorflow/tensorflow/issues/45054 PiperOrigin-RevId: 350612688 Change-Id: Id69e600f3c80207f1c8626f531676fa4340286c1
This commit is contained in:
parent
3a6b834adb
commit
8438d59abb
@ -684,6 +684,7 @@ tf_py_test(
|
||||
size = "medium",
|
||||
srcs = ["merge_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
|
@ -373,6 +373,20 @@ class MergeLayersTestNoExecution(test.TestCase):
|
||||
mask = layer.output_mask
|
||||
self.assertListEqual(mask.shape.as_list(), [None, 4])
|
||||
|
||||
def test_merge_concatenate_sparse_shape(self):
|
||||
i1 = keras.layers.Input(shape=(1,), batch_size=2, sparse=True)
|
||||
i2 = keras.layers.Input(shape=(2,), batch_size=2, sparse=True)
|
||||
layer = keras.layers.Concatenate(axis=1)
|
||||
o = layer([i1, i2])
|
||||
self.assertListEqual(o.shape.as_list(), [2, 3])
|
||||
|
||||
# Make sure it also respect None as the batch size
|
||||
i1 = keras.layers.Input(shape=(1,), sparse=True)
|
||||
i2 = keras.layers.Input(shape=(2,), sparse=True)
|
||||
layer = keras.layers.Concatenate(axis=1)
|
||||
o = layer([i1, i2])
|
||||
self.assertListEqual(o.shape.as_list(), [None, 3])
|
||||
|
||||
def test_user_changes_to_input_structure(self):
|
||||
a = keras.layers.Input(shape=(4, 5))
|
||||
struct = [a, a]
|
||||
|
@ -419,13 +419,32 @@ def sparse_concat_v2(axis, sp_inputs, expand_nonconcat_dims=False, name=None):
|
||||
output_ind, output_val, output_shape = (
|
||||
gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name))
|
||||
|
||||
shapes_value = [tensor_util.constant_value(shape) for shape in shapes]
|
||||
if shapes_value and all(shape is not None for shape in shapes_value):
|
||||
dim = sum(shape[axis] for shape in shapes_value)
|
||||
output_shape = shapes_value[0]
|
||||
output_shape[axis] = dim
|
||||
output_shape = ops.convert_to_tensor(output_shape)
|
||||
input_shapes = [inp.shape for inp in sp_inputs]
|
||||
if all(shape.rank is not None for shape in input_shapes):
|
||||
if expand_nonconcat_dims:
|
||||
static_output_shape = []
|
||||
for dim in range(input_shapes[0].rank):
|
||||
static_output_shape.append(
|
||||
max(tensor_shape.dimension_at_index(shape, dim)
|
||||
for shape in input_shapes))
|
||||
else:
|
||||
static_output_shape = input_shapes[0].as_list()
|
||||
static_output_shape[axis] = sum(
|
||||
tensor_shape.dimension_at_index(shape, axis)
|
||||
for shape in input_shapes)
|
||||
else:
|
||||
static_output_shape = tensor_shape.unknown_shape()
|
||||
if all(shape.is_fully_defined() for shape in input_shapes):
|
||||
output_shape = ops.convert_to_tensor(static_output_shape,
|
||||
dtype=dtypes.int64)
|
||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||
else:
|
||||
# In case there are partially defined shape, we couldn't update the
|
||||
# output_shape tensor value. We update the output._dense_shape_default,
|
||||
# which populate output.shape as the best effort.
|
||||
output = sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||
output._dense_shape_default = tensor_shape.TensorShape(static_output_shape)
|
||||
return output
|
||||
|
||||
|
||||
sparse_concat_v2.__doc__ = sparse_concat.__doc__.replace(
|
||||
|
@ -287,6 +287,15 @@ class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||
|
||||
def testSparseConcatStaticShape(self):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest('sparse_spaceholder is only available in graph context.')
|
||||
input_a = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 1))
|
||||
input_b = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 2))
|
||||
|
||||
result = sparse_ops.sparse_concat_v2(axis=1, sp_inputs=[input_a, input_b])
|
||||
self.assertEqual(result.shape, [2, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user