This CL allows EagerTensors to support len() like numpy arrays. This should allow matplotlib plotting to work correctly with EagerTensors.

PiperOrigin-RevId: 221335277
This commit is contained in:
A. Unique TensorFlower 2018-11-13 14:31:34 -08:00 committed by TensorFlower Gardener
parent 5b1e910313
commit 4b543f449c
2 changed files with 23 additions and 0 deletions

View File

@ -128,6 +128,23 @@ class TFETensorTest(test_util.TensorFlowTestCase):
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
def testLenAgreesWithNumpy(self):
numpy_tensor = np.asarray(1.0)
tensor = constant_op.constant(numpy_tensor)
with self.assertRaises(TypeError):
len(numpy_tensor)
with self.assertRaisesRegexp(
TypeError, r"Scalar tensor has no `len[(][)]`"):
len(tensor)
numpy_tensor = np.asarray([1.0, 2.0, 3.0])
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(len(numpy_tensor), len(tensor))
numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(len(numpy_tensor), len(tensor))
def testCopy(self):
t = constant_op.constant(1.0)
tt = copy.copy(t)

View File

@ -890,6 +890,12 @@ class _EagerTensorBase(Tensor):
"""Returns the number of Tensor dimensions."""
return self.shape.ndims
def __len__(self):
"""Returns the length of the first dimension in the Tensor."""
if not self.shape.ndims:
raise TypeError("Scalar tensor has no `len()`")
return self._shape_tuple()[0]
def _cpu_nograd(self):
"""A copy of this Tensor with contents backed by host memory.