Server-side cancellation support for distributed function execution.
1. Thread the RPC cancel signal through the eager service RunComponentFunction calls; 2. Always pass the cancellation manager to the underlying executor (instead of only passing when `is_eager` is true, i.e., pure eager ops). With this we do not need to cancel the rendezvous from the process FLR; instead the ExecutorState takes care of it when op fails. 3. Do not mark all statuses as derived when aborting rendezvous or triggering cancellation. This usually results in the original errors buried as one of the derived errors. PiperOrigin-RevId: 313814162 Change-Id: Ia866f5f522a0b1aa54e9dce7b9cc0bcf7682136a
This commit is contained in:
parent
58f1e31019
commit
356121e563
tensorflow/core
@ -1459,7 +1459,7 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
||||
EagerKernelExecuteAsync(
|
||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
graph_collector, op->GetCancellationManager(), retvals, num_outputs,
|
||||
[op, num_outputs, &retvals, done = std::move(done)](const Status& s) {
|
||||
[op, num_outputs, retvals, done = std::move(done)](const Status& s) {
|
||||
op->Clear();
|
||||
// Since the operation failed, we need to Unref any outputs if they were
|
||||
// allocated.
|
||||
|
@ -236,7 +236,6 @@ Status KernelAndDeviceOp::Run(
|
||||
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
|
||||
OpKernelContext::Params params;
|
||||
params.is_eager = true;
|
||||
params.device = device_;
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = inputs.GetTensorValues();
|
||||
|
@ -56,6 +56,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/context.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -66,6 +67,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
|
||||
#include "tensorflow/core/profiler/lib/scoped_annotation.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -1054,10 +1056,12 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
|
||||
// aborting all other execution in the step.
|
||||
abort_run = true;
|
||||
|
||||
// If execution has been cancelled, mark any new errors as being
|
||||
// derived. This ensures any errors triggered by cancellation are marked
|
||||
// as derived.
|
||||
if (cancellation_manager_ && cancellation_manager_->IsCancelled()) {
|
||||
// If execution has been cancelled, mark cancelled or aborted errors as
|
||||
// being derived. Note that the original node that fails might also
|
||||
// trigger cancellation, and here we make sure the original error is
|
||||
// exposed to users and not buried as a derived error.
|
||||
if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
|
||||
(errors::IsCancelled(s) || errors::IsAborted(s))) {
|
||||
status_ = StatusGroup::MakeDerived(s);
|
||||
} else {
|
||||
status_ = s;
|
||||
|
@ -1055,29 +1055,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||
local_cm = std::make_shared<CancellationManager>();
|
||||
cm = local_cm.get();
|
||||
}
|
||||
auto token = cm->get_cancellation_token();
|
||||
const auto cancelled_error = errors::Cancelled(
|
||||
"ProcessFunctionLibraryRuntime::RunMultiDevice was cancelled.");
|
||||
const bool already_cancelled = !cm->RegisterCallback(
|
||||
token,
|
||||
[rendez = opts.rendezvous, n_func = data->glue_.size(), cancelled_error] {
|
||||
// Abort rendezvous only if there are more than one component functions
|
||||
// to avoid reporting cancellation error directly to PartitionedCallOps
|
||||
// that launch a single component function.
|
||||
if (rendez && n_func > 1) {
|
||||
rendez->StartAbort(cancelled_error);
|
||||
}
|
||||
});
|
||||
if (already_cancelled) {
|
||||
done(cancelled_error);
|
||||
return;
|
||||
}
|
||||
|
||||
auto* refcounted_done = new ReffedStatusCallback(
|
||||
[cm, token, local_cm, done = std::move(done)](const Status& s) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
done(s);
|
||||
});
|
||||
auto* refcounted_done = new ReffedStatusCallback(std::move(done));
|
||||
for (int i = 0; i < data->glue_.size(); ++i) {
|
||||
refcounted_done->Ref();
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -373,11 +374,14 @@ void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
|
||||
|
||||
void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
CHECK(!s.ok());
|
||||
// Use a "derived" status as the status for the rendezvous. Derived
|
||||
// status messages are ignored when aggregating errors across devices: this
|
||||
// allows us to prefer our original status message over any cancellation
|
||||
// related errors.
|
||||
Status derived_status = StatusGroup::MakeDerived(s);
|
||||
// If the status passed in is a cancelled or aborted error, mark it as
|
||||
// "derived" for the rendezvous. Derived status messages are ignored when
|
||||
// aggregating errors across devices: this allows us to prefer our original
|
||||
// status message over any cancellation related errors.
|
||||
Status derived_status = s;
|
||||
if (errors::IsCancelled(s) || errors::IsAborted(s)) {
|
||||
derived_status = StatusGroup::MakeDerived(s);
|
||||
}
|
||||
|
||||
local_->StartAbort(derived_status);
|
||||
{
|
||||
|
@ -411,7 +411,7 @@ Status EagerServiceImpl::CreateMasterContext(
|
||||
}
|
||||
|
||||
void EagerServiceImpl::RunComponentFunction(
|
||||
const RunComponentFunctionRequest* request,
|
||||
CallOptions* call_opts, const RunComponentFunctionRequest* request,
|
||||
RunComponentFunctionResponse* response, StatusCallback done) {
|
||||
ServerContext* context = nullptr;
|
||||
Status s = GetServerContext(request->context_id(), &context);
|
||||
@ -451,11 +451,17 @@ void EagerServiceImpl::RunComponentFunction(
|
||||
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
|
||||
<< operation.id();
|
||||
|
||||
auto cm = std::make_shared<CancellationManager>();
|
||||
op->SetCancellationManager(cm.get());
|
||||
call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
|
||||
|
||||
context->Ref();
|
||||
EagerLocalExecuteAsync(
|
||||
op, retvals->data(), num_retvals,
|
||||
[op, op_id = operation.id(), num_retvals, retvals, response,
|
||||
eager_context, context, done = std::move(done)](const Status& status) {
|
||||
[op, op_id = operation.id(), num_retvals, retvals, cm, call_opts,
|
||||
response, eager_context, context,
|
||||
done = std::move(done)](const Status& status) {
|
||||
call_opts->ClearCancelCallback();
|
||||
auto wrapped_done = [&](const Status& status) {
|
||||
context->Unref();
|
||||
done(status);
|
||||
|
@ -96,7 +96,8 @@ class EagerServiceImpl {
|
||||
Status WaitQueueDone(const WaitQueueDoneRequest* request,
|
||||
WaitQueueDoneResponse* response);
|
||||
|
||||
void RunComponentFunction(const RunComponentFunctionRequest* request,
|
||||
void RunComponentFunction(CallOptions* call_opts,
|
||||
const RunComponentFunctionRequest* request,
|
||||
RunComponentFunctionResponse* response,
|
||||
StatusCallback done);
|
||||
|
||||
|
@ -15,10 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -39,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/eager_service.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -94,7 +91,7 @@ class FakeEagerClient : public EagerClient {
|
||||
const RunComponentFunctionRequest* request,
|
||||
RunComponentFunctionResponse* response,
|
||||
StatusCallback done) override {
|
||||
impl_->RunComponentFunction(request, response, std::move(done));
|
||||
impl_->RunComponentFunction(call_opts, request, response, std::move(done));
|
||||
}
|
||||
|
||||
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||
@ -177,14 +174,11 @@ void SetTensorProto(TensorProto* tensor_proto) {
|
||||
TF_DeleteTensor(t);
|
||||
}
|
||||
|
||||
void AddOperationToEnqueueRequest(
|
||||
int64 id, const string& name,
|
||||
void BuildOperation(
|
||||
Operation* operation, int64 id, const string& name,
|
||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||
inputs,
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||
EnqueueRequest* request) {
|
||||
auto* operation = request->add_queue()->mutable_operation();
|
||||
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device) {
|
||||
operation->set_id(id);
|
||||
operation->set_name(name);
|
||||
operation->set_device(device);
|
||||
@ -209,6 +203,28 @@ void AddOperationToEnqueueRequest(
|
||||
}
|
||||
}
|
||||
|
||||
void AddOperationToEnqueueRequest(
|
||||
int64 id, const string& name,
|
||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||
inputs,
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||
EnqueueRequest* request) {
|
||||
auto* operation = request->add_queue()->mutable_operation();
|
||||
BuildOperation(operation, id, name, inputs, attrs, device);
|
||||
}
|
||||
|
||||
void AddOperationToRunComponentFunctionRequest(
|
||||
int64 id, const string& name,
|
||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||
inputs,
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||
RunComponentFunctionRequest* request) {
|
||||
auto* operation = request->mutable_operation();
|
||||
operation->set_is_function(true);
|
||||
operation->set_is_component_function(true);
|
||||
BuildOperation(operation, id, name, inputs, attrs, device);
|
||||
}
|
||||
|
||||
tensorflow::NodeDef MatMulFunctionNodeDef() {
|
||||
tensorflow::NodeDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
@ -299,6 +315,69 @@ tensorflow::FunctionDef MatMulNestedFunction() {
|
||||
return def;
|
||||
}
|
||||
|
||||
tensorflow::FunctionDef SingleRecvNodeFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'SingleRecvNodeFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'recv_tensor'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'recv_node'"
|
||||
" op: '_Recv'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'client_terminated'"
|
||||
" value {"
|
||||
" b: true"
|
||||
" }"
|
||||
" }"
|
||||
" attr {"
|
||||
" key: 'recv_device'"
|
||||
" value {"
|
||||
" s: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" }"
|
||||
" }"
|
||||
" attr {"
|
||||
" key: 'send_device'"
|
||||
" value {"
|
||||
" s: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" }"
|
||||
" }"
|
||||
" attr {"
|
||||
" key: 'send_device_incarnation'"
|
||||
" value {"
|
||||
" i: 1"
|
||||
" }"
|
||||
" }"
|
||||
" attr {"
|
||||
" key: 'tensor_name'"
|
||||
" value {"
|
||||
" s: 't0'"
|
||||
" }"
|
||||
" }"
|
||||
" attr {"
|
||||
" key: 'tensor_type'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'recv_tensor'"
|
||||
" value: 'recv_node:tensor:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def;
|
||||
}
|
||||
|
||||
// Test creates a context and attempts to execute some ops.
|
||||
TEST_F(EagerServiceImplTest, BasicTest) {
|
||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||
@ -462,6 +541,97 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
||||
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
// Creates a context and attempts to execute a component function.
|
||||
void TestComponentFunction(const RegisterFunctionOp& register_op,
|
||||
const string& function_name,
|
||||
const bool test_cancel) {
|
||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||
uint64 context_id = random::New64();
|
||||
|
||||
// Create context.
|
||||
CreateContextRequest request;
|
||||
request.mutable_server_def()->set_job_name("localhost");
|
||||
request.mutable_server_def()->set_task_index(0);
|
||||
request.set_context_id(context_id);
|
||||
CreateContextResponse response;
|
||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||
|
||||
// Register function.
|
||||
EnqueueRequest enqueue_request;
|
||||
enqueue_request.set_context_id(context_id);
|
||||
*enqueue_request.add_queue()->mutable_register_function() = register_op;
|
||||
EnqueueResponse enqueue_response;
|
||||
TF_ASSERT_OK(
|
||||
eager_service_impl.Enqueue(&enqueue_request, &enqueue_response));
|
||||
|
||||
// First run an op to generate input for function.
|
||||
EnqueueRequest remote_enqueue_request;
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
EnqueueResponse remote_enqueue_response;
|
||||
|
||||
std::unordered_map<string, AttrValue> const_attrs;
|
||||
AttrValue val;
|
||||
val.set_type(tensorflow::DataType::DT_FLOAT);
|
||||
const_attrs.insert({"dtype", val});
|
||||
val.Clear();
|
||||
SetTensorProto(val.mutable_tensor());
|
||||
const_attrs.insert({"value", val});
|
||||
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
|
||||
// Run function with input from the previous op.
|
||||
RunComponentFunctionRequest run_comp_func_request;
|
||||
run_comp_func_request.set_context_id(context_id);
|
||||
RunComponentFunctionResponse run_comp_func_response;
|
||||
AddOperationToRunComponentFunctionRequest(
|
||||
2, function_name, {std::make_pair(1, 0)},
|
||||
std::unordered_map<string, AttrValue>(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", &run_comp_func_request);
|
||||
|
||||
CallOptions call_opts;
|
||||
Notification n;
|
||||
Status status;
|
||||
eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
|
||||
&run_comp_func_response,
|
||||
[&status, &n](const Status& s) {
|
||||
status.Update(s);
|
||||
n.Notify();
|
||||
});
|
||||
if (test_cancel) {
|
||||
call_opts.StartCancel();
|
||||
}
|
||||
n.WaitForNotification();
|
||||
if (test_cancel) {
|
||||
EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
|
||||
} else {
|
||||
TF_ASSERT_OK(status);
|
||||
// Retrieve the output.
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* tensor_handle;
|
||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
|
||||
auto actual = t->flat<float>();
|
||||
EXPECT_EQ(4, actual.size());
|
||||
|
||||
EXPECT_EQ(7, actual(0));
|
||||
EXPECT_EQ(10, actual(1));
|
||||
EXPECT_EQ(15, actual(2));
|
||||
EXPECT_EQ(22, actual(3));
|
||||
}
|
||||
|
||||
CloseContextRequest close_context_request;
|
||||
close_context_request.set_context_id(context_id);
|
||||
close_context_request.set_context_view_id(0);
|
||||
CloseContextResponse close_context_response;
|
||||
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
|
||||
&close_context_response));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
|
||||
@ -483,6 +653,18 @@ TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
|
||||
TestFunction(register_op, "MatMulNestedFunction");
|
||||
}
|
||||
|
||||
TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
|
||||
RegisterFunctionOp register_op;
|
||||
*register_op.mutable_function_def() = MatMulFunction();
|
||||
TestComponentFunction(register_op, "MatMulFunction", false);
|
||||
}
|
||||
|
||||
TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
|
||||
RegisterFunctionOp register_op;
|
||||
*register_op.mutable_function_def() = SingleRecvNodeFunction();
|
||||
TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
|
||||
}
|
||||
|
||||
class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
public:
|
||||
FunctionWithRemoteInputsTest()
|
||||
|
@ -76,9 +76,13 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
||||
EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>*
|
||||
call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
local_impl_.RunComponentFunction(
|
||||
&call->request, &call->response,
|
||||
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
|
||||
auto call_opts = std::make_shared<CallOptions>();
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
local_impl_.RunComponentFunction(call_opts.get(), &call->request,
|
||||
&call->response,
|
||||
[call, call_opts](const Status& s) {
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
});
|
||||
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
|
||||
RunComponentFunctionRequest, RunComponentFunctionResponse>::
|
||||
@ -86,7 +90,7 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
||||
&service_, cq_.get(),
|
||||
&grpc::EagerService::AsyncService::RequestRunComponentFunction,
|
||||
&GrpcEagerServiceImpl::RunComponentFunctionHandler,
|
||||
/*supports_cancel=*/false);
|
||||
/*supports_cancel=*/true);
|
||||
}
|
||||
|
||||
// Called when a new request has been received as part of a StreamingEnqueue
|
||||
|
@ -597,9 +597,6 @@ class OpKernelContext {
|
||||
// The step being executed.
|
||||
int64 step_id = 0;
|
||||
|
||||
// True if the op is created by eager runtime.
|
||||
bool is_eager = false;
|
||||
|
||||
// The op kernel being computed.
|
||||
OpKernel* op_kernel = nullptr;
|
||||
|
||||
@ -718,8 +715,6 @@ class OpKernelContext {
|
||||
|
||||
int64 step_id() const { return params_->step_id; }
|
||||
|
||||
bool is_eager() const { return params_->is_eager; }
|
||||
|
||||
const OpKernel& op_kernel() const { return *params_->op_kernel; }
|
||||
|
||||
// Input/output signature.
|
||||
|
@ -193,12 +193,7 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->output_alloc_attr(0);
|
||||
if (ctx->is_eager()) {
|
||||
// NOTE(fishx): Only set cancellation_manager in eager mode. Because in
|
||||
// Tensorflow 1.x, session (or graph_mgr) will abort the underlying
|
||||
// rendezvous if it encounters any error.
|
||||
args.cancellation_manager = ctx->cancellation_manager();
|
||||
}
|
||||
args.cancellation_manager = ctx->cancellation_manager();
|
||||
|
||||
FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
|
||||
if (frame_iter == FrameAndIter(0, 0)) {
|
||||
|
Loading…
Reference in New Issue
Block a user