From 7da6a83b74f467929032dce95794ef0197d46b20 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Tue, 23 Jan 2018 11:05:50 -0800 Subject: [PATCH] TFE usability: Implement `ndim` for `EagerTensor`+NumPy compatibility Third party libraries like matplotlib often access `ndim`. PiperOrigin-RevId: 182961097 --- tensorflow/python/eager/tensor_test.py | 13 +++++++++++++ tensorflow/python/framework/ops.py | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 2568d3dc054..0bd5a5dbafd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -112,6 +112,19 @@ class TFETensorTest(test_util.TensorFlowTestCase): numpy_tensor = np.asarray(tensor, dtype=np.int32) self.assertAllEqual(numpy_tensor, [1, 2, 3]) + def testNdimsAgreesWithNumpy(self): + numpy_tensor = np.asarray(1.0) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + + numpy_tensor = np.asarray([1.0, 2.0, 3.0]) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + + numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + tensor = constant_op.constant(numpy_tensor) + self.assertAllEqual(numpy_tensor.ndim, tensor.ndim) + def testCopy(self): t = constant_op.constant(1.0) tt = copy.copy(t) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0eb06ae9132..2489982d934 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -775,6 +775,11 @@ class _EagerTensorBase(Tensor): """The shape of the tensor as a list.""" return list(self._shape_tuple()) + @property + def ndim(self): + """Returns the number of Tensor dimensions.""" + return self.shape.ndims + def cpu(self): """A copy of this Tensor with contents backed by host memory.""" return self._copy(context.context(), "CPU:0")