Delegate EagerTensor __bool__ behavior directly to numpy.
Numpy also casts non-bools to bool, so after this change EagerTensors that are not dtype bool can be converted to python bools. This also allows any shape EagerTensor with a single element to be converted to a python bool. PiperOrigin-RevId: 223409869
This commit is contained in:
parent
5e6092315a
commit
29574450f8
@ -175,9 +175,13 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(dtypes.float64, t.dtype)
|
||||
|
||||
def testBool(self):
|
||||
t = _create_tensor(False)
|
||||
if t:
|
||||
self.assertFalse(True)
|
||||
self.assertFalse(bool(_create_tensor(False)))
|
||||
self.assertFalse(bool(_create_tensor([False])))
|
||||
self.assertFalse(bool(_create_tensor([[False]])))
|
||||
self.assertFalse(bool(_create_tensor([0])))
|
||||
self.assertFalse(bool(_create_tensor([0.])))
|
||||
self.assertTrue(bool(_create_tensor([1])))
|
||||
self.assertTrue(bool(_create_tensor([1.])))
|
||||
|
||||
def testIntDowncast(self):
|
||||
t = _create_tensor(3)
|
||||
|
@ -910,13 +910,7 @@ class _EagerTensorBase(Tensor):
|
||||
return self._copy(context.context(), "GPU:" + str(gpu_index))
|
||||
|
||||
def __bool__(self):
|
||||
if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError(
|
||||
"Non-scalar tensor %s cannot be converted to boolean." % repr(self))
|
||||
if self.dtype != dtypes.bool:
|
||||
raise ValueError(
|
||||
"Non-boolean tensor %s cannot be converted to boolean." % repr(self))
|
||||
return bool(self.cpu().numpy())
|
||||
return bool(self.numpy())
|
||||
|
||||
def __nonzero__(self):
|
||||
return self.__bool__()
|
||||
|
Loading…
Reference in New Issue
Block a user