From 29574450f8701385b821f8b6701e19df1d41dbe5 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Thu, 29 Nov 2018 14:17:23 -0800 Subject: [PATCH] Delegate EagerTensor __bool__ behavior directly to numpy. Numpy also casts non-bools to bool, so after this change EagerTensors that are not dtype bool can be converted to python bools. This also allows any shape EagerTensor with a single element to be converted to a python bool. PiperOrigin-RevId: 223409869 --- tensorflow/python/eager/tensor_test.py | 10 +++++++--- tensorflow/python/framework/ops.py | 8 +------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 8c9d5dabe79..25442ff0485 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -175,9 +175,13 @@ class TFETensorTest(test_util.TensorFlowTestCase): self.assertEqual(dtypes.float64, t.dtype) def testBool(self): - t = _create_tensor(False) - if t: - self.assertFalse(True) + self.assertFalse(bool(_create_tensor(False))) + self.assertFalse(bool(_create_tensor([False]))) + self.assertFalse(bool(_create_tensor([[False]]))) + self.assertFalse(bool(_create_tensor([0]))) + self.assertFalse(bool(_create_tensor([0.]))) + self.assertTrue(bool(_create_tensor([1]))) + self.assertTrue(bool(_create_tensor([1.]))) def testIntDowncast(self): t = _create_tensor(3) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index bd798f9ffa2..b5175d3c93b 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -910,13 +910,7 @@ class _EagerTensorBase(Tensor): return self._copy(context.context(), "GPU:" + str(gpu_index)) def __bool__(self): - if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison - raise ValueError( - "Non-scalar tensor %s cannot be converted to boolean." % repr(self)) - if self.dtype != dtypes.bool: - raise ValueError( - "Non-boolean tensor %s cannot be converted to boolean." % repr(self)) - return bool(self.cpu().numpy()) + return bool(self.numpy()) def __nonzero__(self): return self.__bool__()