This change includes:

1. When failed to compute a eager tensor, let the EagerTensor throw proper computation exception instead of ValueError.
2. Stop using StreamingEnqueueAsync for destroy tensor handle request.

With this change, turning on streaming rpc won't break dataset iterator anymore.

PiperOrigin-RevId: 267222175
This commit is contained in:
Xiao Yu 2019-09-04 13:43:48 -07:00 committed by TensorFlower Gardener
parent bdc3ca8b84
commit 8b0e6baeb3
6 changed files with 50 additions and 31 deletions

View File

@ -28,17 +28,23 @@ namespace eager {
class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
public:
DestroyTensorHandleNode(std::unique_ptr<EnqueueRequest> request,
EagerClient* eager_client)
EagerClient* eager_client, bool ready)
: tensorflow::AsyncEagerNode(),
request_(std::move(request)),
eager_client_(eager_client) {}
eager_client_(eager_client),
ready_(ready) {}
void RunAsync(StatusCallback done) override {
EnqueueResponse* response = new EnqueueResponse;
eager_client_->StreamingEnqueueAsync(
bool ready = ready_;
// NOTE(fishx): Don't use StreamingEnqueueAsync here. When a
// StreamingEnqueueAsync request fails all following requests will fail as
// well. We don't want this request poison following requests since it is
// safe to ignore a failing destroy tensor handle request.
eager_client_->EnqueueAsync(
request_.get(), response,
[response, done](const tensorflow::Status& s) {
if (!s.ok()) {
[response, ready, done](const tensorflow::Status& s) {
if (!s.ok() && ready) {
LOG(WARNING) << "Ignoring an error encountered when deleting "
"remote tensors handles: "
<< s.ToString();
@ -59,6 +65,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode {
private:
std::unique_ptr<EnqueueRequest> request_;
EagerClient* eager_client_; // Not owned, and must outlive this node.
bool ready_;
};
} // namespace eager

View File

@ -25,8 +25,8 @@ namespace {
void DestoryRemoteTensorHandle(EagerContext* ctx,
eager::EagerClient* eager_client,
uint64 context_id, uint64 op_id,
int output_num) {
uint64 context_id, uint64 op_id, int output_num,
bool ready) {
if (ctx->GetContextId() != context_id) {
// This means that this tensor was pointing to a remote device, which
// has been changed out from under us. Simply return since there is
@ -44,7 +44,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
VLOG(3) << "Sending request to delete " << request->DebugString();
std::unique_ptr<EagerNode> node(
absl::make_unique<eager::DestroyTensorHandleNode>(std::move(request),
eager_client));
eager_client, ready));
auto* executor = ctx->Executor();
if (executor->Async()) {
Status status = executor->Add(std::move(node));
@ -87,7 +87,7 @@ RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
RemoteTensorHandleData::~RemoteTensorHandleData() {
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
output_num_);
output_num_, /*ready=*/true);
ctx_->Unref();
}
@ -150,7 +150,7 @@ UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
if (delete_remote_tensor_) {
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
output_num_);
output_num_, /*ready=*/false);
}
ctx_->Unref();
}

View File

@ -580,7 +580,7 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
auto handle = self->handle;
int n = TFE_TensorHandleNumDims(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
@ -590,7 +590,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
for (int i = 0; i < n; ++i) {
PyObject* dim =
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
@ -606,7 +606,7 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
// Getter for `_rank`.
static PyObject* EagerTensor_rank(EagerTensor* self) {
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
@ -622,7 +622,7 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
@ -680,14 +680,14 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
return EagerTensorFromHandle(handle);
}
// Function `_numpy`.
// Function `_numpy_internal`.
// Convert an EagerTensor to a Python numpy.ndarray object.
// The two may share underlying storage so changes to one may reflect in the
// other.
// Note that if `self` is not on CPU, we raise an Exception.
static PyObject* EagerTensor_numpy(EagerTensor* self) {
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
Py_XDECREF(py_array);
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
@ -754,8 +754,8 @@ static PyMemberDef EagerTensor_members[] = {
#endif
static PyMethodDef EagerTensor_methods[] = {
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
PyDoc_STR("_numpy")},
{"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
PyDoc_STR("_numpy_internal")},
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
PyDoc_STR("_datatype_enum")},
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,

View File

@ -32,6 +32,7 @@ from tensorflow.python.eager import core
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@ -380,7 +381,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
def test_numpyFailsForResource(self):
v = variables.Variable(42)
with self.assertRaisesRegex(ValueError, "Cannot convert .+ resource"):
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Cannot convert .+ resource"):
v._handle._numpy()
def testMemoryviewFailsForResource(self):

View File

@ -910,10 +910,21 @@ class _EagerTensorBase(Tensor):
"""Returns the length of the first dimension in the Tensor."""
if not self.shape.ndims:
raise TypeError("Scalar tensor has no `len()`")
return self._shape_tuple()[0]
# pylint: disable=protected-access
try:
return self._shape_tuple()[0]
except core._NotOkStatusException as e:
six.raise_from(core._status_to_exception(e.code, e.message), None)
def _numpy_internal(self):
raise NotImplementedError()
def _numpy(self):
raise NotImplementedError()
# pylint: disable=protected-access
try:
return self._numpy_internal()
except core._NotOkStatusException as e:
six.raise_from(core._status_to_exception(e.code, e.message), None)
@property
def dtype(self):
@ -1036,9 +1047,14 @@ class _EagerTensorBase(Tensor):
@property
def shape(self):
if self._tensor_shape is None: # pylint: disable=access-member-before-definition
# `_tensor_shape` is declared and defined in the definition of
# `EagerTensor`, in C.
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
# pylint: disable=protected-access
try:
# `_tensor_shape` is declared and defined in the definition of
# `EagerTensor`, in C.
self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
except core._NotOkStatusException as e:
six.raise_from(core._status_to_exception(e.code, e.message), None)
return self._tensor_shape
def get_shape(self):

View File

@ -656,12 +656,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertEqual(v.handle.op.colocation_groups(),
v.initializer.inputs[1].op.colocation_groups())
def testHandleNumpy(self):
with context.eager_mode():
with self.assertRaises(ValueError):
resource_variable_ops.ResourceVariable(
1.0, name="handle-numpy").handle.numpy()
def testCountUpTo(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable(0, name="upto")