EagerTensor.__index__ is now consistent with np.array.__index__
Previously, EagerTensor allowed lookups with non-scalar tensors, e.g. >>> index = tf.constant([0]) >>> [42][index] 42 If this change breaks your code, apply tf.squeeze to the index tensor. For the above example this would look like >>> [42][tf.squeeze(index)] 42 PiperOrigin-RevId: 249245276
This commit is contained in:
parent
2fad0483f2
commit
85429cebb0
@ -21,8 +21,10 @@ from __future__ import print_function
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -199,6 +201,16 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(bool(_create_tensor([1])))
|
self.assertTrue(bool(_create_tensor([1])))
|
||||||
self.assertTrue(bool(_create_tensor([1.])))
|
self.assertTrue(bool(_create_tensor([1.])))
|
||||||
|
|
||||||
|
@unittest.skipUnless(six.PY2, "long has been removed in PY3")
|
||||||
|
def testLong(self):
|
||||||
|
self.assertEqual(long(_create_tensor(long(42))), 42)
|
||||||
|
|
||||||
|
def testIndex(self):
|
||||||
|
self.assertEqual([42][_create_tensor(0)], 42)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
_ = [42][_create_tensor([0])]
|
||||||
|
|
||||||
def testIntDowncast(self):
|
def testIntDowncast(self):
|
||||||
t = _create_tensor(3)
|
t = _create_tensor(3)
|
||||||
self.assertEqual(dtypes.int32, t.dtype)
|
self.assertEqual(dtypes.int32, t.dtype)
|
||||||
|
@ -763,14 +763,21 @@ class _EagerTensorBase(Tensor):
|
|||||||
|
|
||||||
# __int__, __float__ and __index__ may copy the tensor to CPU and
|
# __int__, __float__ and __index__ may copy the tensor to CPU and
|
||||||
# only work for scalars; values are cast as per numpy.
|
# only work for scalars; values are cast as per numpy.
|
||||||
|
# TODO(slebedev): avoid redundant copy in all of the following methods.
|
||||||
def __int__(self):
|
def __int__(self):
|
||||||
return int(self.numpy())
|
return int(self.numpy())
|
||||||
|
|
||||||
|
def __long__(self):
|
||||||
|
return long(self.numpy())
|
||||||
|
|
||||||
def __float__(self):
|
def __float__(self):
|
||||||
return float(self.numpy())
|
return float(self.numpy())
|
||||||
|
|
||||||
def __index__(self):
|
def __index__(self):
|
||||||
return int(self.numpy())
|
maybe_arr = self.numpy()
|
||||||
|
if isinstance(maybe_arr, np.ndarray):
|
||||||
|
return maybe_arr.__index__()
|
||||||
|
return int(maybe_arr) # Must be a NumPy scalar.
|
||||||
|
|
||||||
def __array__(self, dtype=None):
|
def __array__(self, dtype=None):
|
||||||
return np.asarray(self.numpy(), dtype=dtype)
|
return np.asarray(self.numpy(), dtype=dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user