Fixed StructuredTensor.partition_outer_dimension. The shape
output wasn't always right. PiperOrigin-RevId: 309127604 Change-Id: Ie4995f765e57a4a857ee4a902a5ae3042215aefb
This commit is contained in:
parent
d8dea3ade3
commit
af76ffe37f
@ -1089,7 +1089,7 @@ def _partition_outer_dimension(value, row_partition):
|
||||
nrows = row_partition.static_nrows
|
||||
ncols = row_partition.static_uniform_row_length
|
||||
shape = tensor_shape.TensorShape([nrows, ncols]).concatenate(
|
||||
value.shape[2:])
|
||||
value.shape[1:])
|
||||
fields = dict((k, _partition_outer_dimension(v, row_partition))
|
||||
for (k, v) in value._fields.items())
|
||||
return StructuredTensor(
|
||||
|
@ -523,6 +523,16 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
|
||||
"x": tensor_spec.TensorSpec([2, 2], dtypes.int32),
|
||||
"y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32)})))
|
||||
|
||||
def testPartitionOuterDimension3(self):
|
||||
rt = ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
|
||||
struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2])
|
||||
struct_2 = struct.partition_outer_dimension(
|
||||
row_partition.RowPartition.from_row_splits([0, 1, 2]))
|
||||
struct_3 = struct_2.partition_outer_dimension(
|
||||
row_partition.RowPartition.from_row_splits([0, 1, 2]))
|
||||
self.assertEqual(3, struct_3.rank)
|
||||
|
||||
def testPartitionOuterDimsErrors(self):
|
||||
st = StructuredTensor.from_fields({})
|
||||
partition = row_partition.RowPartition.from_row_splits([0])
|
||||
@ -889,6 +899,18 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
|
||||
result = st.merge_dims(outer_axis, inner_axis)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testMergeDims_0_1(self):
|
||||
rt = ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
|
||||
struct = StructuredTensor.from_fields({"r": rt}, [2])
|
||||
struct_2 = struct.partition_outer_dimension(
|
||||
row_partition.RowPartition.from_row_splits([0, 1, 2]))
|
||||
struct_3 = struct_2.partition_outer_dimension(
|
||||
row_partition.RowPartition.from_row_splits([0, 1, 2]))
|
||||
self.assertLen(struct_3.row_partitions, 2)
|
||||
merged = struct_3.merge_dims(0, 1)
|
||||
self.assertLen(merged.row_partitions, 1)
|
||||
|
||||
def testMergeDimsError(self):
|
||||
st = StructuredTensor.from_pyval([[[{"a": 5}]]])
|
||||
with self.assertRaisesRegexp(
|
||||
|
Loading…
Reference in New Issue
Block a user