This change includes:
1. When failed to compute a eager tensor, let the EagerTensor throw proper computation exception instead of ValueError. 2. Stop using StreamingEnqueueAsync for destroy tensor handle request. With this change, turning on streaming rpc won't break dataset iterator anymore. PiperOrigin-RevId: 267222175
This commit is contained in:
parent
bdc3ca8b84
commit
8b0e6baeb3
@ -28,17 +28,23 @@ namespace eager {
|
||||
class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
|
||||
public:
|
||||
DestroyTensorHandleNode(std::unique_ptr<EnqueueRequest> request,
|
||||
EagerClient* eager_client)
|
||||
EagerClient* eager_client, bool ready)
|
||||
: tensorflow::AsyncEagerNode(),
|
||||
request_(std::move(request)),
|
||||
eager_client_(eager_client) {}
|
||||
eager_client_(eager_client),
|
||||
ready_(ready) {}
|
||||
|
||||
void RunAsync(StatusCallback done) override {
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
eager_client_->StreamingEnqueueAsync(
|
||||
bool ready = ready_;
|
||||
// NOTE(fishx): Don't use StreamingEnqueueAsync here. When a
|
||||
// StreamingEnqueueAsync request fails all following requests will fail as
|
||||
// well. We don't want this request poison following requests since it is
|
||||
// safe to ignore a failing destroy tensor handle request.
|
||||
eager_client_->EnqueueAsync(
|
||||
request_.get(), response,
|
||||
[response, done](const tensorflow::Status& s) {
|
||||
if (!s.ok()) {
|
||||
[response, ready, done](const tensorflow::Status& s) {
|
||||
if (!s.ok() && ready) {
|
||||
LOG(WARNING) << "Ignoring an error encountered when deleting "
|
||||
"remote tensors handles: "
|
||||
<< s.ToString();
|
||||
@ -59,6 +65,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
|
||||
private:
|
||||
std::unique_ptr<EnqueueRequest> request_;
|
||||
EagerClient* eager_client_; // Not owned, and must outlive this node.
|
||||
bool ready_;
|
||||
};
|
||||
|
||||
} // namespace eager
|
||||
|
@ -25,8 +25,8 @@ namespace {
|
||||
|
||||
void DestoryRemoteTensorHandle(EagerContext* ctx,
|
||||
eager::EagerClient* eager_client,
|
||||
uint64 context_id, uint64 op_id,
|
||||
int output_num) {
|
||||
uint64 context_id, uint64 op_id, int output_num,
|
||||
bool ready) {
|
||||
if (ctx->GetContextId() != context_id) {
|
||||
// This means that this tensor was pointing to a remote device, which
|
||||
// has been changed out from under us. Simply return since there is
|
||||
@ -44,7 +44,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
||||
VLOG(3) << "Sending request to delete " << request->DebugString();
|
||||
std::unique_ptr<EagerNode> node(
|
||||
absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request),
|
||||
eager_client));
|
||||
eager_client, ready));
|
||||
auto* executor = ctx->Executor();
|
||||
if (executor->Async()) {
|
||||
Status status = executor->Add(std::move(node));
|
||||
@ -87,7 +87,7 @@ RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
||||
|
||||
RemoteTensorHandleData::~RemoteTensorHandleData() {
|
||||
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
||||
output_num_);
|
||||
output_num_, /*ready=*/true);
|
||||
ctx_->Unref();
|
||||
}
|
||||
|
||||
@ -150,7 +150,7 @@ UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
|
||||
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
|
||||
if (delete_remote_tensor_) {
|
||||
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
||||
output_num_);
|
||||
output_num_, /*ready=*/false);
|
||||
}
|
||||
ctx_->Unref();
|
||||
}
|
||||
|
@ -580,7 +580,7 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
auto handle = self->handle;
|
||||
int n = TFE_TensorHandleNumDims(handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
return nullptr;
|
||||
@ -590,7 +590,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
PyObject* dim =
|
||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
|
||||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
@ -606,7 +606,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
// Getter for `_rank`.
|
||||
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
return nullptr;
|
||||
@ -622,7 +622,7 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
||||
auto handle = self->handle;
|
||||
int n = TFE_TensorHandleNumElements(handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
return nullptr;
|
||||
@ -680,14 +680,14 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
||||
return EagerTensorFromHandle(handle);
|
||||
}
|
||||
|
||||
// Function `_numpy`.
|
||||
// Function `_numpy_internal`.
|
||||
// Convert an EagerTensor to a Python numpy.ndarray object.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
// Note that if `self` is not on CPU, we raise an Exception.
|
||||
static PyObject* EagerTensor_numpy(EagerTensor* self) {
|
||||
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
||||
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
Py_XDECREF(py_array);
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
@ -754,8 +754,8 @@ static PyMemberDef EagerTensor_members[] = {
|
||||
#endif
|
||||
|
||||
static PyMethodDef EagerTensor_methods[] = {
|
||||
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
||||
PyDoc_STR("_numpy")},
|
||||
{"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
|
||||
PyDoc_STR("_numpy_internal")},
|
||||
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
|
||||
PyDoc_STR("_datatype_enum")},
|
||||
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -380,7 +381,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_numpyFailsForResource(self):
|
||||
v = variables.Variable(42)
|
||||
with self.assertRaisesRegex(ValueError, "Cannot convert .+ resource"):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Cannot convert .+ resource"):
|
||||
v._handle._numpy()
|
||||
|
||||
def testMemoryviewFailsForResource(self):
|
||||
|
@ -910,10 +910,21 @@ class _EagerTensorBase(Tensor):
|
||||
"""Returns the length of the first dimension in the Tensor."""
|
||||
if not self.shape.ndims:
|
||||
raise TypeError("Scalar tensor has no `len()`")
|
||||
return self._shape_tuple()[0]
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
return self._shape_tuple()[0]
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
|
||||
def _numpy_internal(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _numpy(self):
|
||||
raise NotImplementedError()
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
return self._numpy_internal()
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -1036,9 +1047,14 @@ class _EagerTensorBase(Tensor):
|
||||
@property
|
||||
def shape(self):
|
||||
if self._tensor_shape is None: # pylint: disable=access-member-before-definition
|
||||
# `_tensor_shape` is declared and defined in the definition of
|
||||
# `EagerTensor`, in C.
|
||||
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
# `_tensor_shape` is declared and defined in the definition of
|
||||
# `EagerTensor`, in C.
|
||||
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
|
||||
return self._tensor_shape
|
||||
|
||||
def get_shape(self):
|
||||
|
@ -656,12 +656,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
self.assertEqual(v.handle.op.colocation_groups(),
|
||||
v.initializer.inputs[1].op.colocation_groups())
|
||||
|
||||
def testHandleNumpy(self):
|
||||
with context.eager_mode():
|
||||
with self.assertRaises(ValueError):
|
||||
resource_variable_ops.ResourceVariable(
|
||||
1.0, name="handle-numpy").handle.numpy()
|
||||
|
||||
def testCountUpTo(self):
|
||||
with context.eager_mode():
|
||||
v = resource_variable_ops.ResourceVariable(0, name="upto")
|
||||
|
Loading…
Reference in New Issue
Block a user