From af76ffe37f0417f886342553d8ff2f3126185f3e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Apr 2020 17:05:41 -0700 Subject: [PATCH] Fixed StructuredTensor.partition_outer_dimension. The shape output wasn't always right. PiperOrigin-RevId: 309127604 Change-Id: Ie4995f765e57a4a857ee4a902a5ae3042215aefb --- .../ops/structured/structured_tensor.py | 2 +- .../ops/structured/structured_tensor_test.py | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py index a75364df659..6234e21d8fc 100644 --- a/tensorflow/python/ops/structured/structured_tensor.py +++ b/tensorflow/python/ops/structured/structured_tensor.py @@ -1089,7 +1089,7 @@ def _partition_outer_dimension(value, row_partition): nrows = row_partition.static_nrows ncols = row_partition.static_uniform_row_length shape = tensor_shape.TensorShape([nrows, ncols]).concatenate( - value.shape[2:]) + value.shape[1:]) fields = dict((k, _partition_outer_dimension(v, row_partition)) for (k, v) in value._fields.items()) return StructuredTensor( diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index e2d6a161641..0f2ac2c83e1 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -523,6 +523,16 @@ class StructuredTensorTest(test_util.TensorFlowTestCase, "x": tensor_spec.TensorSpec([2, 2], dtypes.int32), "y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32)}))) + def testPartitionOuterDimension3(self): + rt = ragged_tensor.RaggedTensor.from_value_rowids( + array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) + struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2]) + struct_2 = struct.partition_outer_dimension( + row_partition.RowPartition.from_row_splits([0, 1, 2])) + struct_3 = struct_2.partition_outer_dimension( + row_partition.RowPartition.from_row_splits([0, 1, 2])) + self.assertEqual(3, struct_3.rank) + def testPartitionOuterDimsErrors(self): st = StructuredTensor.from_fields({}) partition = row_partition.RowPartition.from_row_splits([0]) @@ -889,6 +899,18 @@ class StructuredTensorTest(test_util.TensorFlowTestCase, result = st.merge_dims(outer_axis, inner_axis) self.assertAllEqual(result, expected) + def testMergeDims_0_1(self): + rt = ragged_tensor.RaggedTensor.from_value_rowids( + array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) + struct = StructuredTensor.from_fields({"r": rt}, [2]) + struct_2 = struct.partition_outer_dimension( + row_partition.RowPartition.from_row_splits([0, 1, 2])) + struct_3 = struct_2.partition_outer_dimension( + row_partition.RowPartition.from_row_splits([0, 1, 2])) + self.assertLen(struct_3.row_partitions, 2) + merged = struct_3.merge_dims(0, 1) + self.assertLen(merged.row_partitions, 1) + def testMergeDimsError(self): st = StructuredTensor.from_pyval([[[{"a": 5}]]]) with self.assertRaisesRegexp(