Introduce non-blocking component function execution.
The current implementation of KernelAndDeviceFunc::Run ties up a thread while the function execution is pending. This leads to distributed deadlock issues in large-scale parameter server training when the number of workers exceed the thread pool size. This change leverages the RunComponentFunction request to execute component functions in a non-blocking manner. By avoiding the thread tying issue, it removes the constraint on the number of concurrent component functions to execute in parallel. PiperOrigin-RevId: 308939721 Change-Id: I086f9ee587c4df76b303158f27c362a9bcb8314c
This commit is contained in:
parent
8c9a4d6bf2
commit
448f351cfe
@ -550,6 +550,15 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int num_outputs = kernel->num_outputs();
|
||||||
|
if (num_outputs > *num_retvals) {
|
||||||
|
return errors::InvalidArgument("Expecting ", num_outputs,
|
||||||
|
" outputs, but *num_retvals is ",
|
||||||
|
*num_retvals);
|
||||||
|
}
|
||||||
|
*num_retvals = num_outputs;
|
||||||
|
|
||||||
kernel->Ref(); // Ownership of reference is passed to out_kernel.
|
kernel->Ref(); // Ownership of reference is passed to out_kernel.
|
||||||
out_kernel->reset(kernel.get());
|
out_kernel->reset(kernel.get());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -587,12 +596,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel));
|
GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel));
|
||||||
|
|
||||||
int num_outputs = kernel->num_outputs();
|
int num_outputs = kernel->num_outputs();
|
||||||
if (num_outputs > *num_retvals) {
|
|
||||||
return errors::InvalidArgument("Expecting ", num_outputs,
|
|
||||||
" outputs, but *num_retvals is ",
|
|
||||||
*num_retvals);
|
|
||||||
}
|
|
||||||
*num_retvals = num_outputs;
|
|
||||||
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
|
TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
|
||||||
|
|
||||||
GraphCollector* graph_collector = nullptr;
|
GraphCollector* graph_collector = nullptr;
|
||||||
@ -1282,4 +1285,117 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Low-level utility function to execute the kernel specified by `kernel` on
|
||||||
|
// `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
|
||||||
|
// Different from `EagerKernelExecute` that ties up the thread until the
|
||||||
|
// underlying function finishes execute, this function does not block the thread
|
||||||
|
// and could return before the function execution finishes. The provided
|
||||||
|
// `StatusCallback` will be triggered after function execution with its status.
|
||||||
|
void EagerKernelExecuteAsync(
|
||||||
|
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
|
const core::RefCountPtr<KernelAndDevice> kernel,
|
||||||
|
GraphCollector* graph_collector, CancellationManager* cancellation_manager,
|
||||||
|
TensorHandle** retvals, int num_outputs, StatusCallback done) {
|
||||||
|
auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
|
||||||
|
auto outputs = std::make_shared<std::vector<Tensor>>(1);
|
||||||
|
|
||||||
|
Status s = inputs->Init(ctx, op_inputs, kernel);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel->Ref(); // Ownership of reference is transferred to the callback
|
||||||
|
kernel->RunAsync(
|
||||||
|
ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
|
||||||
|
remote_func_params,
|
||||||
|
[retvals, inputs, outputs, num_outputs, ctx, graph_collector,
|
||||||
|
kernel_raw = kernel.get(), done = std::move(done)](const Status& s) {
|
||||||
|
auto wrapped_done = [&](const Status& s) {
|
||||||
|
kernel_raw->Unref();
|
||||||
|
done(s);
|
||||||
|
};
|
||||||
|
if (!s.ok()) {
|
||||||
|
wrapped_done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (graph_collector != nullptr) {
|
||||||
|
CollectGraphs(ctx);
|
||||||
|
}
|
||||||
|
DCHECK_EQ(num_outputs, outputs->size());
|
||||||
|
wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
|
||||||
|
kernel_raw));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Low-level utility to run the eager operation on local devices. Different from
|
||||||
|
// `EagerLocalExecute` which blocks and waits for the finishing the op
|
||||||
|
// execution, this method does not block the thread and could return before the
|
||||||
|
// eager operation execution finishes. The provided `StatusCallback` will be
|
||||||
|
// triggered after execution with its status.
|
||||||
|
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
||||||
|
int* num_retvals, StatusCallback done) {
|
||||||
|
if (VariantDeviceIsCustom(op->Device())) {
|
||||||
|
done(errors::Unimplemented(
|
||||||
|
"Custom device is not supported in EagerLocalExecuteAsync."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!op->IsLocal()) {
|
||||||
|
done(errors::InvalidArgument(
|
||||||
|
"Remote execution is not supported in async EagerLocalExecuteAsync"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ScopedMemoryDebugAnnotation op_annotation(
|
||||||
|
op->op_name(), op->remote_func_params().has_value()
|
||||||
|
? op->remote_func_params().value().step_id.value_or(0)
|
||||||
|
: 0);
|
||||||
|
profiler::TraceMe activity(
|
||||||
|
[&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
|
||||||
|
profiler::TraceMeLevel::kInfo);
|
||||||
|
EagerContext& ctx = op->EagerContext();
|
||||||
|
|
||||||
|
core::RefCountPtr<KernelAndDevice> kernel;
|
||||||
|
Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_outputs = kernel->num_outputs();
|
||||||
|
s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphCollector* graph_collector = nullptr;
|
||||||
|
if (ctx.ShouldStoreGraphs()) {
|
||||||
|
graph_collector = ctx.GetGraphCollector();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
retvals[i] = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerKernelExecuteAsync(
|
||||||
|
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||||
|
graph_collector, op->GetCancellationManager(), retvals, num_outputs,
|
||||||
|
[op, num_outputs, &retvals, done = std::move(done)](const Status& s) {
|
||||||
|
op->Clear();
|
||||||
|
// Since the operation failed, we need to Unref any outputs if they were
|
||||||
|
// allocated.
|
||||||
|
if (!s.ok()) {
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
if (retvals[i] != nullptr) {
|
||||||
|
retvals[i]->Unref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
done(s);
|
||||||
|
});
|
||||||
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -63,6 +63,27 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
EagerExecutor* executor, Device* device, bool mirror,
|
EagerExecutor* executor, Device* device, bool mirror,
|
||||||
TensorHandle** result);
|
TensorHandle** result);
|
||||||
|
|
||||||
|
// Utility function that executes a fully constructed EagerOperation
|
||||||
|
// asynchronously on the local task. This function works differently from
|
||||||
|
// EagerExecute in several ways:
|
||||||
|
// - It supports local execution only.
|
||||||
|
// - It returns after launching the eager operation to run asynchronously.
|
||||||
|
// Different from EagerExecute with async context that apends the operation
|
||||||
|
// to the end of the eager executor schedule queue, this call bypasses the
|
||||||
|
// executor logic and directly launches op execution. Ops running through
|
||||||
|
// this call does NOT have an ordering and can be executed in parallel.
|
||||||
|
// - It takes a StatusCallback which will be triggered after execution with the
|
||||||
|
// execution status.
|
||||||
|
//
|
||||||
|
// Does not support custom device.
|
||||||
|
//
|
||||||
|
// 'retvals' must point to a pre-allocated array of TensorHandle* and
|
||||||
|
// '*num_retvals' should be set to the size of this array. It is an error if
|
||||||
|
// the size of 'retvals' is less than the number of outputs. This call sets
|
||||||
|
// *num_retvals to the number of outputs.
|
||||||
|
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
||||||
|
int* num_retvals, StatusCallback done);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_
|
||||||
|
@ -302,21 +302,37 @@ Status KernelAndDeviceFunc::Run(
|
|||||||
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) {
|
||||||
std::unique_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
|
Notification n;
|
||||||
|
Status status;
|
||||||
|
RunAsync(step_container, inputs, outputs, cancellation_manager,
|
||||||
|
remote_func_params, [&status, &n](const Status& s) {
|
||||||
|
status = s;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelAndDeviceFunc::RunAsync(
|
||||||
|
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
|
std::function<void(const Status&)> done) {
|
||||||
|
std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
|
||||||
if (remote_func_params.has_value()) {
|
if (remote_func_params.has_value()) {
|
||||||
const EagerRemoteFunctionParams& params = remote_func_params.value();
|
const EagerRemoteFunctionParams& params = remote_func_params.value();
|
||||||
if (params.step_id.has_value()) {
|
if (params.step_id.has_value()) {
|
||||||
// If the function is a remote component of a cross-process function,
|
// If the function is a remote component of a cross-process function,
|
||||||
// re-use the step id as its parent function's.
|
// re-use the step id as its parent function's.
|
||||||
opts = absl::make_unique<FunctionLibraryRuntime::Options>(
|
opts = std::make_shared<FunctionLibraryRuntime::Options>(
|
||||||
params.step_id.value());
|
params.step_id.value());
|
||||||
} else {
|
} else {
|
||||||
opts = absl::make_unique<FunctionLibraryRuntime::Options>();
|
opts = std::make_shared<FunctionLibraryRuntime::Options>();
|
||||||
}
|
}
|
||||||
// Reuse the op id if it exists.
|
// Reuse the op id if it exists.
|
||||||
opts->op_id = params.op_id;
|
opts->op_id = params.op_id;
|
||||||
} else {
|
} else {
|
||||||
opts = absl::make_unique<FunctionLibraryRuntime::Options>();
|
opts = std::make_shared<FunctionLibraryRuntime::Options>();
|
||||||
if (get_op_id_ && is_cross_process_) {
|
if (get_op_id_ && is_cross_process_) {
|
||||||
// If the function is a cross-process function and the remote execution
|
// If the function is a cross-process function and the remote execution
|
||||||
// goes through eager service, create an eager op id for the function.
|
// goes through eager service, create an eager op id for the function.
|
||||||
@ -331,49 +347,43 @@ Status KernelAndDeviceFunc::Run(
|
|||||||
opts->rendezvous = rendezvous;
|
opts->rendezvous = rendezvous;
|
||||||
opts->create_rendezvous = false;
|
opts->create_rendezvous = false;
|
||||||
|
|
||||||
CancellationManager cm;
|
// Create a cancellation manager to be used by FLR options if caller does not
|
||||||
|
// pass in one. If the caller does provide one, pass it to process FLR and the
|
||||||
|
// locally created one will be unused.
|
||||||
|
std::shared_ptr<CancellationManager> local_cm;
|
||||||
if (cancellation_manager) {
|
if (cancellation_manager) {
|
||||||
opts->cancellation_manager = cancellation_manager;
|
opts->cancellation_manager = cancellation_manager;
|
||||||
} else {
|
} else {
|
||||||
opts->cancellation_manager = &cm;
|
local_cm = std::make_shared<CancellationManager>();
|
||||||
|
opts->cancellation_manager = local_cm.get();
|
||||||
}
|
}
|
||||||
opts->allow_dead_tensors = true;
|
opts->allow_dead_tensors = true;
|
||||||
|
|
||||||
opts->step_container =
|
opts->step_container =
|
||||||
step_container == nullptr ? &step_container_ : step_container;
|
step_container == nullptr ? &step_container_ : step_container;
|
||||||
auto step_container_cleanup = gtl::MakeCleanup([step_container, this] {
|
|
||||||
if (step_container == nullptr) {
|
|
||||||
this->step_container_.CleanUp();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
opts->collective_executor =
|
opts->collective_executor =
|
||||||
collective_executor_ ? collective_executor_->get() : nullptr;
|
collective_executor_ ? collective_executor_->get() : nullptr;
|
||||||
|
|
||||||
opts->stats_collector = nullptr;
|
opts->stats_collector = nullptr;
|
||||||
opts->runner = get_runner();
|
opts->runner = get_runner();
|
||||||
|
|
||||||
Notification done;
|
|
||||||
Status status;
|
|
||||||
outputs->clear();
|
outputs->clear();
|
||||||
|
|
||||||
{
|
profiler::TraceMe* activity = new profiler::TraceMe(
|
||||||
profiler::TraceMe activity(
|
[&] {
|
||||||
[&] {
|
return absl::StrCat("FunctionRun#name=", name(), ",id=", opts->step_id,
|
||||||
return absl::StrCat("FunctionRun#name=", name(),
|
"#");
|
||||||
",id=", opts->step_id, "#");
|
},
|
||||||
},
|
profiler::TraceMeLevel::kInfo);
|
||||||
profiler::TraceMeLevel::kInfo);
|
pflr_->Run(*opts, handle_, inputs, outputs,
|
||||||
pflr_->Run(*opts, handle_, inputs, outputs,
|
[opts, rendezvous, local_cm, step_container, this, activity,
|
||||||
[&status, &done](const Status& s) {
|
done = std::move(done)](const Status& s) {
|
||||||
status = s;
|
delete activity;
|
||||||
done.Notify();
|
rendezvous->Unref();
|
||||||
});
|
if (step_container == nullptr) {
|
||||||
done.WaitForNotification();
|
this->step_container_.CleanUp();
|
||||||
}
|
}
|
||||||
|
done(s);
|
||||||
rendezvous->Unref();
|
});
|
||||||
return status;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
|
tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
|
||||||
|
@ -124,6 +124,20 @@ class KernelAndDevice : public core::RefCounted {
|
|||||||
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0;
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0;
|
||||||
|
|
||||||
|
// Execute kernel asynchronously when applicable. Different from `Run` which
|
||||||
|
// blocks the caller thread and waits for the execution of the op/function,
|
||||||
|
// `RunAsync` could return before finishing the execution. The `done` callback
|
||||||
|
// will be triggered once the op/function execution finishes.
|
||||||
|
// Currently, calling RunAsync on ops might not honor the asynchronicity when
|
||||||
|
// it is called on an instance with only sync implementation, execute the
|
||||||
|
// kernel synchronously and then call the callback with the return status
|
||||||
|
// from sync execution.
|
||||||
|
virtual void RunAsync(
|
||||||
|
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
|
StatusCallback done) = 0;
|
||||||
|
|
||||||
virtual Device* InputDevice(int i) const = 0;
|
virtual Device* InputDevice(int i) const = 0;
|
||||||
virtual Device* OutputDevice(int idx) const = 0;
|
virtual Device* OutputDevice(int idx) const = 0;
|
||||||
// If idx'th output is a resource, returns the device backing the resource.
|
// If idx'th output is a resource, returns the device backing the resource.
|
||||||
@ -187,6 +201,16 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
|||||||
const absl::optional<EagerRemoteFunctionParams>&
|
const absl::optional<EagerRemoteFunctionParams>&
|
||||||
remote_func_params) override;
|
remote_func_params) override;
|
||||||
|
|
||||||
|
void RunAsync(
|
||||||
|
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
|
StatusCallback done) override {
|
||||||
|
// Trivial async implementation on top of the sync version
|
||||||
|
done(Run(step_container, inputs, outputs, cancellation_manager,
|
||||||
|
remote_func_params));
|
||||||
|
}
|
||||||
|
|
||||||
const OpKernel* kernel() const override { return kernel_.get(); }
|
const OpKernel* kernel() const override { return kernel_.get(); }
|
||||||
|
|
||||||
Device* InputDevice(int i) const override;
|
Device* InputDevice(int i) const override;
|
||||||
@ -265,6 +289,12 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
const absl::optional<EagerRemoteFunctionParams>&
|
const absl::optional<EagerRemoteFunctionParams>&
|
||||||
remote_func_params) override;
|
remote_func_params) override;
|
||||||
|
|
||||||
|
void RunAsync(
|
||||||
|
ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
|
std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
|
||||||
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
|
StatusCallback done) override;
|
||||||
|
|
||||||
const OpKernel* kernel() const override { return nullptr; }
|
const OpKernel* kernel() const override { return nullptr; }
|
||||||
|
|
||||||
Device* InputDevice(int i) const override;
|
Device* InputDevice(int i) const override;
|
||||||
|
@ -159,9 +159,10 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
eager::EnqueueRequest* request = new eager::EnqueueRequest;
|
auto request = std::make_shared<RunComponentFunctionRequest>();
|
||||||
|
auto response = std::make_shared<RunComponentFunctionResponse>();
|
||||||
request->set_context_id(context_id_);
|
request->set_context_id(context_id_);
|
||||||
eager::Operation* remote_op = request->add_queue()->mutable_operation();
|
eager::Operation* remote_op = request->mutable_operation();
|
||||||
|
|
||||||
for (const auto& arg : args) {
|
for (const auto& arg : args) {
|
||||||
if (arg.index() == 0) {
|
if (arg.index() == 0) {
|
||||||
@ -188,39 +189,28 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
|||||||
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
||||||
remote_op->set_device(function_data->target);
|
remote_op->set_device(function_data->target);
|
||||||
|
|
||||||
// StreamingEnqueueAsync may introduce a deadlock. When streaming RPC is
|
// Execute component function on remote worker using RunComponentFunction RPC.
|
||||||
// disabled, Run() returns when the remote function execution completes, which
|
// Different from executing remote functions with Enqueue, this method runs
|
||||||
// might be blocked by a non-enqueued function execution.
|
// a function on remote worker without tying up a thread (i.e., pure
|
||||||
EnqueueResponse* response = new EnqueueResponse;
|
// asynchronously).
|
||||||
eager_client->EnqueueAsync(
|
eager_client->RunComponentFunctionAsync(
|
||||||
request, response,
|
request.get(), response.get(),
|
||||||
[request, response, rets, done = std::move(done)](const Status& s) {
|
[request, response, rets, done = std::move(done)](const Status& s) {
|
||||||
Status status = s;
|
if (!s.ok()) {
|
||||||
auto cleanup = gtl::MakeCleanup([request, response, &status, &done] {
|
done(s);
|
||||||
done(status);
|
|
||||||
delete request;
|
|
||||||
delete response;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (response->queue_response_size() != 1) {
|
for (const auto& tensor_proto : response->tensor()) {
|
||||||
status.Update(errors::Internal(
|
|
||||||
"Expect that the size of response queue equals 1, but got: ",
|
|
||||||
response->queue_response_size()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (const auto& tensor_proto : response->queue_response(0).tensor()) {
|
|
||||||
Tensor t;
|
Tensor t;
|
||||||
if (t.FromProto(tensor_proto)) {
|
if (t.FromProto(tensor_proto)) {
|
||||||
rets->push_back(std::move(t));
|
rets->push_back(std::move(t));
|
||||||
} else {
|
} else {
|
||||||
status.Update(errors::Internal("Could not convert tensor proto: ",
|
done(errors::Internal("Could not convert tensor proto: ",
|
||||||
tensor_proto.DebugString()));
|
tensor_proto.DebugString()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
done(Status::OK());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ class EagerClient : public core::RefCounted {
|
|||||||
CLIENT_METHOD(UpdateContext);
|
CLIENT_METHOD(UpdateContext);
|
||||||
CLIENT_METHOD(Enqueue);
|
CLIENT_METHOD(Enqueue);
|
||||||
CLIENT_METHOD(WaitQueueDone);
|
CLIENT_METHOD(WaitQueueDone);
|
||||||
|
CLIENT_METHOD(RunComponentFunction);
|
||||||
CLIENT_METHOD(KeepAlive);
|
CLIENT_METHOD(KeepAlive);
|
||||||
CLIENT_METHOD(CloseContext);
|
CLIENT_METHOD(CloseContext);
|
||||||
|
|
||||||
|
@ -92,10 +92,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetEagerOperation(const Operation& operation,
|
Status GetEagerOperationAndNumRetvals(const Operation& operation,
|
||||||
EagerContext* eager_context,
|
EagerContext* eager_context,
|
||||||
EagerExecutor* eager_executor,
|
EagerExecutor* eager_executor,
|
||||||
EagerOperation* eager_op) {
|
EagerOperation* eager_op,
|
||||||
|
int* num_retvals) {
|
||||||
const char* name = operation.name().c_str(); // Shorthand
|
const char* name = operation.name().c_str(); // Shorthand
|
||||||
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
|
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
|
||||||
absl::nullopt;
|
absl::nullopt;
|
||||||
@ -138,7 +139,10 @@ Status GetEagerOperation(const Operation& operation,
|
|||||||
for (const auto& attr : operation.attrs()) {
|
for (const auto& attr : operation.attrs()) {
|
||||||
eager_op->MutableAttrs()->Set(attr.first, attr.second);
|
eager_op->MutableAttrs()->Set(attr.first, attr.second);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
|
// TODO(nareshmodi): Consider caching this.
|
||||||
|
return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
|
||||||
|
num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
|
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
|
||||||
@ -406,18 +410,78 @@ Status EagerServiceImpl::CreateMasterContext(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EagerServiceImpl::RunComponentFunction(
|
||||||
|
const RunComponentFunctionRequest* request,
|
||||||
|
RunComponentFunctionResponse* response, StatusCallback done) {
|
||||||
|
ServerContext* context = nullptr;
|
||||||
|
Status s = GetServerContext(request->context_id(), &context);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
core::ScopedUnref context_unref(context);
|
||||||
|
|
||||||
|
auto& operation = request->operation();
|
||||||
|
// This codepath should only be triggered for executing component function
|
||||||
|
if (!operation.is_function() || !operation.is_component_function()) {
|
||||||
|
done(errors::Internal(
|
||||||
|
"RunComponentFunction request can only be used to execute "
|
||||||
|
"component functions."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContext* eager_context = context->Context();
|
||||||
|
EagerExecutor* eager_executor = &eager_context->Executor();
|
||||||
|
|
||||||
|
EagerOperation* op = new EagerOperation(eager_context);
|
||||||
|
int* num_retvals = new int(0);
|
||||||
|
s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
|
||||||
|
op, num_retvals);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!op->IsLocal()) {
|
||||||
|
done(errors::Internal(
|
||||||
|
"Received RunComponentFunction request with remote function device. "));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
|
||||||
|
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
|
||||||
|
<< operation.id();
|
||||||
|
|
||||||
|
context->Ref();
|
||||||
|
EagerLocalExecuteAsync(
|
||||||
|
op, retvals->data(), num_retvals,
|
||||||
|
[op, op_id = operation.id(), num_retvals, retvals, response,
|
||||||
|
eager_context, context, done = std::move(done)](const Status& status) {
|
||||||
|
auto wrapped_done = [&](const Status& status) {
|
||||||
|
context->Unref();
|
||||||
|
done(status);
|
||||||
|
delete op;
|
||||||
|
delete num_retvals;
|
||||||
|
delete retvals;
|
||||||
|
};
|
||||||
|
if (!status.ok()) {
|
||||||
|
wrapped_done(status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
wrapped_done(AddOpRetvalsToResponse(
|
||||||
|
eager_context, op_id, *num_retvals, retvals->data(),
|
||||||
|
[response] { return response->add_tensor(); },
|
||||||
|
[response] { return response->add_shape(); }));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||||
EagerContext* eager_context,
|
EagerContext* eager_context,
|
||||||
EagerExecutor* eager_executor,
|
EagerExecutor* eager_executor,
|
||||||
QueueResponse* queue_response) {
|
QueueResponse* queue_response) {
|
||||||
tensorflow::EagerOperation op(eager_context);
|
tensorflow::EagerOperation op(eager_context);
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GetEagerOperation(operation, eager_context, eager_executor, &op));
|
|
||||||
|
|
||||||
int num_retvals = 0;
|
int num_retvals = 0;
|
||||||
// TODO(nareshmodi): Consider caching this.
|
TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
|
||||||
TF_RETURN_IF_ERROR(GetNumRetvals(eager_context, operation.name(),
|
operation, eager_context, eager_executor, &op, &num_retvals));
|
||||||
operation.attrs(), &num_retvals));
|
|
||||||
|
|
||||||
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
||||||
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
||||||
|
@ -96,6 +96,10 @@ class EagerServiceImpl {
|
|||||||
Status WaitQueueDone(const WaitQueueDoneRequest* request,
|
Status WaitQueueDone(const WaitQueueDoneRequest* request,
|
||||||
WaitQueueDoneResponse* response);
|
WaitQueueDoneResponse* response);
|
||||||
|
|
||||||
|
void RunComponentFunction(const RunComponentFunctionRequest* request,
|
||||||
|
RunComponentFunctionResponse* response,
|
||||||
|
StatusCallback done);
|
||||||
|
|
||||||
Status KeepAlive(const KeepAliveRequest* request,
|
Status KeepAlive(const KeepAliveRequest* request,
|
||||||
KeepAliveResponse* response);
|
KeepAliveResponse* response);
|
||||||
|
|
||||||
|
@ -90,6 +90,12 @@ class FakeEagerClient : public EagerClient {
|
|||||||
CLIENT_METHOD(CloseContext);
|
CLIENT_METHOD(CloseContext);
|
||||||
#undef CLIENT_METHOD
|
#undef CLIENT_METHOD
|
||||||
|
|
||||||
|
void RunComponentFunctionAsync(const RunComponentFunctionRequest* request,
|
||||||
|
RunComponentFunctionResponse* response,
|
||||||
|
StatusCallback done) override {
|
||||||
|
impl_->RunComponentFunction(request, response, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||||
EnqueueResponse* response,
|
EnqueueResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
@ -702,7 +708,7 @@ TEST_F(FunctionWithRemoteInputsTest,
|
|||||||
CheckOutputTensorAndClose(outputs.at(0));
|
CheckOutputTensorAndClose(outputs.at(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test executes a remote function through KernelAndDeviceFunc.
|
// Test executes a remote function through KernelAndDeviceFunc::Run.
|
||||||
TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
||||||
Init();
|
Init();
|
||||||
Device* local_device;
|
Device* local_device;
|
||||||
@ -747,6 +753,58 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
|||||||
CheckOutputsAndClose(op_id);
|
CheckOutputsAndClose(op_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test executes a remote function through KernelAndDeviceFunc::RunAsync.
|
||||||
|
TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
|
||||||
|
Init();
|
||||||
|
Device* local_device;
|
||||||
|
TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
|
||||||
|
std::vector<Device*> input_dev_ptrs;
|
||||||
|
input_dev_ptrs.push_back(local_device);
|
||||||
|
FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
|
||||||
|
EagerContext* ctx = nullptr;
|
||||||
|
TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
|
||||||
|
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
||||||
|
const int64 op_id = 2;
|
||||||
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
|
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
|
||||||
|
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
||||||
|
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
||||||
|
[=]() { return op_id; }));
|
||||||
|
|
||||||
|
// Instantiate MatMulFunction on remote_device.
|
||||||
|
const NodeDef node_def = MatMulFunctionNodeDef();
|
||||||
|
TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr));
|
||||||
|
|
||||||
|
// Run MatMulFunction on remote_device.
|
||||||
|
gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
|
||||||
|
RemoteTensorHandle input;
|
||||||
|
input.set_op_id(1);
|
||||||
|
input.set_output_num(0);
|
||||||
|
input.set_op_device(local_device_);
|
||||||
|
input.set_device(local_device_);
|
||||||
|
std::vector<RemoteTensorHandle> remote_handles = {input};
|
||||||
|
TestExecuteNodeArgs inputs(
|
||||||
|
std::move(input_tensors),
|
||||||
|
[&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
|
||||||
|
*handle = remote_handles.at(index);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
Notification n;
|
||||||
|
kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
|
||||||
|
/*cancellation_manager=*/nullptr,
|
||||||
|
/*remote_func_params=*/absl::nullopt,
|
||||||
|
[&status, &n](const Status& s) {
|
||||||
|
status = s;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
CheckOutputsAndClose(op_id);
|
||||||
|
}
|
||||||
|
|
||||||
// Test creates a context and attempts to send a tensor (using the RPC), and
|
// Test creates a context and attempts to send a tensor (using the RPC), and
|
||||||
// then use the tensor.
|
// then use the tensor.
|
||||||
TEST_F(EagerServiceImplTest, SendTensorTest) {
|
TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||||
|
@ -134,6 +134,7 @@ class GrpcEagerClient : public EagerClient {
|
|||||||
CLIENT_METHOD(UpdateContext);
|
CLIENT_METHOD(UpdateContext);
|
||||||
CLIENT_METHOD(Enqueue);
|
CLIENT_METHOD(Enqueue);
|
||||||
CLIENT_METHOD(WaitQueueDone);
|
CLIENT_METHOD(WaitQueueDone);
|
||||||
|
CLIENT_METHOD(RunComponentFunction);
|
||||||
CLIENT_METHOD(KeepAlive);
|
CLIENT_METHOD(KeepAlive);
|
||||||
|
|
||||||
#undef CLIENT_METHOD
|
#undef CLIENT_METHOD
|
||||||
|
@ -52,6 +52,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
|
|||||||
ENQUEUE_REQUEST(UpdateContext);
|
ENQUEUE_REQUEST(UpdateContext);
|
||||||
ENQUEUE_REQUEST(Enqueue);
|
ENQUEUE_REQUEST(Enqueue);
|
||||||
ENQUEUE_REQUEST(WaitQueueDone);
|
ENQUEUE_REQUEST(WaitQueueDone);
|
||||||
|
ENQUEUE_REQUEST(RunComponentFunction);
|
||||||
ENQUEUE_REQUEST(KeepAlive);
|
ENQUEUE_REQUEST(KeepAlive);
|
||||||
ENQUEUE_REQUEST(CloseContext);
|
ENQUEUE_REQUEST(CloseContext);
|
||||||
#undef ENQUEUE_REQUEST
|
#undef ENQUEUE_REQUEST
|
||||||
|
@ -72,6 +72,23 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
|||||||
HANDLER(CloseContext);
|
HANDLER(CloseContext);
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
|
void RunComponentFunctionHandler(
|
||||||
|
EagerCall<RunComponentFunctionRequest, RunComponentFunctionResponse>*
|
||||||
|
call) {
|
||||||
|
env_->compute_pool->Schedule([this, call]() {
|
||||||
|
local_impl_.RunComponentFunction(
|
||||||
|
&call->request, &call->response,
|
||||||
|
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
|
||||||
|
});
|
||||||
|
Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
|
||||||
|
RunComponentFunctionRequest, RunComponentFunctionResponse>::
|
||||||
|
EnqueueRequest(
|
||||||
|
&service_, cq_.get(),
|
||||||
|
&grpc::EagerService::AsyncService::RequestRunComponentFunction,
|
||||||
|
&GrpcEagerServiceImpl::RunComponentFunctionHandler,
|
||||||
|
/*supports_cancel=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
// Called when a new request has been received as part of a StreamingEnqueue
|
// Called when a new request has been received as part of a StreamingEnqueue
|
||||||
// call.
|
// call.
|
||||||
// StreamingEnqueueHandler gets the request from the `call` and fills the
|
// StreamingEnqueueHandler gets the request from the `call` and fills the
|
||||||
|
@ -173,6 +173,18 @@ message WaitQueueDoneResponse {
|
|||||||
// propagate some stats.
|
// propagate some stats.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RunComponentFunctionRequest {
|
||||||
|
fixed64 context_id = 1;
|
||||||
|
|
||||||
|
Operation operation = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RunComponentFunctionResponse {
|
||||||
|
repeated TensorShapeProto shape = 1;
|
||||||
|
|
||||||
|
repeated TensorProto tensor = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message KeepAliveRequest {
|
message KeepAliveRequest {
|
||||||
fixed64 context_id = 1;
|
fixed64 context_id = 1;
|
||||||
}
|
}
|
||||||
@ -272,6 +284,22 @@ service EagerService {
|
|||||||
// in the stream so far.
|
// in the stream so far.
|
||||||
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
|
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
|
||||||
|
|
||||||
|
// This takes an Eager operation and executes it in async mode on the remote
|
||||||
|
// server. Different from EnqueueRequest, ops/functions sent through this
|
||||||
|
// type of requests are allowed to execute in parallel and no ordering is
|
||||||
|
// preserved by RPC stream or executor.
|
||||||
|
// This request type should only be used for executing component functions.
|
||||||
|
// Ordering of component functions should be enforced by their corresponding
|
||||||
|
// main functions. The runtime ensures the following invarients for component
|
||||||
|
// functions (CFs) and their main functions (MFs):
|
||||||
|
// (1) MF1 -> MF2 ==> CF1 -> CF2 ("->" indicates order of execution);
|
||||||
|
// (2) MF1 || MF2 ==> CF1 || CF2 ("||" indicates possible parallel execution);
|
||||||
|
// (3) For CF1 and CF2 that come from the same MF, CF1 || CF2
|
||||||
|
// For executing ops/main functions, use Enqueue or StreamingEnqueue instead
|
||||||
|
// for correct ordering.
|
||||||
|
rpc RunComponentFunction(RunComponentFunctionRequest)
|
||||||
|
returns (RunComponentFunctionResponse);
|
||||||
|
|
||||||
// Contexts are always created with a deadline and no RPCs within a deadline
|
// Contexts are always created with a deadline and no RPCs within a deadline
|
||||||
// will trigger a context garbage collection. KeepAlive calls can be used to
|
// will trigger a context garbage collection. KeepAlive calls can be used to
|
||||||
// delay this. It can also be used to validate the existence of a context ID
|
// delay this. It can also be used to validate the existence of a context ID
|
||||||
|
Loading…
Reference in New Issue
Block a user