EagerTensor.__index__ is now consistent with np.array.__index__
Previously, EagerTensor allowed lookups with non-scalar tensors, e.g. >>> index = tf.constant([0]) >>> [42][index] 42 If this change breaks your code, apply tf.squeeze to the index tensor. For the above example this would look like >>> [42][tf.squeeze(index)] 42 PiperOrigin-RevId: 249245276
This commit is contained in:
parent
2fad0483f2
commit
85429cebb0
tensorflow/python
@ -21,8 +21,10 @@ from __future__ import print_function
|
||||
import copy
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
@ -199,6 +201,16 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(bool(_create_tensor([1])))
|
||||
self.assertTrue(bool(_create_tensor([1.])))
|
||||
|
||||
@unittest.skipUnless(six.PY2, "long has been removed in PY3")
|
||||
def testLong(self):
|
||||
self.assertEqual(long(_create_tensor(long(42))), 42)
|
||||
|
||||
def testIndex(self):
|
||||
self.assertEqual([42][_create_tensor(0)], 42)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_ = [42][_create_tensor([0])]
|
||||
|
||||
def testIntDowncast(self):
|
||||
t = _create_tensor(3)
|
||||
self.assertEqual(dtypes.int32, t.dtype)
|
||||
|
@ -763,14 +763,21 @@ class _EagerTensorBase(Tensor):
|
||||
|
||||
# __int__, __float__ and __index__ may copy the tensor to CPU and
|
||||
# only work for scalars; values are cast as per numpy.
|
||||
# TODO(slebedev): avoid redundant copy in all of the following methods.
|
||||
def __int__(self):
|
||||
return int(self.numpy())
|
||||
|
||||
def __long__(self):
|
||||
return long(self.numpy())
|
||||
|
||||
def __float__(self):
|
||||
return float(self.numpy())
|
||||
|
||||
def __index__(self):
|
||||
return int(self.numpy())
|
||||
maybe_arr = self.numpy()
|
||||
if isinstance(maybe_arr, np.ndarray):
|
||||
return maybe_arr.__index__()
|
||||
return int(maybe_arr) # Must be a NumPy scalar.
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
return np.asarray(self.numpy(), dtype=dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user