Delegate EagerTensor __bool__ behavior directly to numpy.

Numpy also casts non-bools to bool, so after this change EagerTensors that are not dtype bool can be converted to python bools. This also allows any shape EagerTensor with a single element to be converted to a python bool.

PiperOrigin-RevId: 223409869
This commit is contained in:
Akshay Modi 2018-11-29 14:17:23 -08:00 committed by TensorFlower Gardener
parent 5e6092315a
commit 29574450f8
2 changed files with 8 additions and 10 deletions

View File

@ -175,9 +175,13 @@ class TFETensorTest(test_util.TensorFlowTestCase):
self.assertEqual(dtypes.float64, t.dtype)
def testBool(self):
t = _create_tensor(False)
if t:
self.assertFalse(True)
self.assertFalse(bool(_create_tensor(False)))
self.assertFalse(bool(_create_tensor([False])))
self.assertFalse(bool(_create_tensor([[False]])))
self.assertFalse(bool(_create_tensor([0])))
self.assertFalse(bool(_create_tensor([0.])))
self.assertTrue(bool(_create_tensor([1])))
self.assertTrue(bool(_create_tensor([1.])))
def testIntDowncast(self):
t = _create_tensor(3)

View File

@ -910,13 +910,7 @@ class _EagerTensorBase(Tensor):
return self._copy(context.context(), "GPU:" + str(gpu_index))
def __bool__(self):
if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison
raise ValueError(
"Non-scalar tensor %s cannot be converted to boolean." % repr(self))
if self.dtype != dtypes.bool:
raise ValueError(
"Non-boolean tensor %s cannot be converted to boolean." % repr(self))
return bool(self.cpu().numpy())
return bool(self.numpy())
def __nonzero__(self):
return self.__bool__()