Update RaggedTensor.merge_dims to return the ragged tensor as-is if outer_axis==inner_axis. I.e.: rt.merge_dims(x, x) == rt. (Previously, it raised an exception for outer_axis==inner_axis.)

PiperOrigin-RevId: 341433226
Change-Id: I0f02a782a9feb000ab00dd25583e0de360cd4e4b
This commit is contained in:
Edward Loper 2020-11-09 10:20:17 -08:00 committed by TensorFlower Gardener
parent 8266d57ad7
commit fb31adb04f
4 changed files with 22 additions and 22 deletions

View File

@ -164,6 +164,20 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
'inner_axis': 3,
'expected': [[[1, 2], [3, 4], [5, 6], [7, 8]], [[9, 10], [11, 12]]],
},
{
'testcase_name': 'OuterEqualsInner',
'rt': [[1], [2], [3, 4]],
'outer_axis': 0,
'inner_axis': 0,
'expected': [[1], [2], [3, 4]],
},
{
'testcase_name': 'OuterEqualsInnerWithNegativeAxis',
'rt': [[1], [2], [3, 4]],
'outer_axis': 1,
'inner_axis': -1,
'expected': [[1], [2], [3, 4]],
},
]) # pyformat: disable
def testRaggedMergeDims(self,
rt,
@ -227,33 +241,19 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
'exception': ValueError,
'message': 'inner_axis=-3 out of bounds: expected -2<=inner_axis<2',
},
{
'rt': [[1]],
'outer_axis': 0,
'inner_axis': 0,
'exception': ValueError,
'message': 'Expected outer_axis .* to be less than inner_axis .*',
},
{
'rt': [[1]],
'outer_axis': 1,
'inner_axis': 0,
'exception': ValueError,
'message': 'Expected outer_axis .* to be less than inner_axis .*',
'message': 'Expected outer_axis .* to be less than or equal to .*',
},
{
'rt': [[1]],
'outer_axis': -1,
'inner_axis': -2,
'exception': ValueError,
'message': 'Expected outer_axis .* to be less than inner_axis .*',
},
{
'rt': [[1]],
'outer_axis': 1,
'inner_axis': -1,
'exception': ValueError,
'message': 'Expected outer_axis .* to be less than inner_axis .*',
'message': 'Expected outer_axis .* to be less than or equal to .*',
},
]) # pyformat: disable
def testRaggedMergeDimsError(self,

View File

@ -1426,8 +1426,8 @@ class RaggedTensor(composite_tensor.CompositeTensor,
self.shape.rank,
axis_name="inner_axis",
ndims_name="rank(self)")
if not outer_axis < inner_axis:
raise ValueError("Expected outer_axis (%d) to be less than "
if not outer_axis <= inner_axis:
raise ValueError("Expected outer_axis (%d) to be less than or equal to "
"inner_axis (%d)" % (outer_axis, inner_axis))
return merge_dims(self, outer_axis, inner_axis)

View File

@ -944,8 +944,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
self.shape.rank,
axis_name='inner_axis',
ndims_name='rank(self)')
if not outer_axis < inner_axis:
raise ValueError('Expected outer_axis (%d) to be less than '
if not outer_axis <= inner_axis:
raise ValueError('Expected outer_axis (%d) to be less than or equal to '
'inner_axis (%d)' % (outer_axis, inner_axis))
return _merge_dims(self, outer_axis, inner_axis)

View File

@ -916,8 +916,8 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
def testMergeDimsError(self):
st = StructuredTensor.from_pyval([[[{"a": 5}]]])
with self.assertRaisesRegex(
ValueError,
r"Expected outer_axis \(2\) to be less than inner_axis \(1\)"):
ValueError, r"Expected outer_axis \(2\) to be less than "
r"or equal to inner_axis \(1\)"):
st.merge_dims(2, 1)
def testTupleFieldValue(self):