TFE usability: Implement ndim
for EagerTensor
+NumPy compatibility
Third party libraries like matplotlib often access `ndim`. PiperOrigin-RevId: 182961097
This commit is contained in:
parent
144bfce5be
commit
7da6a83b74
@ -112,6 +112,19 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
numpy_tensor = np.asarray(tensor, dtype=np.int32)
|
||||
self.assertAllEqual(numpy_tensor, [1, 2, 3])
|
||||
|
||||
def testNdimsAgreesWithNumpy(self):
|
||||
numpy_tensor = np.asarray(1.0)
|
||||
tensor = constant_op.constant(numpy_tensor)
|
||||
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
|
||||
|
||||
numpy_tensor = np.asarray([1.0, 2.0, 3.0])
|
||||
tensor = constant_op.constant(numpy_tensor)
|
||||
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
|
||||
|
||||
numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
|
||||
tensor = constant_op.constant(numpy_tensor)
|
||||
self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
|
||||
|
||||
def testCopy(self):
|
||||
t = constant_op.constant(1.0)
|
||||
tt = copy.copy(t)
|
||||
|
@ -775,6 +775,11 @@ class _EagerTensorBase(Tensor):
|
||||
"""The shape of the tensor as a list."""
|
||||
return list(self._shape_tuple())
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
"""Returns the number of Tensor dimensions."""
|
||||
return self.shape.ndims
|
||||
|
||||
def cpu(self):
|
||||
"""A copy of this Tensor with contents backed by host memory."""
|
||||
return self._copy(context.context(), "CPU:0")
|
||||
|
Loading…
Reference in New Issue
Block a user