Set py_func private executor in c++ instead of python.

When py_func throws a exception, the resources will be held until the exception is decref in C++. If the resource is a remote TensorHandle, it will enqueue a destroy_tensor node into executor. This change ensures that these destroy_tensor nodes are executed by the pyfunc private executor.

PiperOrigin-RevId: 263014289
This commit is contained in:
Xiao Yu 2019-08-12 15:08:00 -07:00 committed by TensorFlower Gardener
parent 7685639e2f
commit 970cf44c28
5 changed files with 60 additions and 37 deletions

View File

@ -39,6 +39,7 @@ REGISTER_OP("EagerPyFunc")
.Input("input: Tin")
.Output("output: Tout")
.Attr("token: string")
.Attr("is_async: bool=false")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >=0")
.SetIsStateful()

View File

@ -61,6 +61,9 @@ struct PyCall {
// True if the call is associated with an EagerPyFunc.
bool eager = false;
// True if the call is running under eager async mode.
bool eager_async = false;
// Inputs and outputs of this function invocation.
std::vector<Tensor> ins;
std::vector<Tensor> out;
@ -173,12 +176,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
// Prepare the argument.
PyObject* args = nullptr;
TFE_Context* ctx = nullptr;
std::unique_ptr<EagerExecutor> new_executor = nullptr;
EagerExecutor* old_executor = nullptr;
if (call->eager) {
// See FuncRegistry._ctx.
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
CHECK_NE(ctx, nullptr);
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args));
new_executor.reset(new EagerExecutor(call->eager_async));
old_executor = ctx->context->Executor();
ctx->context->SetExecutorForThread(new_executor.get());
} else {
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
}
@ -187,31 +196,38 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
// Invokes the trampoline.
PyObject* result = PyEval_CallObject(trampoline, args);
Py_DECREF(args);
Status s = Status::OK();
if (result == nullptr) {
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
PyErr_ExceptionMatches(PyExc_TypeError)) {
return errors::InvalidArgument(PyExceptionFetch());
s = errors::InvalidArgument(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
*out_log_on_error = false;
return errors::OutOfRange(PyExceptionFetch());
s = errors::OutOfRange(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
return errors::ResourceExhausted(PyExceptionFetch());
s = errors::ResourceExhausted(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
return errors::Unimplemented(PyExceptionFetch());
s = errors::Unimplemented(PyExceptionFetch());
} else {
// TODO(ebrevdo): Check if exception is an OpError and use the
// OpError.error_code property to map it back in the Status.
return errors::Unknown(PyExceptionFetch());
s = errors::Unknown(PyExceptionFetch());
}
} else {
return errors::Internal("Failed to run py callback ", call->token,
": see error log.");
s = errors::Internal("Failed to run py callback ", call->token,
": see error log.");
}
}
if (new_executor != nullptr) {
s.Update(new_executor->WaitForAllPendingNodes());
ctx->context->SetExecutorForThread(old_executor);
}
TF_RETURN_IF_ERROR(s);
// Process the return values and convert them to TF Tensors.
Status s = Status::OK();
if (PyList_Check(result)) {
// `result` is a Python list; if this operation is an `EagerPyFunc`, then
// every item in the list must be an `EagerTensor`; otherwise, every element
@ -282,6 +298,9 @@ class PyFuncOp : public OpKernel {
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc";
if (eager_) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_async", &eager_async_));
}
}
bool IsExpensive() override { return true; }
@ -299,6 +318,7 @@ class PyFuncOp : public OpKernel {
"Unrecognized device class: ", ctx->device()->name()));
return;
}
call.eager_async = eager_async_;
}
for (int i = 0; i < ctx->num_inputs(); ++i) {
@ -357,6 +377,8 @@ class PyFuncOp : public OpKernel {
// i.e., if and only if the eager attribute is set.
bool eager_;
bool eager_async_;
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
};

View File

@ -31,7 +31,6 @@ import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import executor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
@ -102,30 +101,27 @@ class EagerFunc(object):
def __call__(self, device, token, args):
"""Passes `args` to `self._func`, which is executed eagerly."""
func_executor = executor.new_executor(context.is_async())
with context.executor_scope(func_executor):
with context.eager_mode(), backprop.GradientTape() as tape:
# Only watch tensors with a floating dtype.
for tensor in args:
for t in nest.flatten(tensor):
if t.dtype.is_floating:
tape.watch(t)
ret = self._func(*args)
# Use tf.identity to copy the returned tensors to device if necessary.
with ops.device(device):
if isinstance(ret, (tuple, list)):
outputs = [
array_ops.identity(self._convert(x, dtype=dtype))
for (x, dtype) in zip(ret, self._out_dtypes)
]
elif ret is None:
outputs = None
else:
outputs = array_ops.identity(
self._convert(ret, dtype=self._out_dtypes[0]))
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
return outputs
func_executor.wait()
with context.eager_mode(), backprop.GradientTape() as tape:
# Only watch tensors with a floating dtype.
for tensor in args:
for t in nest.flatten(tensor):
if t.dtype.is_floating:
tape.watch(t)
ret = self._func(*args)
# Use tf.identity to copy the returned tensors to device if necessary.
with ops.device(device):
if isinstance(ret, (tuple, list)):
outputs = [
array_ops.identity(self._convert(x, dtype=dtype))
for (x, dtype) in zip(ret, self._out_dtypes)
]
elif ret is None:
outputs = None
else:
outputs = array_ops.identity(
self._convert(ret, dtype=self._out_dtypes[0]))
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
return outputs
class FuncRegistry(object):
@ -290,7 +286,11 @@ def _internal_py_func(func,
if eager:
result = gen_script_ops.eager_py_func(
input=inp, token=token, Tout=Tout, name=name)
input=inp,
token=token,
is_async=context.is_async(),
Tout=Tout,
name=name)
else:
if stateful:
result = gen_script_ops.py_func(

View File

@ -1090,7 +1090,7 @@ tf_module {
}
member_method {
name: "EagerPyFunc"
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "EditDistance"

View File

@ -1090,7 +1090,7 @@ tf_module {
}
member_method {
name: "EagerPyFunc"
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "EditDistance"