Enable __array__ access on EagerTensors (forwarding to _numpy)
PiperOrigin-RevId: 309881480 Change-Id: I095a3ab0e3a991074d5f3949501302792cc9fd3b
This commit is contained in:
parent
c265e89afa
commit
f5ad74a467
@ -1067,15 +1067,17 @@ class _EagerTensorBase(Tensor):
|
|||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||||
|
|
||||||
|
def __array__(self):
|
||||||
|
return self._numpy()
|
||||||
|
|
||||||
def _numpy_internal(self):
|
def _numpy_internal(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _numpy(self):
|
def _numpy(self):
|
||||||
# pylint: disable=protected-access
|
|
||||||
try:
|
try:
|
||||||
return self._numpy_internal()
|
return self._numpy_internal()
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e: # pylint: disable=protected-access
|
||||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
six.raise_from(core._status_to_exception(e.code, e.message), None) # pylint: disable=protected-access
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user