Update RaggedTensor.merge_dims to return the ragged tensor as-is if outer_axis==inner_axis. I.e.: rt.merge_dims(x, x) == rt
. (Previously, it raised an exception for outer_axis==inner_axis.)
PiperOrigin-RevId: 341433226 Change-Id: I0f02a782a9feb000ab00dd25583e0de360cd4e4b
This commit is contained in:
parent
8266d57ad7
commit
fb31adb04f
@ -164,6 +164,20 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
|
||||
'inner_axis': 3,
|
||||
'expected': [[[1, 2], [3, 4], [5, 6], [7, 8]], [[9, 10], [11, 12]]],
|
||||
},
|
||||
{
|
||||
'testcase_name': 'OuterEqualsInner',
|
||||
'rt': [[1], [2], [3, 4]],
|
||||
'outer_axis': 0,
|
||||
'inner_axis': 0,
|
||||
'expected': [[1], [2], [3, 4]],
|
||||
},
|
||||
{
|
||||
'testcase_name': 'OuterEqualsInnerWithNegativeAxis',
|
||||
'rt': [[1], [2], [3, 4]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': -1,
|
||||
'expected': [[1], [2], [3, 4]],
|
||||
},
|
||||
]) # pyformat: disable
|
||||
def testRaggedMergeDims(self,
|
||||
rt,
|
||||
@ -227,33 +241,19 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
|
||||
'exception': ValueError,
|
||||
'message': 'inner_axis=-3 out of bounds: expected -2<=inner_axis<2',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 0,
|
||||
'inner_axis': 0,
|
||||
'exception': ValueError,
|
||||
'message': 'Expected outer_axis .* to be less than inner_axis .*',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': 0,
|
||||
'exception': ValueError,
|
||||
'message': 'Expected outer_axis .* to be less than inner_axis .*',
|
||||
'message': 'Expected outer_axis .* to be less than or equal to .*',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': -1,
|
||||
'inner_axis': -2,
|
||||
'exception': ValueError,
|
||||
'message': 'Expected outer_axis .* to be less than inner_axis .*',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': -1,
|
||||
'exception': ValueError,
|
||||
'message': 'Expected outer_axis .* to be less than inner_axis .*',
|
||||
'message': 'Expected outer_axis .* to be less than or equal to .*',
|
||||
},
|
||||
]) # pyformat: disable
|
||||
def testRaggedMergeDimsError(self,
|
||||
|
@ -1426,8 +1426,8 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
self.shape.rank,
|
||||
axis_name="inner_axis",
|
||||
ndims_name="rank(self)")
|
||||
if not outer_axis < inner_axis:
|
||||
raise ValueError("Expected outer_axis (%d) to be less than "
|
||||
if not outer_axis <= inner_axis:
|
||||
raise ValueError("Expected outer_axis (%d) to be less than or equal to "
|
||||
"inner_axis (%d)" % (outer_axis, inner_axis))
|
||||
return merge_dims(self, outer_axis, inner_axis)
|
||||
|
||||
|
@ -944,8 +944,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
self.shape.rank,
|
||||
axis_name='inner_axis',
|
||||
ndims_name='rank(self)')
|
||||
if not outer_axis < inner_axis:
|
||||
raise ValueError('Expected outer_axis (%d) to be less than '
|
||||
if not outer_axis <= inner_axis:
|
||||
raise ValueError('Expected outer_axis (%d) to be less than or equal to '
|
||||
'inner_axis (%d)' % (outer_axis, inner_axis))
|
||||
return _merge_dims(self, outer_axis, inner_axis)
|
||||
|
||||
|
@ -916,8 +916,8 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
|
||||
def testMergeDimsError(self):
|
||||
st = StructuredTensor.from_pyval([[[{"a": 5}]]])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Expected outer_axis \(2\) to be less than inner_axis \(1\)"):
|
||||
ValueError, r"Expected outer_axis \(2\) to be less than "
|
||||
r"or equal to inner_axis \(1\)"):
|
||||
st.merge_dims(2, 1)
|
||||
|
||||
def testTupleFieldValue(self):
|
||||
|
Loading…
Reference in New Issue
Block a user