From 448f351cfe5f009d028be71fe34705a16e339898 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Tue, 28 Apr 2020 18:38:57 -0700 Subject: [PATCH] Introduce non-blocking component function execution. The current implementation of KernelAndDeviceFunc::Run ties up a thread while the function execution is pending. This leads to distributed deadlock issues in large-scale parameter server training when the number of workers exceed the thread pool size. This change leverages the RunComponentFunction request to execute component functions in a non-blocking manner. By avoiding the thread tying issue, it removes the constraint on the number of concurrent component functions to execute in parallel. PiperOrigin-RevId: 308939721 Change-Id: I086f9ee587c4df76b303158f27c362a9bcb8314c --- .../core/common_runtime/eager/execute.cc | 128 +++++++++++++++++- .../core/common_runtime/eager/execute.h | 21 +++ .../common_runtime/eager/kernel_and_device.cc | 74 +++++----- .../common_runtime/eager/kernel_and_device.h | 30 ++++ .../eager/cluster_function_library_runtime.cc | 40 ++---- .../distributed_runtime/eager/eager_client.h | 1 + .../eager/eager_service_impl.cc | 86 ++++++++++-- .../eager/eager_service_impl.h | 4 + .../eager/eager_service_impl_test.cc | 60 +++++++- .../rpc/eager/grpc_eager_client.cc | 1 + .../rpc/eager/grpc_eager_service_impl.cc | 1 + .../rpc/eager/grpc_eager_service_impl.h | 17 +++ tensorflow/core/protobuf/eager_service.proto | 28 ++++ 13 files changed, 416 insertions(+), 75 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index b62f5ac63d9..a4e7f2f3304 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -550,6 +550,15 @@ Status GetOrCreateKernelAndDevice( } } } + + int num_outputs = kernel->num_outputs(); + if (num_outputs > *num_retvals) { + return errors::InvalidArgument("Expecting ", num_outputs, + " outputs, but *num_retvals is ", + *num_retvals); + } + *num_retvals = num_outputs; + kernel->Ref(); // Ownership of reference is passed to out_kernel. out_kernel->reset(kernel.get()); return Status::OK(); @@ -587,12 +596,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel)); int num_outputs = kernel->num_outputs(); - if (num_outputs > *num_retvals) { - return errors::InvalidArgument("Expecting ", num_outputs, - " outputs, but *num_retvals is ", - *num_retvals); - } - *num_retvals = num_outputs; TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); GraphCollector* graph_collector = nullptr; @@ -1282,4 +1285,117 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } } +namespace { +// Low-level utility function to execute the kernel specified by `kernel` on +// `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'. +// Different from `EagerKernelExecute` that ties up the thread until the +// underlying function finishes execute, this function does not block the thread +// and could return before the function execution finishes. The provided +// `StatusCallback` will be triggered after function execution with its status. +void EagerKernelExecuteAsync( + EagerContext* ctx, const absl::InlinedVector& op_inputs, + const absl::optional& remote_func_params, + const core::RefCountPtr kernel, + GraphCollector* graph_collector, CancellationManager* cancellation_manager, + TensorHandle** retvals, int num_outputs, StatusCallback done) { + auto inputs = std::make_shared(op_inputs.size()); + auto outputs = std::make_shared>(1); + + Status s = inputs->Init(ctx, op_inputs, kernel); + if (!s.ok()) { + done(s); + return; + } + + kernel->Ref(); // Ownership of reference is transferred to the callback + kernel->RunAsync( + ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager, + remote_func_params, + [retvals, inputs, outputs, num_outputs, ctx, graph_collector, + kernel_raw = kernel.get(), done = std::move(done)](const Status& s) { + auto wrapped_done = [&](const Status& s) { + kernel_raw->Unref(); + done(s); + }; + if (!s.ok()) { + wrapped_done(s); + return; + } + if (graph_collector != nullptr) { + CollectGraphs(ctx); + } + DCHECK_EQ(num_outputs, outputs->size()); + wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx, + kernel_raw)); + }); +} +} // namespace + +// Low-level utility to run the eager operation on local devices. Different from +// `EagerLocalExecute` which blocks and waits for the finishing the op +// execution, this method does not block the thread and could return before the +// eager operation execution finishes. The provided `StatusCallback` will be +// triggered after execution with its status. +void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, + int* num_retvals, StatusCallback done) { + if (VariantDeviceIsCustom(op->Device())) { + done(errors::Unimplemented( + "Custom device is not supported in EagerLocalExecuteAsync.")); + return; + } + if (!op->IsLocal()) { + done(errors::InvalidArgument( + "Remote execution is not supported in async EagerLocalExecuteAsync")); + return; + } + + ScopedMemoryDebugAnnotation op_annotation( + op->op_name(), op->remote_func_params().has_value() + ? op->remote_func_params().value().step_id.value_or(0) + : 0); + profiler::TraceMe activity( + [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); }, + profiler::TraceMeLevel::kInfo); + EagerContext& ctx = op->EagerContext(); + + core::RefCountPtr kernel; + Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); + if (!s.ok()) { + done(s); + return; + } + + int num_outputs = kernel->num_outputs(); + s = ValidateInputTypeAndPlacement(&ctx, op, kernel); + if (!s.ok()) { + done(s); + return; + } + + GraphCollector* graph_collector = nullptr; + if (ctx.ShouldStoreGraphs()) { + graph_collector = ctx.GetGraphCollector(); + } + + for (int i = 0; i < num_outputs; ++i) { + retvals[i] = nullptr; + } + + EagerKernelExecuteAsync( + &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), + graph_collector, op->GetCancellationManager(), retvals, num_outputs, + [op, num_outputs, &retvals, done = std::move(done)](const Status& s) { + op->Clear(); + // Since the operation failed, we need to Unref any outputs if they were + // allocated. + if (!s.ok()) { + for (int i = 0; i < num_outputs; ++i) { + if (retvals[i] != nullptr) { + retvals[i]->Unref(); + } + } + } + done(s); + }); +} } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.h b/tensorflow/core/common_runtime/eager/execute.h index 8ed8b9555e3..2224981db94 100644 --- a/tensorflow/core/common_runtime/eager/execute.h +++ b/tensorflow/core/common_runtime/eager/execute.h @@ -63,6 +63,27 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, EagerExecutor* executor, Device* device, bool mirror, TensorHandle** result); +// Utility function that executes a fully constructed EagerOperation +// asynchronously on the local task. This function works differently from +// EagerExecute in several ways: +// - It supports local execution only. +// - It returns after launching the eager operation to run asynchronously. +// Different from EagerExecute with async context that apends the operation +// to the end of the eager executor schedule queue, this call bypasses the +// executor logic and directly launches op execution. Ops running through +// this call does NOT have an ordering and can be executed in parallel. +// - It takes a StatusCallback which will be triggered after execution with the +// execution status. +// +// Does not support custom device. +// +// 'retvals' must point to a pre-allocated array of TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. +void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, + int* num_retvals, StatusCallback done); + } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_ diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 1eabd5c7eee..c9ff9e506b8 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -302,21 +302,37 @@ Status KernelAndDeviceFunc::Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, const absl::optional& remote_func_params) { - std::unique_ptr opts = nullptr; + Notification n; + Status status; + RunAsync(step_container, inputs, outputs, cancellation_manager, + remote_func_params, [&status, &n](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + return status; +} + +void KernelAndDeviceFunc::RunAsync( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, CancellationManager* cancellation_manager, + const absl::optional& remote_func_params, + std::function done) { + std::shared_ptr opts = nullptr; if (remote_func_params.has_value()) { const EagerRemoteFunctionParams& params = remote_func_params.value(); if (params.step_id.has_value()) { // If the function is a remote component of a cross-process function, // re-use the step id as its parent function's. - opts = absl::make_unique( + opts = std::make_shared( params.step_id.value()); } else { - opts = absl::make_unique(); + opts = std::make_shared(); } // Reuse the op id if it exists. opts->op_id = params.op_id; } else { - opts = absl::make_unique(); + opts = std::make_shared(); if (get_op_id_ && is_cross_process_) { // If the function is a cross-process function and the remote execution // goes through eager service, create an eager op id for the function. @@ -331,49 +347,43 @@ Status KernelAndDeviceFunc::Run( opts->rendezvous = rendezvous; opts->create_rendezvous = false; - CancellationManager cm; + // Create a cancellation manager to be used by FLR options if caller does not + // pass in one. If the caller does provide one, pass it to process FLR and the + // locally created one will be unused. + std::shared_ptr local_cm; if (cancellation_manager) { opts->cancellation_manager = cancellation_manager; } else { - opts->cancellation_manager = &cm; + local_cm = std::make_shared(); + opts->cancellation_manager = local_cm.get(); } opts->allow_dead_tensors = true; - opts->step_container = step_container == nullptr ? &step_container_ : step_container; - auto step_container_cleanup = gtl::MakeCleanup([step_container, this] { - if (step_container == nullptr) { - this->step_container_.CleanUp(); - } - }); - opts->collective_executor = collective_executor_ ? collective_executor_->get() : nullptr; opts->stats_collector = nullptr; opts->runner = get_runner(); - Notification done; - Status status; outputs->clear(); - { - profiler::TraceMe activity( - [&] { - return absl::StrCat("FunctionRun#name=", name(), - ",id=", opts->step_id, "#"); - }, - profiler::TraceMeLevel::kInfo); - pflr_->Run(*opts, handle_, inputs, outputs, - [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); - done.WaitForNotification(); - } - - rendezvous->Unref(); - return status; + profiler::TraceMe* activity = new profiler::TraceMe( + [&] { + return absl::StrCat("FunctionRun#name=", name(), ",id=", opts->step_id, + "#"); + }, + profiler::TraceMeLevel::kInfo); + pflr_->Run(*opts, handle_, inputs, outputs, + [opts, rendezvous, local_cm, step_container, this, activity, + done = std::move(done)](const Status& s) { + delete activity; + rendezvous->Unref(); + if (step_container == nullptr) { + this->step_container_.CleanUp(); + } + done(s); + }); } tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 641e4e79a6b..0597dc0aa2e 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -124,6 +124,20 @@ class KernelAndDevice : public core::RefCounted { std::vector* outputs, CancellationManager* cancellation_manager, const absl::optional& remote_func_params) = 0; + // Execute kernel asynchronously when applicable. Different from `Run` which + // blocks the caller thread and waits for the execution of the op/function, + // `RunAsync` could return before finishing the execution. The `done` callback + // will be triggered once the op/function execution finishes. + // Currently, calling RunAsync on ops might not honor the asynchronicity when + // it is called on an instance with only sync implementation, execute the + // kernel synchronously and then call the callback with the return status + // from sync execution. + virtual void RunAsync( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, CancellationManager* cancellation_manager, + const absl::optional& remote_func_params, + StatusCallback done) = 0; + virtual Device* InputDevice(int i) const = 0; virtual Device* OutputDevice(int idx) const = 0; // If idx'th output is a resource, returns the device backing the resource. @@ -187,6 +201,16 @@ class KernelAndDeviceOp final : public KernelAndDevice { const absl::optional& remote_func_params) override; + void RunAsync( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, CancellationManager* cancellation_manager, + const absl::optional& remote_func_params, + StatusCallback done) override { + // Trivial async implementation on top of the sync version + done(Run(step_container, inputs, outputs, cancellation_manager, + remote_func_params)); + } + const OpKernel* kernel() const override { return kernel_.get(); } Device* InputDevice(int i) const override; @@ -265,6 +289,12 @@ class KernelAndDeviceFunc final : public KernelAndDevice { const absl::optional& remote_func_params) override; + void RunAsync( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, CancellationManager* cancellation_manager, + const absl::optional& remote_func_params, + StatusCallback done) override; + const OpKernel* kernel() const override { return nullptr; } Device* InputDevice(int i) const override; diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index 10924857ac5..ec129173833 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -159,9 +159,10 @@ void EagerClusterFunctionLibraryRuntime::Run( return; } - eager::EnqueueRequest* request = new eager::EnqueueRequest; + auto request = std::make_shared(); + auto response = std::make_shared(); request->set_context_id(context_id_); - eager::Operation* remote_op = request->add_queue()->mutable_operation(); + eager::Operation* remote_op = request->mutable_operation(); for (const auto& arg : args) { if (arg.index() == 0) { @@ -188,39 +189,28 @@ void EagerClusterFunctionLibraryRuntime::Run( op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); remote_op->set_device(function_data->target); - // StreamingEnqueueAsync may introduce a deadlock. When streaming RPC is - // disabled, Run() returns when the remote function execution completes, which - // might be blocked by a non-enqueued function execution. - EnqueueResponse* response = new EnqueueResponse; - eager_client->EnqueueAsync( - request, response, + // Execute component function on remote worker using RunComponentFunction RPC. + // Different from executing remote functions with Enqueue, this method runs + // a function on remote worker without tying up a thread (i.e., pure + // asynchronously). + eager_client->RunComponentFunctionAsync( + request.get(), response.get(), [request, response, rets, done = std::move(done)](const Status& s) { - Status status = s; - auto cleanup = gtl::MakeCleanup([request, response, &status, &done] { - done(status); - delete request; - delete response; - }); - - if (!status.ok()) { + if (!s.ok()) { + done(s); return; } - if (response->queue_response_size() != 1) { - status.Update(errors::Internal( - "Expect that the size of response queue equals 1, but got: ", - response->queue_response_size())); - return; - } - for (const auto& tensor_proto : response->queue_response(0).tensor()) { + for (const auto& tensor_proto : response->tensor()) { Tensor t; if (t.FromProto(tensor_proto)) { rets->push_back(std::move(t)); } else { - status.Update(errors::Internal("Could not convert tensor proto: ", - tensor_proto.DebugString())); + done(errors::Internal("Could not convert tensor proto: ", + tensor_proto.DebugString())); return; } } + done(Status::OK()); }); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 5f260e477d6..9ca802d8a72 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -38,6 +38,7 @@ class EagerClient : public core::RefCounted { CLIENT_METHOD(UpdateContext); CLIENT_METHOD(Enqueue); CLIENT_METHOD(WaitQueueDone); + CLIENT_METHOD(RunComponentFunction); CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index d4cd11e12ad..95131150d3d 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -92,10 +92,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, return Status::OK(); } -Status GetEagerOperation(const Operation& operation, - EagerContext* eager_context, - EagerExecutor* eager_executor, - EagerOperation* eager_op) { +Status GetEagerOperationAndNumRetvals(const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + EagerOperation* eager_op, + int* num_retvals) { const char* name = operation.name().c_str(); // Shorthand absl::optional remote_func_params = absl::nullopt; @@ -138,7 +139,10 @@ Status GetEagerOperation(const Operation& operation, for (const auto& attr : operation.attrs()) { eager_op->MutableAttrs()->Set(attr.first, attr.second); } - return Status::OK(); + + // TODO(nareshmodi): Consider caching this. + return GetNumRetvals(eager_context, operation.name(), operation.attrs(), + num_retvals); } Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { @@ -406,18 +410,78 @@ Status EagerServiceImpl::CreateMasterContext( return Status::OK(); } +void EagerServiceImpl::RunComponentFunction( + const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, StatusCallback done) { + ServerContext* context = nullptr; + Status s = GetServerContext(request->context_id(), &context); + if (!s.ok()) { + done(s); + return; + } + core::ScopedUnref context_unref(context); + + auto& operation = request->operation(); + // This codepath should only be triggered for executing component function + if (!operation.is_function() || !operation.is_component_function()) { + done(errors::Internal( + "RunComponentFunction request can only be used to execute " + "component functions.")); + return; + } + + EagerContext* eager_context = context->Context(); + EagerExecutor* eager_executor = &eager_context->Executor(); + + EagerOperation* op = new EagerOperation(eager_context); + int* num_retvals = new int(0); + s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor, + op, num_retvals); + if (!s.ok()) { + done(s); + return; + } + if (!op->IsLocal()) { + done(errors::Internal( + "Received RunComponentFunction request with remote function device. ")); + return; + } + + auto* retvals = new absl::FixedArray(*num_retvals); + VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op " + << operation.id(); + + context->Ref(); + EagerLocalExecuteAsync( + op, retvals->data(), num_retvals, + [op, op_id = operation.id(), num_retvals, retvals, response, + eager_context, context, done = std::move(done)](const Status& status) { + auto wrapped_done = [&](const Status& status) { + context->Unref(); + done(status); + delete op; + delete num_retvals; + delete retvals; + }; + if (!status.ok()) { + wrapped_done(status); + return; + } + wrapped_done(AddOpRetvalsToResponse( + eager_context, op_id, *num_retvals, retvals->data(), + [response] { return response->add_tensor(); }, + [response] { return response->add_shape(); })); + }); +} + Status EagerServiceImpl::ExecuteOp(const Operation& operation, EagerContext* eager_context, EagerExecutor* eager_executor, QueueResponse* queue_response) { tensorflow::EagerOperation op(eager_context); - TF_RETURN_IF_ERROR( - GetEagerOperation(operation, eager_context, eager_executor, &op)); - int num_retvals = 0; - // TODO(nareshmodi): Consider caching this. - TF_RETURN_IF_ERROR(GetNumRetvals(eager_context, operation.name(), - operation.attrs(), &num_retvals)); + TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals( + operation, eager_context, eager_executor, &op, &num_retvals)); absl::FixedArray retvals(num_retvals); VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 113b4ccad79..06d4c36b61c 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -96,6 +96,10 @@ class EagerServiceImpl { Status WaitQueueDone(const WaitQueueDoneRequest* request, WaitQueueDoneResponse* response); + void RunComponentFunction(const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, + StatusCallback done); + Status KeepAlive(const KeepAliveRequest* request, KeepAliveResponse* response); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 50f822bf468..76ca5c318fb 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -90,6 +90,12 @@ class FakeEagerClient : public EagerClient { CLIENT_METHOD(CloseContext); #undef CLIENT_METHOD + void RunComponentFunctionAsync(const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, + StatusCallback done) override { + impl_->RunComponentFunction(request, response, std::move(done)); + } + void StreamingEnqueueAsync(const EnqueueRequest* request, EnqueueResponse* response, StatusCallback done) override { @@ -702,7 +708,7 @@ TEST_F(FunctionWithRemoteInputsTest, CheckOutputTensorAndClose(outputs.at(0)); } -// Test executes a remote function through KernelAndDeviceFunc. +// Test executes a remote function through KernelAndDeviceFunc::Run. TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { Init(); Device* local_device; @@ -747,6 +753,58 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { CheckOutputsAndClose(op_id); } +// Test executes a remote function through KernelAndDeviceFunc::RunAsync. +TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { + Init(); + Device* local_device; + TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device)); + std::vector input_dev_ptrs; + input_dev_ptrs.push_back(local_device); + FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_); + EagerContext* ctx = nullptr; + TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx)); + core::RefCountPtr kernel = nullptr; + const int64 op_id = 2; + kernel.reset(new KernelAndDeviceFunc( + flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr, + /*collective_executor=*/nullptr, local_device, fdef_.signature().name(), + [ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); }, + [=]() { return op_id; })); + + // Instantiate MatMulFunction on remote_device. + const NodeDef node_def = MatMulFunctionNodeDef(); + TF_ASSERT_OK(kernel->InstantiateFunc(node_def, nullptr)); + + // Run MatMulFunction on remote_device. + gtl::InlinedVector input_tensors = {TensorValue()}; + RemoteTensorHandle input; + input.set_op_id(1); + input.set_output_num(0); + input.set_op_device(local_device_); + input.set_device(local_device_); + std::vector remote_handles = {input}; + TestExecuteNodeArgs inputs( + std::move(input_tensors), + [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status { + *handle = remote_handles.at(index); + return Status::OK(); + }); + std::vector outputs; + + Status status; + Notification n; + kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs, + /*cancellation_manager=*/nullptr, + /*remote_func_params=*/absl::nullopt, + [&status, &n](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_ASSERT_OK(status); + CheckOutputsAndClose(op_id); +} + // Test creates a context and attempts to send a tensor (using the RPC), and // then use the tensor. TEST_F(EagerServiceImplTest, SendTensorTest) { diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index cbbd76b42ad..de4f36ea24d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -134,6 +134,7 @@ class GrpcEagerClient : public EagerClient { CLIENT_METHOD(UpdateContext); CLIENT_METHOD(Enqueue); CLIENT_METHOD(WaitQueueDone); + CLIENT_METHOD(RunComponentFunction); CLIENT_METHOD(KeepAlive); #undef CLIENT_METHOD diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index c74c648b985..27c3f30c9ab 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -52,6 +52,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(UpdateContext); ENQUEUE_REQUEST(Enqueue); ENQUEUE_REQUEST(WaitQueueDone); + ENQUEUE_REQUEST(RunComponentFunction); ENQUEUE_REQUEST(KeepAlive); ENQUEUE_REQUEST(CloseContext); #undef ENQUEUE_REQUEST diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 167a4cf2703..d95589704b1 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -72,6 +72,23 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface { HANDLER(CloseContext); #undef HANDLER + void RunComponentFunctionHandler( + EagerCall* + call) { + env_->compute_pool->Schedule([this, call]() { + local_impl_.RunComponentFunction( + &call->request, &call->response, + [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); }); + }); + Call:: + EnqueueRequest( + &service_, cq_.get(), + &grpc::EagerService::AsyncService::RequestRunComponentFunction, + &GrpcEagerServiceImpl::RunComponentFunctionHandler, + /*supports_cancel=*/false); + } + // Called when a new request has been received as part of a StreamingEnqueue // call. // StreamingEnqueueHandler gets the request from the `call` and fills the diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 7be7199f10c..e9e21777d3f 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -173,6 +173,18 @@ message WaitQueueDoneResponse { // propagate some stats. } +message RunComponentFunctionRequest { + fixed64 context_id = 1; + + Operation operation = 2; +} + +message RunComponentFunctionResponse { + repeated TensorShapeProto shape = 1; + + repeated TensorProto tensor = 2; +} + message KeepAliveRequest { fixed64 context_id = 1; } @@ -272,6 +284,22 @@ service EagerService { // in the stream so far. rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse); + // This takes an Eager operation and executes it in async mode on the remote + // server. Different from EnqueueRequest, ops/functions sent through this + // type of requests are allowed to execute in parallel and no ordering is + // preserved by RPC stream or executor. + // This request type should only be used for executing component functions. + // Ordering of component functions should be enforced by their corresponding + // main functions. The runtime ensures the following invarients for component + // functions (CFs) and their main functions (MFs): + // (1) MF1 -> MF2 ==> CF1 -> CF2 ("->" indicates order of execution); + // (2) MF1 || MF2 ==> CF1 || CF2 ("||" indicates possible parallel execution); + // (3) For CF1 and CF2 that come from the same MF, CF1 || CF2 + // For executing ops/main functions, use Enqueue or StreamingEnqueue instead + // for correct ordering. + rpc RunComponentFunction(RunComponentFunctionRequest) + returns (RunComponentFunctionResponse); + // Contexts are always created with a deadline and no RPCs within a deadline // will trigger a context garbage collection. KeepAlive calls can be used to // delay this. It can also be used to validate the existence of a context ID