diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index b5833718c79..77f756d76fe 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -21,8 +21,10 @@ from __future__ import print_function import copy import re import sys +import unittest import numpy as np +import six from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -199,6 +201,16 @@ class TFETensorTest(test_util.TensorFlowTestCase): self.assertTrue(bool(_create_tensor([1]))) self.assertTrue(bool(_create_tensor([1.]))) + @unittest.skipUnless(six.PY2, "long has been removed in PY3") + def testLong(self): + self.assertEqual(long(_create_tensor(long(42))), 42) + + def testIndex(self): + self.assertEqual([42][_create_tensor(0)], 42) + + with self.assertRaises(TypeError): + _ = [42][_create_tensor([0])] + def testIntDowncast(self): t = _create_tensor(3) self.assertEqual(dtypes.int32, t.dtype) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 30c820a6020..1b2b5198540 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -763,14 +763,21 @@ class _EagerTensorBase(Tensor): # __int__, __float__ and __index__ may copy the tensor to CPU and # only work for scalars; values are cast as per numpy. + # TODO(slebedev): avoid redundant copy in all of the following methods. def __int__(self): return int(self.numpy()) + def __long__(self): + return long(self.numpy()) + def __float__(self): return float(self.numpy()) def __index__(self): - return int(self.numpy()) + maybe_arr = self.numpy() + if isinstance(maybe_arr, np.ndarray): + return maybe_arr.__index__() + return int(maybe_arr) # Must be a NumPy scalar. def __array__(self, dtype=None): return np.asarray(self.numpy(), dtype=dtype)