Set py_func private executor in c++ instead of python.
When py_func throws a exception, the resources will be held until the exception is decref in C++. If the resource is a remote TensorHandle, it will enqueue a destroy_tensor node into executor. This change ensures that these destroy_tensor nodes are executed by the pyfunc private executor. PiperOrigin-RevId: 263014289
This commit is contained in:
parent
7685639e2f
commit
970cf44c28
@ -39,6 +39,7 @@ REGISTER_OP("EagerPyFunc")
|
||||
.Input("input: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("token: string")
|
||||
.Attr("is_async: bool=false")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >=0")
|
||||
.SetIsStateful()
|
||||
|
@ -61,6 +61,9 @@ struct PyCall {
|
||||
// True if the call is associated with an EagerPyFunc.
|
||||
bool eager = false;
|
||||
|
||||
// True if the call is running under eager async mode.
|
||||
bool eager_async = false;
|
||||
|
||||
// Inputs and outputs of this function invocation.
|
||||
std::vector<Tensor> ins;
|
||||
std::vector<Tensor> out;
|
||||
@ -173,12 +176,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
|
||||
// Prepare the argument.
|
||||
PyObject* args = nullptr;
|
||||
TFE_Context* ctx = nullptr;
|
||||
std::unique_ptr<EagerExecutor> new_executor = nullptr;
|
||||
EagerExecutor* old_executor = nullptr;
|
||||
if (call->eager) {
|
||||
// See FuncRegistry._ctx.
|
||||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
||||
ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
||||
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
|
||||
CHECK_NE(ctx, nullptr);
|
||||
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args));
|
||||
new_executor.reset(new EagerExecutor(call->eager_async));
|
||||
old_executor = ctx->context->Executor();
|
||||
ctx->context->SetExecutorForThread(new_executor.get());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
|
||||
}
|
||||
@ -187,31 +196,38 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
// Invokes the trampoline.
|
||||
PyObject* result = PyEval_CallObject(trampoline, args);
|
||||
Py_DECREF(args);
|
||||
Status s = Status::OK();
|
||||
if (result == nullptr) {
|
||||
if (PyErr_Occurred()) {
|
||||
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
||||
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
||||
return errors::InvalidArgument(PyExceptionFetch());
|
||||
s = errors::InvalidArgument(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
||||
*out_log_on_error = false;
|
||||
return errors::OutOfRange(PyExceptionFetch());
|
||||
s = errors::OutOfRange(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
||||
return errors::ResourceExhausted(PyExceptionFetch());
|
||||
s = errors::ResourceExhausted(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
||||
return errors::Unimplemented(PyExceptionFetch());
|
||||
s = errors::Unimplemented(PyExceptionFetch());
|
||||
} else {
|
||||
// TODO(ebrevdo): Check if exception is an OpError and use the
|
||||
// OpError.error_code property to map it back in the Status.
|
||||
return errors::Unknown(PyExceptionFetch());
|
||||
s = errors::Unknown(PyExceptionFetch());
|
||||
}
|
||||
} else {
|
||||
return errors::Internal("Failed to run py callback ", call->token,
|
||||
": see error log.");
|
||||
s = errors::Internal("Failed to run py callback ", call->token,
|
||||
": see error log.");
|
||||
}
|
||||
}
|
||||
|
||||
if (new_executor != nullptr) {
|
||||
s.Update(new_executor->WaitForAllPendingNodes());
|
||||
ctx->context->SetExecutorForThread(old_executor);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
// Process the return values and convert them to TF Tensors.
|
||||
Status s = Status::OK();
|
||||
if (PyList_Check(result)) {
|
||||
// `result` is a Python list; if this operation is an `EagerPyFunc`, then
|
||||
// every item in the list must be an `EagerTensor`; otherwise, every element
|
||||
@ -282,6 +298,9 @@ class PyFuncOp : public OpKernel {
|
||||
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
||||
eager_ = type_string() == "EagerPyFunc";
|
||||
if (eager_) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_async", &eager_async_));
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return true; }
|
||||
@ -299,6 +318,7 @@ class PyFuncOp : public OpKernel {
|
||||
"Unrecognized device class: ", ctx->device()->name()));
|
||||
return;
|
||||
}
|
||||
call.eager_async = eager_async_;
|
||||
}
|
||||
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
@ -357,6 +377,8 @@ class PyFuncOp : public OpKernel {
|
||||
// i.e., if and only if the eager attribute is set.
|
||||
bool eager_;
|
||||
|
||||
bool eager_async_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
||||
};
|
||||
|
||||
|
@ -31,7 +31,6 @@ import six
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
@ -102,30 +101,27 @@ class EagerFunc(object):
|
||||
def __call__(self, device, token, args):
|
||||
"""Passes `args` to `self._func`, which is executed eagerly."""
|
||||
|
||||
func_executor = executor.new_executor(context.is_async())
|
||||
with context.executor_scope(func_executor):
|
||||
with context.eager_mode(), backprop.GradientTape() as tape:
|
||||
# Only watch tensors with a floating dtype.
|
||||
for tensor in args:
|
||||
for t in nest.flatten(tensor):
|
||||
if t.dtype.is_floating:
|
||||
tape.watch(t)
|
||||
ret = self._func(*args)
|
||||
# Use tf.identity to copy the returned tensors to device if necessary.
|
||||
with ops.device(device):
|
||||
if isinstance(ret, (tuple, list)):
|
||||
outputs = [
|
||||
array_ops.identity(self._convert(x, dtype=dtype))
|
||||
for (x, dtype) in zip(ret, self._out_dtypes)
|
||||
]
|
||||
elif ret is None:
|
||||
outputs = None
|
||||
else:
|
||||
outputs = array_ops.identity(
|
||||
self._convert(ret, dtype=self._out_dtypes[0]))
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
return outputs
|
||||
func_executor.wait()
|
||||
with context.eager_mode(), backprop.GradientTape() as tape:
|
||||
# Only watch tensors with a floating dtype.
|
||||
for tensor in args:
|
||||
for t in nest.flatten(tensor):
|
||||
if t.dtype.is_floating:
|
||||
tape.watch(t)
|
||||
ret = self._func(*args)
|
||||
# Use tf.identity to copy the returned tensors to device if necessary.
|
||||
with ops.device(device):
|
||||
if isinstance(ret, (tuple, list)):
|
||||
outputs = [
|
||||
array_ops.identity(self._convert(x, dtype=dtype))
|
||||
for (x, dtype) in zip(ret, self._out_dtypes)
|
||||
]
|
||||
elif ret is None:
|
||||
outputs = None
|
||||
else:
|
||||
outputs = array_ops.identity(
|
||||
self._convert(ret, dtype=self._out_dtypes[0]))
|
||||
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class FuncRegistry(object):
|
||||
@ -290,7 +286,11 @@ def _internal_py_func(func,
|
||||
|
||||
if eager:
|
||||
result = gen_script_ops.eager_py_func(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
input=inp,
|
||||
token=token,
|
||||
is_async=context.is_async(),
|
||||
Tout=Tout,
|
||||
name=name)
|
||||
else:
|
||||
if stateful:
|
||||
result = gen_script_ops.py_func(
|
||||
|
@ -1090,7 +1090,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EagerPyFunc"
|
||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EditDistance"
|
||||
|
@ -1090,7 +1090,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EagerPyFunc"
|
||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'token\', \'Tout\', \'is_async\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EditDistance"
|
||||
|
Loading…
x
Reference in New Issue
Block a user