Do not mask errors in Dimension.{is_compatible_with,merge_with}
PiperOrigin-RevId: 276328746 Change-Id: I6d046d460332ebaee7453176202cf75221414af9
This commit is contained in:
parent
9e0e60617c
commit
52354aaa7f
@ -258,10 +258,7 @@ class Dimension(object):
|
|||||||
Returns:
|
Returns:
|
||||||
True if this Dimension and `other` are compatible.
|
True if this Dimension and `other` are compatible.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
other = as_dimension(other)
|
other = as_dimension(other)
|
||||||
except (TypeError, ValueError):
|
|
||||||
return NotImplemented
|
|
||||||
return (self._value is None or other.value is None or
|
return (self._value is None or other.value is None or
|
||||||
self._value == other.value)
|
self._value == other.value)
|
||||||
|
|
||||||
@ -309,10 +306,7 @@ class Dimension(object):
|
|||||||
ValueError: If `self` and `other` are not compatible (see
|
ValueError: If `self` and `other` are not compatible (see
|
||||||
is_compatible_with).
|
is_compatible_with).
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
other = as_dimension(other)
|
other = as_dimension(other)
|
||||||
except (TypeError, ValueError):
|
|
||||||
return NotImplemented
|
|
||||||
self.assert_is_compatible_with(other)
|
self.assert_is_compatible_with(other)
|
||||||
if self._value is None:
|
if self._value is None:
|
||||||
return Dimension(other.value)
|
return Dimension(other.value)
|
||||||
|
@ -169,6 +169,20 @@ class DimensionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsNone(tensor_shape.Dimension(None) != None) # pylint: disable=g-equals-none
|
self.assertIsNone(tensor_shape.Dimension(None) != None) # pylint: disable=g-equals-none
|
||||||
self.assertNotEqual(tensor_shape.Dimension(12), 12.99)
|
self.assertNotEqual(tensor_shape.Dimension(12), 12.99)
|
||||||
|
|
||||||
|
def testIsCompatibleWithError(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, "must be integer or None"):
|
||||||
|
tensor_shape.Dimension(42).is_compatible_with([])
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "must be >= 0"):
|
||||||
|
tensor_shape.Dimension(42).is_compatible_with(-1)
|
||||||
|
|
||||||
|
def testMergeWithError(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, "must be integer or None"):
|
||||||
|
tensor_shape.Dimension(42).merge_with([])
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "must be >= 0"):
|
||||||
|
tensor_shape.Dimension(42).merge_with(-1)
|
||||||
|
|
||||||
def testRepr(self):
|
def testRepr(self):
|
||||||
self.assertEqual(repr(tensor_shape.Dimension(7)), "Dimension(7)")
|
self.assertEqual(repr(tensor_shape.Dimension(7)), "Dimension(7)")
|
||||||
self.assertEqual(repr(tensor_shape.Dimension(None)), "Dimension(None)")
|
self.assertEqual(repr(tensor_shape.Dimension(None)), "Dimension(None)")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user