This CL allows EagerTensors to support len()
like numpy arrays. This should allow matplotlib plotting to work correctly with EagerTensors.
PiperOrigin-RevId: 221335277
This commit is contained in:
parent
5b1e910313
commit
4b543f449c
@ -128,6 +128,23 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
tensor = constant_op.constant(numpy_tensor)
|
tensor = constant_op.constant(numpy_tensor)
|
||||||
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
|
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
|
||||||
|
|
||||||
|
def testLenAgreesWithNumpy(self):
|
||||||
|
numpy_tensor = np.asarray(1.0)
|
||||||
|
tensor = constant_op.constant(numpy_tensor)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
len(numpy_tensor)
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, r"Scalar tensor has no `len[(][)]`"):
|
||||||
|
len(tensor)
|
||||||
|
|
||||||
|
numpy_tensor = np.asarray([1.0, 2.0, 3.0])
|
||||||
|
tensor = constant_op.constant(numpy_tensor)
|
||||||
|
self.assertAllEqual(len(numpy_tensor), len(tensor))
|
||||||
|
|
||||||
|
numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
|
||||||
|
tensor = constant_op.constant(numpy_tensor)
|
||||||
|
self.assertAllEqual(len(numpy_tensor), len(tensor))
|
||||||
|
|
||||||
def testCopy(self):
|
def testCopy(self):
|
||||||
t = constant_op.constant(1.0)
|
t = constant_op.constant(1.0)
|
||||||
tt = copy.copy(t)
|
tt = copy.copy(t)
|
||||||
|
@ -890,6 +890,12 @@ class _EagerTensorBase(Tensor):
|
|||||||
"""Returns the number of Tensor dimensions."""
|
"""Returns the number of Tensor dimensions."""
|
||||||
return self.shape.ndims
|
return self.shape.ndims
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Returns the length of the first dimension in the Tensor."""
|
||||||
|
if not self.shape.ndims:
|
||||||
|
raise TypeError("Scalar tensor has no `len()`")
|
||||||
|
return self._shape_tuple()[0]
|
||||||
|
|
||||||
def _cpu_nograd(self):
|
def _cpu_nograd(self):
|
||||||
"""A copy of this Tensor with contents backed by host memory.
|
"""A copy of this Tensor with contents backed by host memory.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user