Add bool support for Dimension.

PiperOrigin-RevId: 350393553
Change-Id: Ife4004b377cd9dfab1322a08d6e7c4e74fc5416c
This commit is contained in:
A. Unique TensorFlower 2021-01-06 11:27:39 -08:00 committed by TensorFlower Gardener
parent e2395993d0
commit 0108d77292
2 changed files with 12 additions and 0 deletions

View File

@ -234,6 +234,10 @@ class Dimension(object):
return None
return self._value != other.value
def __bool__(self):
"""Equivalent to `bool(self.value)`."""
return bool(self._value)
def __int__(self):
return self._value

View File

@ -195,6 +195,14 @@ class DimensionTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError):
tensor_shape.Dimension(dtypes.string)
def testBool(self):
one = tensor_shape.Dimension(1)
zero = tensor_shape.Dimension(0)
has_none = tensor_shape.Dimension(None)
self.assertTrue(one)
self.assertFalse(zero)
self.assertFalse(has_none)
def testMod(self):
four = tensor_shape.Dimension(4)
nine = tensor_shape.Dimension(9)