TFE usability: Implement ndim for EagerTensor+NumPy compatibility

Third party libraries like matplotlib often access `ndim`.

PiperOrigin-RevId: 182961097
This commit is contained in:
Akshay Agrawal 2018-01-23 11:05:50 -08:00 committed by TensorFlower Gardener
parent 144bfce5be
commit 7da6a83b74
2 changed files with 18 additions and 0 deletions

View File

@ -112,6 +112,19 @@ class TFETensorTest(test_util.TensorFlowTestCase):
numpy_tensor = np.asarray(tensor, dtype=np.int32)
self.assertAllEqual(numpy_tensor, [1, 2, 3])
def testNdimsAgreesWithNumpy(self):
numpy_tensor = np.asarray(1.0)
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
numpy_tensor = np.asarray([1.0, 2.0, 3.0])
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
tensor = constant_op.constant(numpy_tensor)
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
def testCopy(self):
t = constant_op.constant(1.0)
tt = copy.copy(t)

View File

@ -775,6 +775,11 @@ class _EagerTensorBase(Tensor):
"""The shape of the tensor as a list."""
return list(self._shape_tuple())
@property
def ndim(self):
"""Returns the number of Tensor dimensions."""
return self.shape.ndims
def cpu(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.context(), "CPU:0")