Do not mask errors in Dimension.{is_compatible_with,merge_with}

PiperOrigin-RevId: 276328746
Change-Id: I6d046d460332ebaee7453176202cf75221414af9
This commit is contained in:
Sergei Lebedev 2019-10-23 12:33:12 -07:00 committed by TensorFlower Gardener
parent 9e0e60617c
commit 52354aaa7f
2 changed files with 16 additions and 8 deletions

View File

@ -258,10 +258,7 @@ class Dimension(object):
Returns:
True if this Dimension and `other` are compatible.
"""
try:
other = as_dimension(other)
except (TypeError, ValueError):
return NotImplemented
other = as_dimension(other)
return (self._value is None or other.value is None or
self._value == other.value)
@ -309,10 +306,7 @@ class Dimension(object):
ValueError: If `self` and `other` are not compatible (see
is_compatible_with).
"""
try:
other = as_dimension(other)
except (TypeError, ValueError):
return NotImplemented
other = as_dimension(other)
self.assert_is_compatible_with(other)
if self._value is None:
return Dimension(other.value)

View File

@ -169,6 +169,20 @@ class DimensionTest(test_util.TensorFlowTestCase):
self.assertIsNone(tensor_shape.Dimension(None) != None) # pylint: disable=g-equals-none
self.assertNotEqual(tensor_shape.Dimension(12), 12.99)
def testIsCompatibleWithError(self):
with self.assertRaisesRegex(TypeError, "must be integer or None"):
tensor_shape.Dimension(42).is_compatible_with([])
with self.assertRaisesRegex(ValueError, "must be >= 0"):
tensor_shape.Dimension(42).is_compatible_with(-1)
def testMergeWithError(self):
with self.assertRaisesRegex(TypeError, "must be integer or None"):
tensor_shape.Dimension(42).merge_with([])
with self.assertRaisesRegex(ValueError, "must be >= 0"):
tensor_shape.Dimension(42).merge_with(-1)
def testRepr(self):
self.assertEqual(repr(tensor_shape.Dimension(7)), "Dimension(7)")
self.assertEqual(repr(tensor_shape.Dimension(None)), "Dimension(None)")