Set py_func private executor in c++ instead of python.

When py_func throws a exception, the resources will be held until the exception is decref in C++. If the resource is a remote TensorHandle, it will enqueue a destroy_tensor node into executor. This change ensures that these destroy_tensor nodes are executed by the pyfunc private executor.

PiperOrigin-RevId: 263014289
This commit is contained in:
Xiao Yu 2019-08-12 15:08:00 -07:00 committed by TensorFlower Gardener
parent 7685639e2f
commit 970cf44c28
5 changed files with 60 additions and 37 deletions

View File

@ -39,6 +39,7 @@ REGISTER_OP("EagerPyFunc")
.Input("input: Tin") .Input("input: Tin")
.Output("output: Tout") .Output("output: Tout")
.Attr("token: string") .Attr("token: string")
.Attr("is_async: bool=false")
.Attr("Tin: list(type) >= 0") .Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >=0") .Attr("Tout: list(type) >=0")
.SetIsStateful() .SetIsStateful()

View File

@ -61,6 +61,9 @@ struct PyCall {
// True if the call is associated with an EagerPyFunc. // True if the call is associated with an EagerPyFunc.
bool eager = false; bool eager = false;
// True if the call is running under eager async mode.
bool eager_async = false;
// Inputs and outputs of this function invocation. // Inputs and outputs of this function invocation.
std::vector<Tensor> ins; std::vector<Tensor> ins;
std::vector<Tensor> out; std::vector<Tensor> out;
@ -173,12 +176,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
// Prepare the argument. // Prepare the argument.
PyObject* args = nullptr; PyObject* args = nullptr;
TFE_Context* ctx = nullptr;
std::unique_ptr<EagerExecutor> new_executor = nullptr;
EagerExecutor* old_executor = nullptr;
if (call->eager) { if (call->eager) {
// See FuncRegistry._ctx. // See FuncRegistry._ctx.
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer( ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
CHECK_NE(ctx, nullptr); CHECK_NE(ctx, nullptr);
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args)); TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args));
new_executor.reset(new EagerExecutor(call->eager_async));
old_executor = ctx->context->Executor();
ctx->context->SetExecutorForThread(new_executor.get());
} else { } else {
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args)); TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
} }
@ -187,31 +196,38 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
// Invokes the trampoline. // Invokes the trampoline.
PyObject* result = PyEval_CallObject(trampoline, args); PyObject* result = PyEval_CallObject(trampoline, args);
Py_DECREF(args); Py_DECREF(args);
Status s = Status::OK();
if (result == nullptr) { if (result == nullptr) {
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_ValueError) || if (PyErr_ExceptionMatches(PyExc_ValueError) ||
PyErr_ExceptionMatches(PyExc_TypeError)) { PyErr_ExceptionMatches(PyExc_TypeError)) {
return errors::InvalidArgument(PyExceptionFetch()); s = errors::InvalidArgument(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) { } else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
*out_log_on_error = false; *out_log_on_error = false;
return errors::OutOfRange(PyExceptionFetch()); s = errors::OutOfRange(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) { } else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
return errors::ResourceExhausted(PyExceptionFetch()); s = errors::ResourceExhausted(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) { } else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
return errors::Unimplemented(PyExceptionFetch()); s = errors::Unimplemented(PyExceptionFetch());
} else { } else {
// TODO(ebrevdo): Check if exception is an OpError and use the // TODO(ebrevdo): Check if exception is an OpError and use the
// OpError.error_code property to map it back in the Status. // OpError.error_code property to map it back in the Status.
return errors::Unknown(PyExceptionFetch()); s = errors::Unknown(PyExceptionFetch());
} }
} else { } else {
return errors::Internal("Failed to run py callback ", call->token, s = errors::Internal("Failed to run py callback ", call->token,
": see error log."); ": see error log.");
} }
} }
if (new_executor != nullptr) {
s.Update(new_executor->WaitForAllPendingNodes());
ctx->context->SetExecutorForThread(old_executor);
}
TF_RETURN_IF_ERROR(s);
// Process the return values and convert them to TF Tensors. // Process the return values and convert them to TF Tensors.
Status s = Status::OK();
if (PyList_Check(result)) { if (PyList_Check(result)) {
// `result` is a Python list; if this operation is an `EagerPyFunc`, then // `result` is a Python list; if this operation is an `EagerPyFunc`, then
// every item in the list must be an `EagerTensor`; otherwise, every element // every item in the list must be an `EagerTensor`; otherwise, every element
@ -282,6 +298,9 @@ class PyFuncOp : public OpKernel {
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc"; eager_ = type_string() == "EagerPyFunc";
if (eager_) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_async", &eager_async_));
}
} }
bool IsExpensive() override { return true; } bool IsExpensive() override { return true; }
@ -299,6 +318,7 @@ class PyFuncOp : public OpKernel {
"Unrecognized device class: ", ctx->device()->name())); "Unrecognized device class: ", ctx->device()->name()));
return; return;
} }
call.eager_async = eager_async_;
} }
for (int i = 0; i < ctx->num_inputs(); ++i) { for (int i = 0; i < ctx->num_inputs(); ++i) {
@ -357,6 +377,8 @@ class PyFuncOp : public OpKernel {
// i.e., if and only if the eager attribute is set. // i.e., if and only if the eager attribute is set.
bool eager_; bool eager_;
bool eager_async_;
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp); TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
}; };

