Introduce non-blocking component function execution.

The current implementation of KernelAndDeviceFunc::Run ties up a thread while the function execution is pending. This leads to distributed deadlock issues in large-scale parameter server training when the number of workers exceed the thread pool size.

This change leverages the RunComponentFunction request to execute component functions in a non-blocking manner. By avoiding the thread tying issue, it removes the constraint on the number of concurrent component functions to execute in parallel.

PiperOrigin-RevId: 308939721
Change-Id: I086f9ee587c4df76b303158f27c362a9bcb8314c
This commit is contained in:
Haoyu Zhang 2020-04-28 18:38:57 -07:00 committed by TensorFlower Gardener
parent 8c9a4d6bf2
commit 448f351cfe
13 changed files with 416 additions and 75 deletions

View File

@ -550,6 +550,15 @@ Status GetOrCreateKernelAndDevice(
}
}
}
int num_outputs = kernel->num_outputs();
if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs,
" outputs, but *num_retvals is ",
*num_retvals);
}
*num_retvals = num_outputs;
kernel->Ref(); // Ownership of reference is passed to out_kernel.
out_kernel->reset(kernel.get());
return Status::OK();
@ -587,12 +596,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel));
int num_outputs = kernel->num_outputs();
if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs,
" outputs, but *num_retvals is ",
*num_retvals);
}
*num_retvals = num_outputs;
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
GraphCollector* graph_collector = nullptr;
@ -1282,4 +1285,117 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
}
}
namespace {
// Low-level utility function to execute the kernel specified by `kernel` on
// `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
// Different from `EagerKernelExecute` that ties up the thread until the
// underlying function finishes execute, this function does not block the thread
// and could return before the function execution finishes. The provided
// `StatusCallback` will be triggered after function execution with its status.
void EagerKernelExecuteAsync(
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
const core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector, CancellationManager* cancellation_manager,
TensorHandle** retvals, int num_outputs, StatusCallback done) {
auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
auto outputs = std::make_shared<std::vector<Tensor>>(1);
Status s = inputs->Init(ctx, op_inputs, kernel);
if (!s.ok()) {
done(s);
return;
}
kernel->Ref(); // Ownership of reference is transferred to the callback
kernel->RunAsync(
ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
remote_func_params,
[retvals, inputs, outputs, num_outputs, ctx, graph_collector,
kernel_raw = kernel.get(), done = std::move(done)](const Status& s) {
auto wrapped_done = [&](const Status& s) {
kernel_raw->Unref();
done(s);
};
if (!s.ok()) {
wrapped_done(s);
return;
}
if (graph_collector != nullptr) {
CollectGraphs(ctx);
}
DCHECK_EQ(num_outputs, outputs->size());
wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
kernel_raw));
});
}
} // namespace
// Low-level utility to run the eager operation on local devices. Different from
// `EagerLocalExecute` which blocks and waits for the finishing the op
// execution, this method does not block the thread and could return before the
// eager operation execution finishes. The provided `StatusCallback` will be
// triggered after execution with its status.
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
int* num_retvals, StatusCallback done) {
if (VariantDeviceIsCustom(op->Device())) {
done(errors::Unimplemented(
"Custom device is not supported in EagerLocalExecuteAsync."));
return;
}
if (!op->IsLocal()) {
done(errors::InvalidArgument(
"Remote execution is not supported in async EagerLocalExecuteAsync"));
return;
}
ScopedMemoryDebugAnnotation op_annotation(
op->op_name(), op->remote_func_params().has_value()
? op->remote_func_params().value().step_id.value_or(0)
: 0);
profiler::TraceMe activity(
[&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
EagerContext& ctx = op->EagerContext();
core::RefCountPtr<KernelAndDevice> kernel;
Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
if (!s.ok()) {
done(s);
return;
}
int num_outputs = kernel->num_outputs();
s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
if (!s.ok()) {
done(s);
return;
}
GraphCollector* graph_collector = nullptr;
if (ctx.ShouldStoreGraphs()) {
graph_collector = ctx.GetGraphCollector();
}
for (int i = 0; i < num_outputs; ++i) {
retvals[i] = nullptr;
}
EagerKernelExecuteAsync(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, op->GetCancellationManager(), retvals, num_outputs,
[op, num_outputs, &retvals, done = std::move(done)](const Status& s) {
op->Clear();
// Since the operation failed, we need to Unref any outputs if they were
// allocated.
if (!s.ok()) {
for (int i = 0; i < num_outputs; ++i) {
if (retvals[i] != nullptr) {
retvals[i]->Unref();
}
}
}
done(s);
});
}
} // namespace tensorflow

View File

@ -63,6 +63,27 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* device, bool mirror,
TensorHandle** result);
// Utility function that executes a fully constructed EagerOperation
// asynchronously on the local task. This function works differently from
// EagerExecute in several ways:
// - It supports local execution only.
// - It returns after launching the eager operation to run asynchronously.
// Different from EagerExecute with async context that apends the operation
// to the end of the eager executor schedule queue, this call bypasses the
// executor logic and directly launches op execution. Ops running through
// this call does NOT have an ordering and can be executed in parallel.
// - It takes a StatusCallback which will be triggered after execution with the
// execution status.
//
// Does not support custom device.
//
// 'retvals' must point to a pre-allocated array of TensorHandle* and
// '*num_retvals' should be set to the size of this array. It is an error if
// the size of 'retvals' is less than the number of outputs. This call sets
// *num_retvals to the number of outputs.
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
int* num_retvals, StatusCallback done);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_

