This change includes:
1. When failed to compute a eager tensor, let the EagerTensor throw proper computation exception instead of ValueError. 2. Stop using StreamingEnqueueAsync for destroy tensor handle request. With this change, turning on streaming rpc won't break dataset iterator anymore. PiperOrigin-RevId: 267222175
This commit is contained in:
parent
bdc3ca8b84
commit
8b0e6baeb3
@ -28,17 +28,23 @@ namespace eager {
|
|||||||
class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
|
class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
|
||||||
public:
|
public:
|
||||||
DestroyTensorHandleNode(std::unique_ptr<EnqueueRequest> request,
|
DestroyTensorHandleNode(std::unique_ptr<EnqueueRequest> request,
|
||||||
EagerClient* eager_client)
|
EagerClient* eager_client, bool ready)
|
||||||
: tensorflow::AsyncEagerNode(),
|
: tensorflow::AsyncEagerNode(),
|
||||||
request_(std::move(request)),
|
request_(std::move(request)),
|
||||||
eager_client_(eager_client) {}
|
eager_client_(eager_client),
|
||||||
|
ready_(ready) {}
|
||||||
|
|
||||||
void RunAsync(StatusCallback done) override {
|
void RunAsync(StatusCallback done) override {
|
||||||
EnqueueResponse* response = new EnqueueResponse;
|
EnqueueResponse* response = new EnqueueResponse;
|
||||||
eager_client_->StreamingEnqueueAsync(
|
bool ready = ready_;
|
||||||
|
// NOTE(fishx): Don't use StreamingEnqueueAsync here. When a
|
||||||
|
// StreamingEnqueueAsync request fails all following requests will fail as
|
||||||
|
// well. We don't want this request poison following requests since it is
|
||||||
|
// safe to ignore a failing destroy tensor handle request.
|
||||||
|
eager_client_->EnqueueAsync(
|
||||||
request_.get(), response,
|
request_.get(), response,
|
||||||
[response, done](const tensorflow::Status& s) {
|
[response, ready, done](const tensorflow::Status& s) {
|
||||||
if (!s.ok()) {
|
if (!s.ok() && ready) {
|
||||||
LOG(WARNING) << "Ignoring an error encountered when deleting "
|
LOG(WARNING) << "Ignoring an error encountered when deleting "
|
||||||
"remote tensors handles: "
|
"remote tensors handles: "
|
||||||
<< s.ToString();
|
<< s.ToString();
|
||||||
@ -59,6 +65,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
|
|||||||
private:
|
private:
|
||||||
std::unique_ptr<EnqueueRequest> request_;
|
std::unique_ptr<EnqueueRequest> request_;
|
||||||
EagerClient* eager_client_; // Not owned, and must outlive this node.
|
EagerClient* eager_client_; // Not owned, and must outlive this node.
|
||||||
|
bool ready_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
|
@ -25,8 +25,8 @@ namespace {
|
|||||||
|
|
||||||
void DestoryRemoteTensorHandle(EagerContext* ctx,
|
void DestoryRemoteTensorHandle(EagerContext* ctx,
|
||||||
eager::EagerClient* eager_client,
|
eager::EagerClient* eager_client,
|
||||||
uint64 context_id, uint64 op_id,
|
uint64 context_id, uint64 op_id, int output_num,
|
||||||
int output_num) {
|
bool ready) {
|
||||||
if (ctx->GetContextId() != context_id) {
|
if (ctx->GetContextId() != context_id) {
|
||||||
// This means that this tensor was pointing to a remote device, which
|
// This means that this tensor was pointing to a remote device, which
|
||||||
// has been changed out from under us. Simply return since there is
|
// has been changed out from under us. Simply return since there is
|
||||||
@ -44,7 +44,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
|||||||
VLOG(3) << "Sending request to delete " << request->DebugString();
|
VLOG(3) << "Sending request to delete " << request->DebugString();
|
||||||
std::unique_ptr<EagerNode> node(
|
std::unique_ptr<EagerNode> node(
|
||||||
absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request),
|
absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request),
|
||||||
eager_client));
|
eager_client, ready));
|
||||||
auto* executor = ctx->Executor();
|
auto* executor = ctx->Executor();
|
||||||
if (executor->Async()) {
|
if (executor->Async()) {
|
||||||
Status status = executor->Add(std::move(node));
|
Status status = executor->Add(std::move(node));
|
||||||
@ -87,7 +87,7 @@ RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
|||||||
|
|
||||||
RemoteTensorHandleData::~RemoteTensorHandleData() {
|
RemoteTensorHandleData::~RemoteTensorHandleData() {
|
||||||
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
||||||
output_num_);
|
output_num_, /*ready=*/true);
|
||||||
ctx_->Unref();
|
ctx_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
|
|||||||
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
|
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
|
||||||
if (delete_remote_tensor_) {
|
if (delete_remote_tensor_) {
|
||||||
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
||||||
output_num_);
|
output_num_, /*ready=*/false);
|
||||||
}
|
}
|
||||||
ctx_->Unref();
|
ctx_->Unref();
|
||||||
}
|
}
|
||||||
|
@ -580,7 +580,7 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
|||||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||||
auto handle = self->handle;
|
auto handle = self->handle;
|
||||||
int n = TFE_TensorHandleNumDims(handle, self->status);
|
int n = TFE_TensorHandleNumDims(handle, self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
TF_SetStatus(self->status, TF_OK, "");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -590,7 +590,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
|||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
PyObject* dim =
|
PyObject* dim =
|
||||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
|
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
|
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
|
||||||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
TF_SetStatus(self->status, TF_OK, "");
|
||||||
@ -606,7 +606,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
|||||||
// Getter for `_rank`.
|
// Getter for `_rank`.
|
||||||
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||||
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
|
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
TF_SetStatus(self->status, TF_OK, "");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -622,7 +622,7 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
|
|||||||
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
||||||
auto handle = self->handle;
|
auto handle = self->handle;
|
||||||
int n = TFE_TensorHandleNumElements(handle, self->status);
|
int n = TFE_TensorHandleNumElements(handle, self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
TF_SetStatus(self->status, TF_OK, "");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -680,14 +680,14 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
|||||||
return EagerTensorFromHandle(handle);
|
return EagerTensorFromHandle(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function `_numpy`.
|
// Function `_numpy_internal`.
|
||||||
// Convert an EagerTensor to a Python numpy.ndarray object.
|
// Convert an EagerTensor to a Python numpy.ndarray object.
|
||||||
// The two may share underlying storage so changes to one may reflect in the
|
// The two may share underlying storage so changes to one may reflect in the
|
||||||
// other.
|
// other.
|
||||||
// Note that if `self` is not on CPU, we raise an Exception.
|
// Note that if `self` is not on CPU, we raise an Exception.
|
||||||
static PyObject* EagerTensor_numpy(EagerTensor* self) {
|
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
||||||
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
|
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||||
Py_XDECREF(py_array);
|
Py_XDECREF(py_array);
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
TF_SetStatus(self->status, TF_OK, "");
|
||||||
@ -754,8 +754,8 @@ static PyMemberDef EagerTensor_members[] = {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
static PyMethodDef EagerTensor_methods[] = {
|
static PyMethodDef EagerTensor_methods[] = {
|
||||||
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
{"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
|
||||||
PyDoc_STR("_numpy")},
|
PyDoc_STR("_numpy_internal")},
|
||||||
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
|
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
|
||||||
PyDoc_STR("_datatype_enum")},
|
PyDoc_STR("_datatype_enum")},
|
||||||
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.eager import core
|
|||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -380,7 +381,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def test_numpyFailsForResource(self):
|
def test_numpyFailsForResource(self):
|
||||||
v = variables.Variable(42)
|
v = variables.Variable(42)
|
||||||
with self.assertRaisesRegex(ValueError, "Cannot convert .+ resource"):
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Cannot convert .+ resource"):
|
||||||
v._handle._numpy()
|
v._handle._numpy()
|
||||||
|
|
||||||
def testMemoryviewFailsForResource(self):
|
def testMemoryviewFailsForResource(self):
|
||||||
|
@ -910,10 +910,21 @@ class _EagerTensorBase(Tensor):
|
|||||||
"""Returns the length of the first dimension in the Tensor."""
|
"""Returns the length of the first dimension in the Tensor."""
|
||||||
if not self.shape.ndims:
|
if not self.shape.ndims:
|
||||||
raise TypeError("Scalar tensor has no `len()`")
|
raise TypeError("Scalar tensor has no `len()`")
|
||||||
return self._shape_tuple()[0]
|
# pylint: disable=protected-access
|
||||||
|
try:
|
||||||
|
return self._shape_tuple()[0]
|
||||||
|
except core._NotOkStatusException as e:
|
||||||
|
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||||
|
|
||||||
|
def _numpy_internal(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _numpy(self):
|
def _numpy(self):
|
||||||
raise NotImplementedError()
|
# pylint: disable=protected-access
|
||||||
|
try:
|
||||||
|
return self._numpy_internal()
|
||||||
|
except core._NotOkStatusException as e:
|
||||||
|
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
@ -1036,9 +1047,14 @@ class _EagerTensorBase(Tensor):
|
|||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
if self._tensor_shape is None: # pylint: disable=access-member-before-definition
|
if self._tensor_shape is None: # pylint: disable=access-member-before-definition
|
||||||
# `_tensor_shape` is declared and defined in the definition of
|
# pylint: disable=protected-access
|
||||||
# `EagerTensor`, in C.
|
try:
|
||||||
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
|
# `_tensor_shape` is declared and defined in the definition of
|
||||||
|
# `EagerTensor`, in C.
|
||||||
|
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
|
||||||
|
except core._NotOkStatusException as e:
|
||||||
|
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||||
|
|
||||||
return self._tensor_shape
|
return self._tensor_shape
|
||||||
|
|
||||||
def get_shape(self):
|
def get_shape(self):
|
||||||
|
@ -656,12 +656,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
self.assertEqual(v.handle.op.colocation_groups(),
|
self.assertEqual(v.handle.op.colocation_groups(),
|
||||||
v.initializer.inputs[1].op.colocation_groups())
|
v.initializer.inputs[1].op.colocation_groups())
|
||||||
|
|
||||||
def testHandleNumpy(self):
|
|
||||||
with context.eager_mode():
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
resource_variable_ops.ResourceVariable(
|
|
||||||
1.0, name="handle-numpy").handle.numpy()
|
|
||||||
|
|
||||||
def testCountUpTo(self):
|
def testCountUpTo(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
v = resource_variable_ops.ResourceVariable(0, name="upto")
|
v = resource_variable_ops.ResourceVariable(0, name="upto")
|
||||||
|
Loading…
Reference in New Issue
Block a user