Fixed StructuredTensor.partition_outer_dimension. The shape

output wasn't always right.

PiperOrigin-RevId: 309127604
Change-Id: Ie4995f765e57a4a857ee4a902a5ae3042215aefb
This commit is contained in:
A. Unique TensorFlower 2020-04-29 17:05:41 -07:00 committed by TensorFlower Gardener
parent d8dea3ade3
commit af76ffe37f
2 changed files with 23 additions and 1 deletions

View File

@ -1089,7 +1089,7 @@ def _partition_outer_dimension(value, row_partition):
nrows = row_partition.static_nrows
ncols = row_partition.static_uniform_row_length
shape = tensor_shape.TensorShape([nrows, ncols]).concatenate(
value.shape[2:])
value.shape[1:])
fields = dict((k, _partition_outer_dimension(v, row_partition))
for (k, v) in value._fields.items())
return StructuredTensor(

View File

@ -523,6 +523,16 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
"x": tensor_spec.TensorSpec([2, 2], dtypes.int32),
"y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32)})))
def testPartitionOuterDimension3(self):
rt = ragged_tensor.RaggedTensor.from_value_rowids(
array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2])
struct_2 = struct.partition_outer_dimension(
row_partition.RowPartition.from_row_splits([0, 1, 2]))
struct_3 = struct_2.partition_outer_dimension(
row_partition.RowPartition.from_row_splits([0, 1, 2]))
self.assertEqual(3, struct_3.rank)
def testPartitionOuterDimsErrors(self):
st = StructuredTensor.from_fields({})
partition = row_partition.RowPartition.from_row_splits([0])
@ -889,6 +899,18 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
result = st.merge_dims(outer_axis, inner_axis)
self.assertAllEqual(result, expected)
def testMergeDims_0_1(self):
rt = ragged_tensor.RaggedTensor.from_value_rowids(
array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
struct = StructuredTensor.from_fields({"r": rt}, [2])
struct_2 = struct.partition_outer_dimension(
row_partition.RowPartition.from_row_splits([0, 1, 2]))
struct_3 = struct_2.partition_outer_dimension(
row_partition.RowPartition.from_row_splits([0, 1, 2]))
self.assertLen(struct_3.row_partitions, 2)
merged = struct_3.merge_dims(0, 1)
self.assertLen(merged.row_partitions, 1)
def testMergeDimsError(self):
st = StructuredTensor.from_pyval([[[{"a": 5}]]])
with self.assertRaisesRegexp(