EagerTensor_numpy now returns a view for supported dtypes
Two notes: * Prior to this change it always produced a copy. See comment in EagerTensor_numpy for details. * EagerTensor.numpy still returns a copy to ensure no change of behavior. This is likely to change in the followup CL. PiperOrigin-RevId: 246378787
This commit is contained in:
parent
73a5fa2ae1
commit
7caec689ac
@ -443,6 +443,7 @@ cc_library(
|
|||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
":safe_ptr",
|
":safe_ptr",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -627,7 +627,36 @@ static PyObject* EagerTensor_numpy(EagerTensor* self) {
|
|||||||
PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
|
PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HACK(slebedev): The following explains why TensorToNdarray never
|
||||||
|
// reuses the storage.
|
||||||
|
//
|
||||||
|
// TF_TensorToPyArray copies the storage unless its
|
||||||
|
// refcount is 1. For DT_STRING and DT_RESOURCE TF_TensorFromTensor
|
||||||
|
// has to copy so the refcount of the original storage is unchanged.
|
||||||
|
// However, if the storage can be reused by TF_TensorFromTensor its
|
||||||
|
// refcount is +1'd and hence TF_TensorToPyArray no longer can reuse it.
|
||||||
|
//
|
||||||
|
// Here we attempt a direct conversion without an intermediate TF_Tensor
|
||||||
|
// and fall-back to the slow path on failure.
|
||||||
PyObject* ret = nullptr;
|
PyObject* ret = nullptr;
|
||||||
|
if (t->dtype() != tensorflow::DT_STRING &&
|
||||||
|
t->dtype() != tensorflow::DT_RESOURCE) {
|
||||||
|
tensorflow::gtl::InlinedVector<npy_intp, 4> dims(t->dims());
|
||||||
|
for (int d = 0; d < t->dims(); ++d) {
|
||||||
|
dims[d] = t->dim_size(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* copy = new tensorflow::Tensor(*t);
|
||||||
|
char* data = const_cast<char*>(copy->tensor_data().data());
|
||||||
|
if (tensorflow::ArrayFromMemory(
|
||||||
|
dims.size(), dims.data(), data, t->dtype(), [copy] { delete copy; },
|
||||||
|
&ret)
|
||||||
|
.ok()) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
||||||
if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
|
if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
|
||||||
Py_XDECREF(ret);
|
Py_XDECREF(ret);
|
||||||
|
@ -487,6 +487,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
ValueError, "non-rectangular Python sequence"):
|
ValueError, "non-rectangular Python sequence"):
|
||||||
constant_op.constant(l)
|
constant_op.constant(l)
|
||||||
|
|
||||||
|
def test_numpyIsView(self):
|
||||||
|
t = constant_op.constant([0.0])
|
||||||
|
t._numpy()[0] = 42.0
|
||||||
|
self.assertAllClose(t, constant_op.constant([42.0]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -758,7 +758,8 @@ class _EagerTensorBase(Tensor):
|
|||||||
"""
|
"""
|
||||||
if self.dtype == dtypes.resource:
|
if self.dtype == dtypes.resource:
|
||||||
raise ValueError("Resource handles are not convertible to numpy.")
|
raise ValueError("Resource handles are not convertible to numpy.")
|
||||||
return self._cpu_nograd()._numpy() # pylint: disable=protected-access
|
maybe_arr = self._cpu_nograd()._numpy() # pylint: disable=protected-access
|
||||||
|
return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
|
||||||
|
|
||||||
# __int__, __float__ and __index__ may copy the tensor to CPU and
|
# __int__, __float__ and __index__ may copy the tensor to CPU and
|
||||||
# only work for scalars; values are cast as per numpy.
|
# only work for scalars; values are cast as per numpy.
|
||||||
@ -772,7 +773,7 @@ class _EagerTensorBase(Tensor):
|
|||||||
return int(self.numpy())
|
return int(self.numpy())
|
||||||
|
|
||||||
def __array__(self, dtype=None):
|
def __array__(self, dtype=None):
|
||||||
return np.array(self.numpy(), dtype=dtype)
|
return np.asarray(self.numpy(), dtype=dtype)
|
||||||
|
|
||||||
def __format__(self, format_spec):
|
def __format__(self, format_spec):
|
||||||
return self.numpy().__format__(format_spec)
|
return self.numpy().__format__(format_spec)
|
||||||
|
Loading…
Reference in New Issue
Block a user