Set py_func private executor in c++ instead of python.
When py_func throws a exception, the resources will be held until the exception is decref in C++. If the resource is a remote TensorHandle, it will enqueue a destroy_tensor node into executor. This change ensures that these destroy_tensor nodes are executed by the pyfunc private executor. PiperOrigin-RevId: 263014289
This commit is contained in:
parent
7685639e2f
commit
970cf44c28
@ -39,6 +39,7 @@ REGISTER_OP("EagerPyFunc")
|
|||||||
.Input("input: Tin")
|
.Input("input: Tin")
|
||||||
.Output("output: Tout")
|
.Output("output: Tout")
|
||||||
.Attr("token: string")
|
.Attr("token: string")
|
||||||
|
.Attr("is_async: bool=false")
|
||||||
.Attr("Tin: list(type) >= 0")
|
.Attr("Tin: list(type) >= 0")
|
||||||
.Attr("Tout: list(type) >=0")
|
.Attr("Tout: list(type) >=0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
@ -61,6 +61,9 @@ struct PyCall {
|
|||||||
// True if the call is associated with an EagerPyFunc.
|
// True if the call is associated with an EagerPyFunc.
|
||||||
bool eager = false;
|
bool eager = false;
|
||||||
|
|
||||||
|
// True if the call is running under eager async mode.
|
||||||
|
bool eager_async = false;
|
||||||
|
|
||||||
// Inputs and outputs of this function invocation.
|
// Inputs and outputs of this function invocation.
|
||||||
std::vector<Tensor> ins;
|
std::vector<Tensor> ins;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
@ -173,12 +176,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||||||
|
|
||||||
// Prepare the argument.
|
// Prepare the argument.
|
||||||
PyObject* args = nullptr;
|
PyObject* args = nullptr;
|
||||||
|
TFE_Context* ctx = nullptr;
|
||||||
|
std::unique_ptr<EagerExecutor> new_executor = nullptr;
|
||||||
|
EagerExecutor* old_executor = nullptr;
|
||||||
if (call->eager) {
|
if (call->eager) {
|
||||||
// See FuncRegistry._ctx.
|
// See FuncRegistry._ctx.
|
||||||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
||||||
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
|
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
|
||||||
CHECK_NE(ctx, nullptr);
|
CHECK_NE(ctx, nullptr);
|
||||||
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args));
|
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args));
|
||||||
|
new_executor.reset(new EagerExecutor(call->eager_async));
|
||||||
|
old_executor = ctx->context->Executor();
|
||||||
|
ctx->context->SetExecutorForThread(new_executor.get());
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
|
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
|
||||||
}
|
}
|
||||||
@ -187,31 +196,38 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||||||
// Invokes the trampoline.
|
// Invokes the trampoline.
|
||||||
PyObject* result = PyEval_CallObject(trampoline, args);
|
PyObject* result = PyEval_CallObject(trampoline, args);
|
||||||
Py_DECREF(args);
|
Py_DECREF(args);
|
||||||
|
Status s = Status::OK();
|
||||||
if (result == nullptr) {
|
if (result == nullptr) {
|
||||||
if (PyErr_Occurred()) {
|
if (PyErr_Occurred()) {
|
||||||
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
||||||
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
||||||
return errors::InvalidArgument(PyExceptionFetch());
|
s = errors::InvalidArgument(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
||||||
*out_log_on_error = false;
|
*out_log_on_error = false;
|
||||||
return errors::OutOfRange(PyExceptionFetch());
|
s = errors::OutOfRange(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
||||||
return errors::ResourceExhausted(PyExceptionFetch());
|
s = errors::ResourceExhausted(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
||||||
return errors::Unimplemented(PyExceptionFetch());
|
s = errors::Unimplemented(PyExceptionFetch());
|
||||||
} else {
|
} else {
|
||||||
// TODO(ebrevdo): Check if exception is an OpError and use the
|
// TODO(ebrevdo): Check if exception is an OpError and use the
|
||||||
// OpError.error_code property to map it back in the Status.
|
// OpError.error_code property to map it back in the Status.
|
||||||
return errors::Unknown(PyExceptionFetch());
|
s = errors::Unknown(PyExceptionFetch());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return errors::Internal("Failed to run py callback ", call->token,
|
s = errors::Internal("Failed to run py callback ", call->token,
|
||||||
": see error log.");
|
": see error log.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (new_executor != nullptr) {
|
||||||
|
s.Update(new_executor->WaitForAllPendingNodes());
|
||||||
|
ctx->context->SetExecutorForThread(old_executor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
|
||||||
// Process the return values and convert them to TF Tensors.
|
// Process the return values and convert them to TF Tensors.
|
||||||
Status s = Status::OK();
|
|
||||||
if (PyList_Check(result)) {
|
if (PyList_Check(result)) {
|
||||||
// `result` is a Python list; if this operation is an `EagerPyFunc`, then
|
// `result` is a Python list; if this operation is an `EagerPyFunc`, then
|
||||||
// every item in the list must be an `EagerTensor`; otherwise, every element
|
// every item in the list must be an `EagerTensor`; otherwise, every element
|
||||||
@ -282,6 +298,9 @@ class PyFuncOp : public OpKernel {
|
|||||||
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
||||||
eager_ = type_string() == "EagerPyFunc";
|
eager_ = type_string() == "EagerPyFunc";
|
||||||
|
if (eager_) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_async", &eager_async_));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsExpensive() override { return true; }
|
bool IsExpensive() override { return true; }
|
||||||
@ -299,6 +318,7 @@ class PyFuncOp : public OpKernel {
|
|||||||
"Unrecognized device class: ", ctx->device()->name()));
|
"Unrecognized device class: ", ctx->device()->name()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
call.eager_async = eager_async_;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
@ -357,6 +377,8 @@ class PyFuncOp : public OpKernel {
|
|||||||
// i.e., if and only if the eager attribute is set.
|
// i.e., if and only if the eager attribute is set.
|
||||||
bool eager_;
|
bool eager_;
|
||||||
|
|
||||||
|
bool eager_async_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -31,7 +31,6 @@ import six
|
|||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import executor
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -102,30 +101,27 @@ class EagerFunc(object):
|
|||||||
def __call__(self, device, token, args):
|
def __call__(self, device, token, args):
|
||||||
"""Passes `args` to `self._func`, which is executed eagerly."""
|
"""Passes `args` to `self._func`, which is executed eagerly."""
|
||||||
|
|
||||||
func_executor = executor.new_executor(context.is_async())
|
with context.eager_mode(), backprop.GradientTape() as tape:
|
||||||
with context.executor_scope(func_executor):
|
# Only watch tensors with a floating dtype.
|
||||||
with context.eager_mode(), backprop.GradientTape() as tape:
|
for tensor in args:
|
||||||
# Only watch tensors with a floating dtype.
|
for t in nest.flatten(tensor):
|
||||||
for tensor in args:
|
if t.dtype.is_floating:
|
||||||
for t in nest.flatten(tensor):
|
tape.watch(t)
|
||||||
if t.dtype.is_floating:
|
ret = self._func(*args)
|
||||||
tape.watch(t)
|
# Use tf.identity to copy the returned tensors to device if necessary.
|
||||||
ret = self._func(*args)
|
with ops.device(device):
|
||||||
# Use tf.identity to copy the returned tensors to device if necessary.
|
if isinstance(ret, (tuple, list)):
|
||||||
with ops.device(device):
|
outputs = [
|
||||||
if isinstance(ret, (tuple, list)):
|
array_ops.identity(self._convert(x, dtype=dtype))
|
||||||
outputs = [
|
for (x, dtype) in zip(ret, self._out_dtypes)
|
||||||
array_ops.identity(self._convert(x, dtype=dtype))
|
]
|
||||||
for (x, dtype) in zip(ret, self._out_dtypes)
|
elif ret is None:
|
||||||
]
|
outputs = None
|
||||||
elif ret is None:
|
else:
|
||||||
outputs = None
|
outputs = array_ops.identity(
|
||||||
else:
|
self._convert(ret, dtype=self._out_dtypes[0]))
|
||||||
outputs = array_ops.identity(
|
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||||
self._convert(ret, dtype=self._out_dtypes[0]))
|
return outputs
|
||||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
|
||||||
return outputs
|
|
||||||
func_executor.wait()
|
|
||||||
|
|
||||||
|
|
||||||
class FuncRegistry(object):
|
class FuncRegistry(object):
|
||||||
@ -290,7 +286,11 @@ def _internal_py_func(func,
|
|||||||
|
|
||||||
if eager:
|
if eager:
|
||||||
result = gen_script_ops.eager_py_func(
|
result = gen_script_ops.eager_py_func(
|
||||||
input=inp, token=token, Tout=Tout, name=name)
|
input=inp,
|
||||||
|
token=token,
|
||||||
|
is_async=context.is_async(),
|
||||||
|
Tout=Tout,
|
||||||
|
name=name)
|
||||||
else:
|
else:
|
||||||
if stateful:
|
if stateful:
|
||||||
result = gen_script_ops.py_func(
|
result = gen_script_ops.py_func(
|
||||||
|
@ -1090,7 +1090,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EagerPyFunc"
|
name: "EagerPyFunc"
|
||||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EditDistance"
|
name: "EditDistance"
|
||||||
|
@ -1090,7 +1090,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EagerPyFunc"
|
name: "EagerPyFunc"
|
||||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EditDistance"
|
name: "EditDistance"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user