From 8b0e6baeb342a41b32d221e499f39ce6f14b19ba Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Wed, 4 Sep 2019 13:43:48 -0700 Subject: [PATCH] This change includes: 1. When failed to compute a eager tensor, let the EagerTensor throw proper computation exception instead of ValueError. 2. Stop using StreamingEnqueueAsync for destroy tensor handle request. With this change, turning on streaming rpc won't break dataset iterator anymore. PiperOrigin-RevId: 267222175 --- .../eager/destroy_tensor_handle_node.h | 17 ++++++++---- .../eager/remote_tensor_handle_data.cc | 10 +++---- tensorflow/python/eager/pywrap_tensor.cc | 18 ++++++------- tensorflow/python/eager/tensor_test.py | 4 ++- tensorflow/python/framework/ops.py | 26 +++++++++++++++---- .../resource_variable_ops_test.py | 6 ----- 6 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h index b3d482dc0c8..6f1a4fb6f2d 100644 --- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h +++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h @@ -28,17 +28,23 @@ namespace eager { class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { public: DestroyTensorHandleNode(std::unique_ptr request, - EagerClient* eager_client) + EagerClient* eager_client, bool ready) : tensorflow::AsyncEagerNode(), request_(std::move(request)), - eager_client_(eager_client) {} + eager_client_(eager_client), + ready_(ready) {} void RunAsync(StatusCallback done) override { EnqueueResponse* response = new EnqueueResponse; - eager_client_->StreamingEnqueueAsync( + bool ready = ready_; + // NOTE(fishx): Don't use StreamingEnqueueAsync here. When a + // StreamingEnqueueAsync request fails all following requests will fail as + // well. We don't want this request poison following requests since it is + // safe to ignore a failing destroy tensor handle request. + eager_client_->EnqueueAsync( request_.get(), response, - [response, done](const tensorflow::Status& s) { - if (!s.ok()) { + [response, ready, done](const tensorflow::Status& s) { + if (!s.ok() && ready) { LOG(WARNING) << "Ignoring an error encountered when deleting " "remote tensors handles: " << s.ToString(); @@ -59,6 +65,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { private: std::unique_ptr request_; EagerClient* eager_client_; // Not owned, and must outlive this node. + bool ready_; }; } // namespace eager diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 85ad20e51d9..b906bd0bdc7 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -25,8 +25,8 @@ namespace { void DestoryRemoteTensorHandle(EagerContext* ctx, eager::EagerClient* eager_client, - uint64 context_id, uint64 op_id, - int output_num) { + uint64 context_id, uint64 op_id, int output_num, + bool ready) { if (ctx->GetContextId() != context_id) { // This means that this tensor was pointing to a remote device, which // has been changed out from under us. Simply return since there is @@ -44,7 +44,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx, VLOG(3) << "Sending request to delete " << request->DebugString(); std::unique_ptr node( absl::make_unique(std::move(request), - eager_client)); + eager_client, ready)); auto* executor = ctx->Executor(); if (executor->Async()) { Status status = executor->Add(std::move(node)); @@ -87,7 +87,7 @@ RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, RemoteTensorHandleData::~RemoteTensorHandleData() { DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_, - output_num_); + output_num_, /*ready=*/true); ctx_->Unref(); } @@ -150,7 +150,7 @@ UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData( UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() { if (delete_remote_tensor_) { DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_, - output_num_); + output_num_, /*ready=*/false); } ctx_->Unref(); } diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b81eddac077..3e921793747 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -580,7 +580,7 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { auto handle = self->handle; int n = TFE_TensorHandleNumDims(handle, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; @@ -590,7 +590,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { for (int i = 0; i < n; ++i) { PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status)); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) || + if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) || dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); @@ -606,7 +606,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { // Getter for `_rank`. static PyObject* EagerTensor_rank(EagerTensor* self) { int num_dims = TFE_TensorHandleNumDims(self->handle, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; @@ -622,7 +622,7 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { static PyObject* EagerTensor_num_elements(EagerTensor* self) { auto handle = self->handle; int n = TFE_TensorHandleNumElements(handle, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; @@ -680,14 +680,14 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, return EagerTensorFromHandle(handle); } -// Function `_numpy`. +// Function `_numpy_internal`. // Convert an EagerTensor to a Python numpy.ndarray object. // The two may share underlying storage so changes to one may reflect in the // other. // Note that if `self` is not on CPU, we raise an Exception. -static PyObject* EagerTensor_numpy(EagerTensor* self) { +static PyObject* EagerTensor_numpy_internal(EagerTensor* self) { auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { Py_XDECREF(py_array); // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); @@ -754,8 +754,8 @@ static PyMemberDef EagerTensor_members[] = { #endif static PyMethodDef EagerTensor_methods[] = { - {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS, - PyDoc_STR("_numpy")}, + {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS, + PyDoc_STR("_numpy_internal")}, {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS, PyDoc_STR("_datatype_enum")}, {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS, diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 74b4b438e0f..eb336aad90f 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -32,6 +32,7 @@ from tensorflow.python.eager import core from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -380,7 +381,8 @@ class TFETensorTest(test_util.TensorFlowTestCase): def test_numpyFailsForResource(self): v = variables.Variable(42) - with self.assertRaisesRegex(ValueError, "Cannot convert .+ resource"): + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Cannot convert .+ resource"): v._handle._numpy() def testMemoryviewFailsForResource(self): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index cc5c5044927..7b25cfaff72 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -910,10 +910,21 @@ class _EagerTensorBase(Tensor): """Returns the length of the first dimension in the Tensor.""" if not self.shape.ndims: raise TypeError("Scalar tensor has no `len()`") - return self._shape_tuple()[0] + # pylint: disable=protected-access + try: + return self._shape_tuple()[0] + except core._NotOkStatusException as e: + six.raise_from(core._status_to_exception(e.code, e.message), None) + + def _numpy_internal(self): + raise NotImplementedError() def _numpy(self): - raise NotImplementedError() + # pylint: disable=protected-access + try: + return self._numpy_internal() + except core._NotOkStatusException as e: + six.raise_from(core._status_to_exception(e.code, e.message), None) @property def dtype(self): @@ -1036,9 +1047,14 @@ class _EagerTensorBase(Tensor): @property def shape(self): if self._tensor_shape is None: # pylint: disable=access-member-before-definition - # `_tensor_shape` is declared and defined in the definition of - # `EagerTensor`, in C. - self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) + # pylint: disable=protected-access + try: + # `_tensor_shape` is declared and defined in the definition of + # `EagerTensor`, in C. + self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) + except core._NotOkStatusException as e: + six.raise_from(core._status_to_exception(e.code, e.message), None) + return self._tensor_shape def get_shape(self): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 70c6c7ecfbc..5077badc619 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -656,12 +656,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(v.handle.op.colocation_groups(), v.initializer.inputs[1].op.colocation_groups()) - def testHandleNumpy(self): - with context.eager_mode(): - with self.assertRaises(ValueError): - resource_variable_ops.ResourceVariable( - 1.0, name="handle-numpy").handle.numpy() - def testCountUpTo(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(0, name="upto")