View File

@ -31,7 +31,6 @@ import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import executor
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -102,30 +101,27 @@ class EagerFunc(object):
def __call__(self, device, token, args): def __call__(self, device, token, args):
"""Passes `args` to `self._func`, which is executed eagerly.""" """Passes `args` to `self._func`, which is executed eagerly."""
func_executor = executor.new_executor(context.is_async()) with context.eager_mode(), backprop.GradientTape() as tape:
with context.executor_scope(func_executor): # Only watch tensors with a floating dtype.
with context.eager_mode(), backprop.GradientTape() as tape: for tensor in args:
# Only watch tensors with a floating dtype. for t in nest.flatten(tensor):
for tensor in args: if t.dtype.is_floating:
for t in nest.flatten(tensor): tape.watch(t)
if t.dtype.is_floating: ret = self._func(*args)
tape.watch(t) # Use tf.identity to copy the returned tensors to device if necessary.
ret = self._func(*args) with ops.device(device):
# Use tf.identity to copy the returned tensors to device if necessary. if isinstance(ret, (tuple, list)):
with ops.device(device): outputs = [
if isinstance(ret, (tuple, list)): array_ops.identity(self._convert(x, dtype=dtype))
outputs = [ for (x, dtype) in zip(ret, self._out_dtypes)
array_ops.identity(self._convert(x, dtype=dtype)) ]
for (x, dtype) in zip(ret, self._out_dtypes) elif ret is None:
] outputs = None
elif ret is None: else:
outputs = None outputs = array_ops.identity(
else: self._convert(ret, dtype=self._out_dtypes[0]))
outputs = array_ops.identity( tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
self._convert(ret, dtype=self._out_dtypes[0])) return outputs
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
return outputs
func_executor.wait()
class FuncRegistry(object): class FuncRegistry(object):
@ -290,7 +286,11 @@ def _internal_py_func(func,
if eager: if eager:
result = gen_script_ops.eager_py_func( result = gen_script_ops.eager_py_func(
input=inp, token=token, Tout=Tout, name=name) input=inp,
token=token,
is_async=context.is_async(),
Tout=Tout,
name=name)
else: else:
if stateful: if stateful:
result = gen_script_ops.py_func( result = gen_script_ops.py_func(

View File

@ -1090,7 +1090,7 @@ tf_module {
} }
member_method { member_method {
name: "EagerPyFunc" name: "EagerPyFunc"
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
} }
member_method { member_method {
name: "EditDistance" name: "EditDistance"

View File

@ -1090,7 +1090,7 @@ tf_module {
} }
member_method { member_method {
name: "EagerPyFunc" name: "EagerPyFunc"
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
} }
member_method { member_method {
name: "EditDistance" name: "EditDistance"