View File

@ -302,21 +302,37 @@ Status KernelAndDeviceFunc::Run(
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
std::unique_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
Notification n;
Status status;
RunAsync(step_container, inputs, outputs, cancellation_manager,
remote_func_params, [&status, &n](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
void KernelAndDeviceFunc::RunAsync(
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
std::function<void(const Status&)> done) {
std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
if (remote_func_params.has_value()) {
const EagerRemoteFunctionParams& params = remote_func_params.value();
if (params.step_id.has_value()) {
// If the function is a remote component of a cross-process function,
// re-use the step id as its parent function's.
opts = absl::make_unique<FunctionLibraryRuntime::Options>(
opts = std::make_shared<FunctionLibraryRuntime::Options>(
params.step_id.value());
} else {
opts = absl::make_unique<FunctionLibraryRuntime::Options>();
opts = std::make_shared<FunctionLibraryRuntime::Options>();
}
// Reuse the op id if it exists.
opts->op_id = params.op_id;
} else {
opts = absl::make_unique<FunctionLibraryRuntime::Options>();
opts = std::make_shared<FunctionLibraryRuntime::Options>();
if (get_op_id_ && is_cross_process_) {
// If the function is a cross-process function and the remote execution
// goes through eager service, create an eager op id for the function.
@ -331,49 +347,43 @@ Status KernelAndDeviceFunc::Run(
opts->rendezvous = rendezvous;
opts->create_rendezvous = false;
CancellationManager cm;
// Create a cancellation manager to be used by FLR options if caller does not
// pass in one. If the caller does provide one, pass it to process FLR and the
// locally created one will be unused.
std::shared_ptr<CancellationManager> local_cm;
if (cancellation_manager) {
opts->cancellation_manager = cancellation_manager;
} else {
opts->cancellation_manager = &cm;
local_cm = std::make_shared<CancellationManager>();
opts->cancellation_manager = local_cm.get();
}
opts->allow_dead_tensors = true;
opts->step_container =
step_container == nullptr ? &step_container_ : step_container;
auto step_container_cleanup = gtl::MakeCleanup([step_container, this] {
if (step_container == nullptr) {
this->step_container_.CleanUp();
}
});
opts->collective_executor =
collective_executor_ ? collective_executor_->get() : nullptr;
opts->stats_collector = nullptr;
opts->runner = get_runner();
Notification done;
Status status;
outputs->clear();
{
profiler::TraceMe activity(
[&] {
return absl::StrCat("FunctionRun#name=", name(),
",id=", opts->step_id, "#");
},
profiler::TraceMeLevel::kInfo);
pflr_->Run(*opts, handle_, inputs, outputs,
[&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
}
rendezvous->Unref();
return status;
profiler::TraceMe* activity = new profiler::TraceMe(
[&] {
return absl::StrCat("FunctionRun#name=", name(), ",id=", opts->step_id,
"#");
},
profiler::TraceMeLevel::kInfo);
pflr_->Run(*opts, handle_, inputs, outputs,
[opts, rendezvous, local_cm, step_container, this, activity,
done = std::move(done)](const Status& s) {
delete activity;
rendezvous->Unref();
if (step_container == nullptr) {
this->step_container_.CleanUp();
}
done(s);
});
}
tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {

View File

@ -124,6 +124,20 @@ class KernelAndDevice : public core::RefCounted {
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0;
// Execute kernel asynchronously when applicable. Different from `Run` which
// blocks the caller thread and waits for the execution of the op/function,
// `RunAsync` could return before finishing the execution. The `done` callback
// will be triggered once the op/function execution finishes.
// Currently, calling RunAsync on ops might not honor the asynchronicity when
// it is called on an instance with only sync implementation, execute the
// kernel synchronously and then call the callback with the return status
// from sync execution.
virtual void RunAsync(
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
StatusCallback done) = 0;
virtual Device* InputDevice(int i) const = 0;
virtual Device* OutputDevice(int idx) const = 0;
// If idx'th output is a resource, returns the device backing the resource.
@ -187,6 +201,16 @@ class KernelAndDeviceOp final : public KernelAndDevice {
const absl::optional<EagerRemoteFunctionParams>&
remote_func_params) override;
void RunAsync(
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
StatusCallback done) override {
// Trivial async implementation on top of the sync version
done(Run(step_container, inputs, outputs, cancellation_manager,
remote_func_params));
}
const OpKernel* kernel() const override { return kernel_.get(); }
Device* InputDevice(int i) const override;
@ -265,6 +289,12 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
const absl::optional<EagerRemoteFunctionParams>&
remote_func_params) override;
void RunAsync(
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
StatusCallback done) override;
const OpKernel* kernel() const override { return nullptr; }
Device* InputDevice(int i) const override;

View File

@ -159,9 +159,10 @@ void EagerClusterFunctionLibraryRuntime::Run(
return;
}
eager::EnqueueRequest* request = new eager::EnqueueRequest;
auto request = std::make_shared<RunComponentFunctionRequest>();
auto response = std::make_shared<RunComponentFunctionResponse>();
request->set_context_id(context_id_);
eager::Operation* remote_op = request->add_queue()->mutable_operation();
eager::Operation* remote_op = request->mutable_operation();
for (const auto& arg : args) {
if (arg.index() == 0) {
@ -188,39 +189,28 @@ void EagerClusterFunctionLibraryRuntime::Run(
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(function_data->target);
// StreamingEnqueueAsync may introduce a deadlock. When streaming RPC is
// disabled, Run() returns when the remote function execution completes, which
// might be blocked by a non-enqueued function execution.
EnqueueResponse* response = new EnqueueResponse;
eager_client->EnqueueAsync(
request, response,
// Execute component function on remote worker using RunComponentFunction RPC.
// Different from executing remote functions with Enqueue, this method runs
// a function on remote worker without tying up a thread (i.e., pure
// asynchronously).
eager_client->RunComponentFunctionAsync(
request.get(), response.get(),
[request, response, rets, done = std::move(done)](const Status& s) {
Status status = s;
auto cleanup = gtl::MakeCleanup([request, response, &status, &done] {
done(status);
delete request;
delete response;
});
if (!status.ok()) {
if (!s.ok()) {
done(s);
return;
}
if (response->queue_response_size() != 1) {
status.Update(errors::Internal(
"Expect that the size of response queue equals 1, but got: ",
response->queue_response_size()));
return;
}
for (const auto& tensor_proto : response->queue_response(0).tensor()) {
for (const auto& tensor_proto : response->tensor()) {
Tensor t;
if (t.FromProto(tensor_proto)) {
rets->push_back(std::move(t));
} else {
status.Update(errors::Internal("Could not convert tensor proto: ",
tensor_proto.DebugString()));
done(errors::Internal("Could not convert tensor proto: ",
tensor_proto.DebugString()));
return;
}
}
done(Status::OK());
});
}

View File

@ -38,6 +38,7 @@ class EagerClient : public core::RefCounted {
CLIENT_METHOD(UpdateContext);
CLIENT_METHOD(Enqueue);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(RunComponentFunction);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);

View File

@ -92,10 +92,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
return Status::OK();
}
Status GetEagerOperation(const Operation& operation,
EagerContext* eager_context,
EagerExecutor* eager_executor,
EagerOperation* eager_op) {
Status GetEagerOperationAndNumRetvals(const Operation& operation,
EagerContext* eager_context,
EagerExecutor* eager_executor,
EagerOperation* eager_op,
int* num_retvals) {
const char* name = operation.name().c_str(); // Shorthand
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
absl::nullopt;
@ -138,7 +139,10 @@ Status GetEagerOperation(const Operation& operation,
for (const auto& attr : operation.attrs()) {
eager_op->MutableAttrs()->Set(attr.first, attr.second);
}
return Status::OK();
// TODO(nareshmodi): Consider caching this.
return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
num_retvals);
}
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
@ -406,18 +410,78 @@ Status EagerServiceImpl::CreateMasterContext(
return Status::OK();
}
void EagerServiceImpl::RunComponentFunction(
const RunComponentFunctionRequest* request,
RunComponentFunctionResponse* response, StatusCallback done) {
ServerContext* context = nullptr;
Status s = GetServerContext(request->context_id(), &context);
if (!s.ok()) {
done(s);
return;
}
core::ScopedUnref context_unref(context);
auto& operation = request->operation();
// This codepath should only be triggered for executing component function
if (!operation.is_function() || !operation.is_component_function()) {
done(errors::Internal(
"RunComponentFunction request can only be used to execute "
"component functions."));
return;
}
EagerContext* eager_context = context->Context();
EagerExecutor* eager_executor = &eager_context->Executor();
EagerOperation* op = new EagerOperation(eager_context);
int* num_retvals = new int(0);
s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
op, num_retvals);
if (!s.ok()) {
done(s);
return;
}
if (!op->IsLocal()) {
done(errors::Internal(
"Received RunComponentFunction request with remote function device. "));
return;
}
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
<< operation.id();
context->Ref();
EagerLocalExecuteAsync(
op, retvals->data(), num_retvals,
[op, op_id = operation.id(), num_retvals, retvals, response,
eager_context, context, done = std::move(done)](const Status& status) {
auto wrapped_done = [&](const Status& status) {
context->Unref();
done(status);
delete op;
delete num_retvals;
delete retvals;
};
if (!status.ok()) {
wrapped_done(status);
return;
}
wrapped_done(AddOpRetvalsToResponse(
eager_context, op_id, *num_retvals, retvals->data(),
[response] { return response->add_tensor(); },
[response] { return response->add_shape(); }));
});
}
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
EagerContext* eager_context,
EagerExecutor* eager_executor,
QueueResponse* queue_response) {
tensorflow::EagerOperation op(eager_context);
TF_RETURN_IF_ERROR(
GetEagerOperation(operation, eager_context, eager_executor, &op));
int num_retvals = 0;
// TODO(nareshmodi): Consider caching this.
TF_RETURN_IF_ERROR(GetNumRetvals(eager_context, operation.name(),
operation.attrs(), &num_retvals));
TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
operation, eager_context, eager_executor, &op, &num_retvals));
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();

View File

@ -96,6 +96,10 @@ class EagerServiceImpl {
Status WaitQueueDone(const WaitQueueDoneRequest* request,
WaitQueueDoneResponse* response);
void RunComponentFunction(const RunComponentFunctionRequest* request,
RunComponentFunctionResponse* response,
StatusCallback done);
Status KeepAlive(const KeepAliveRequest* request,
KeepAliveResponse* response);

View File

@ -90,6 +90,12 @@ class FakeEagerClient : public EagerClient {
CLIENT_METHOD(CloseContext);
#undef CLIENT_METHOD
void RunComponentFunctionAsync(const RunComponentFunctionRequest* request,
RunComponentFunctionResponse* response,
StatusCallback done) override {
impl_->RunComponentFunction(request, response, std::move(done));
}
void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
@ -702,7 +708,7 @@ TEST_F(FunctionWithRemoteInputsTest,
CheckOutputTensorAndClose(outputs.at(0));
}
// Test executes a remote function through KernelAndDeviceFunc.
// Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
Init();
Device* local_device;
@ -747,6 +753,58 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
CheckOutputsAndClose(op_id);
}
// Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
Init();
Device* local_device;
TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
std::vector<Device*> input_dev_ptrs;
input_dev_ptrs.push_back(local_device);
FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
EagerContext* ctx = nullptr;
TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
const int64 op_id = 2;
kernel.reset(new KernelAndDeviceFunc(
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
[=]() { return op_id; }));
// Instantiate MatMulFunction on remote_device.
const NodeDef node_def = MatMulFunctionNodeDef();
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
// Run MatMulFunction on remote_device.
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
RemoteTensorHandle input;
input.set_op_id(1);
input.set_output_num(0);
input.set_op_device(local_device_);
input.set_device(local_device_);
std::vector<RemoteTensorHandle> remote_handles = {input};
TestExecuteNodeArgs inputs(
std::move(input_tensors),
[&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
*handle = remote_handles.at(index);
return Status::OK();
});
std::vector<Tensor> outputs;
Status status;
Notification n;
kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
/*cancellation_manager=*/nullptr,
/*remote_func_params=*/absl::nullopt,
[&status, &n](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
TF_ASSERT_OK(status);
CheckOutputsAndClose(op_id);
}
// Test creates a context and attempts to send a tensor (using the RPC), and
// then use the tensor.
TEST_F(EagerServiceImplTest, SendTensorTest) {

View File

@ -134,6 +134,7 @@ class GrpcEagerClient : public EagerClient {
CLIENT_METHOD(UpdateContext);
CLIENT_METHOD(Enqueue);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(RunComponentFunction);
CLIENT_METHOD(KeepAlive);
#undef CLIENT_METHOD

View File

@ -52,6 +52,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
ENQUEUE_REQUEST(UpdateContext);
ENQUEUE_REQUEST(Enqueue);
ENQUEUE_REQUEST(WaitQueueDone);
ENQUEUE_REQUEST(RunComponentFunction);
ENQUEUE_REQUEST(KeepAlive);
ENQUEUE_REQUEST(CloseContext);
#undef ENQUEUE_REQUEST

View File

@ -72,6 +72,23 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
HANDLER(CloseContext);
#undef HANDLER
void RunComponentFunctionHandler(
EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>*
call) {
env_->compute_pool->Schedule([this, call]() {
local_impl_.RunComponentFunction(
&call->request, &call->response,
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
});
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
RunComponentFunctionRequest, RunComponentFunctionResponse>::
EnqueueRequest(
&service_, cq_.get(),
&grpc::EagerService::AsyncService::RequestRunComponentFunction,
&GrpcEagerServiceImpl::RunComponentFunctionHandler,
/*supports_cancel=*/false);
}
// Called when a new request has been received as part of a StreamingEnqueue
// call.
// StreamingEnqueueHandler gets the request from the `call` and fills the

View File

@ -173,6 +173,18 @@ message WaitQueueDoneResponse {
// propagate some stats.
}
message RunComponentFunctionRequest {
fixed64 context_id = 1;
Operation operation = 2;
}
message RunComponentFunctionResponse {
repeated TensorShapeProto shape = 1;
repeated TensorProto tensor = 2;
}
message KeepAliveRequest {
fixed64 context_id = 1;
}
@ -272,6 +284,22 @@ service EagerService {
// in the stream so far.
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
// This takes an Eager operation and executes it in async mode on the remote
// server. Different from EnqueueRequest, ops/functions sent through this
// type of requests are allowed to execute in parallel and no ordering is
// preserved by RPC stream or executor.
// This request type should only be used for executing component functions.
// Ordering of component functions should be enforced by their corresponding
// main functions. The runtime ensures the following invarients for component
// functions (CFs) and their main functions (MFs):
// (1) MF1 -> MF2 ==> CF1 -> CF2 ("->" indicates order of execution);
// (2) MF1 || MF2 ==> CF1 || CF2 ("||" indicates possible parallel execution);
// (3) For CF1 and CF2 that come from the same MF, CF1 || CF2
// For executing ops/main functions, use Enqueue or StreamingEnqueue instead
// for correct ordering.
rpc RunComponentFunction(RunComponentFunctionRequest)
returns (RunComponentFunctionResponse);
// Contexts are always created with a deadline and no RPCs within a deadline
// will trigger a context garbage collection. KeepAlive calls can be used to
// delay this. It can also be used to validate the existence of a context ID