Meger register function logic into EagerService.Enqueue. Then we can execute Register Function asynchronously via StreamingEqueue.
PiperOrigin-RevId: 272528172
This commit is contained in:
parent
fe15ce0d73
commit
7bdc261c65
@ -44,7 +44,6 @@ EAGER_CLIENT_METHOD(Enqueue);
|
||||
EAGER_CLIENT_METHOD(WaitQueueDone);
|
||||
EAGER_CLIENT_METHOD(KeepAlive);
|
||||
EAGER_CLIENT_METHOD(CloseContext);
|
||||
EAGER_CLIENT_METHOD(RegisterFunction);
|
||||
#undef EAGER_CLIENT_METHOD
|
||||
|
||||
#define WORKER_CLIENT_METHOD(method) \
|
||||
|
@ -69,10 +69,6 @@ class XrtGrpcEagerClient {
|
||||
void CloseContextAsync(const eager::CloseContextRequest* request,
|
||||
eager::CloseContextResponse* response,
|
||||
StatusCallback done, CallOptions* call_opts = nullptr);
|
||||
void RegisterFunctionAsync(const eager::RegisterFunctionRequest* request,
|
||||
eager::RegisterFunctionResponse* response,
|
||||
StatusCallback done,
|
||||
CallOptions* call_opts = nullptr);
|
||||
|
||||
// The following two methods are actually from the WorkerService API, not
|
||||
// EagerService, but are necessary for using remote Eager, and we include them
|
||||
|
@ -381,14 +381,15 @@ std::shared_ptr<XrtRecvTensorFuture> XrtTfContext::RecvTensor(
|
||||
}
|
||||
|
||||
Status XrtTfContext::RegisterFunction(const FunctionDef& def) {
|
||||
eager::RegisterFunctionRequest request;
|
||||
eager::EnqueueRequest request;
|
||||
request.set_context_id(context_id_);
|
||||
*request.mutable_function_def() = def;
|
||||
auto* register_function = request.add_queue()->mutable_register_function();
|
||||
*register_function->mutable_function_def() = def;
|
||||
|
||||
eager::RegisterFunctionResponse response;
|
||||
eager::EnqueueResponse response;
|
||||
Status status;
|
||||
absl::Notification done;
|
||||
eager_client_->RegisterFunctionAsync(&request, &response, [&](Status s) {
|
||||
eager_client_->EnqueueAsync(&request, &response, [&](Status s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
@ -398,39 +399,31 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
||||
// Only client context can register function on remote worker context.
|
||||
if (remote_device_manager_ == nullptr) return Status::OK();
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
|
||||
std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||
request->set_context_id(GetContextId());
|
||||
|
||||
eager::RegisterFunctionRequest request;
|
||||
request.set_context_id(GetContextId());
|
||||
*request.mutable_function_def() = fdef;
|
||||
StripDefaultAttributes(*OpRegistry::Global(),
|
||||
request.mutable_function_def()->mutable_node_def());
|
||||
std::vector<eager::RegisterFunctionResponse> responses(
|
||||
remote_contexts_.size());
|
||||
std::vector<Status> statuses(remote_contexts_.size());
|
||||
eager::RegisterFunctionOp* register_function =
|
||||
request->add_queue()->mutable_register_function();
|
||||
*register_function->mutable_function_def() = fdef;
|
||||
StripDefaultAttributes(
|
||||
*OpRegistry::Global(),
|
||||
register_function->mutable_function_def()->mutable_node_def());
|
||||
|
||||
int i = 0;
|
||||
for (const auto& target : remote_contexts_) {
|
||||
eager::EagerClient* eager_client;
|
||||
statuses[i] = remote_eager_workers_->GetClient(target, &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
blocking_counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||
|
||||
eager_client->RegisterFunctionAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &blocking_counter](const Status& status) {
|
||||
statuses[i] = status;
|
||||
blocking_counter.DecrementCount();
|
||||
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
request.get(), response, [request, response](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to register function remotely due to "
|
||||
<< status.error_message()
|
||||
<< "\nThis shouldn't happen, please file a bug to "
|
||||
"tensorflow team.";
|
||||
}
|
||||
delete response;
|
||||
});
|
||||
|
||||
i++;
|
||||
}
|
||||
blocking_counter.Wait();
|
||||
|
||||
for (int i = 0; i < remote_contexts_.size(); i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
@ -441,40 +434,34 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
|
||||
const std::vector<string>& remote_workers) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Register multiple functions on selected remote workers.
|
||||
int num_requests = function_defs.size() * remote_workers.size();
|
||||
BlockingCounter counter(num_requests);
|
||||
std::vector<Status> statuses(num_requests);
|
||||
uint64 context_id = GetContextId();
|
||||
for (int i = 0; i < remote_workers.size(); i++) {
|
||||
eager::EagerClient* eager_client;
|
||||
Status s =
|
||||
remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
|
||||
if (!s.ok()) {
|
||||
for (int j = 0; j < function_defs.size(); j++) {
|
||||
statuses[i * function_defs.size() + j] = s;
|
||||
counter.DecrementCount();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < function_defs.size(); j++) {
|
||||
eager::RegisterFunctionRequest request;
|
||||
request.set_context_id(context_id);
|
||||
*request.mutable_function_def() = *function_defs[j];
|
||||
auto* response = new eager::RegisterFunctionResponse();
|
||||
int request_idx = i * function_defs.size() + j;
|
||||
eager_client->RegisterFunctionAsync(
|
||||
&request, response,
|
||||
[request_idx, &statuses, response, &counter](const Status& s) {
|
||||
statuses[request_idx] = s;
|
||||
auto* request = new eager::EnqueueRequest;
|
||||
request->set_context_id(context_id);
|
||||
eager::RegisterFunctionOp* register_function =
|
||||
request->add_queue()->mutable_register_function();
|
||||
*register_function->mutable_function_def() = *function_defs[j];
|
||||
auto* response = new eager::EnqueueResponse;
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
request, response, [request, response](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Failed to register function remotely due to "
|
||||
<< s.error_message()
|
||||
<< "\nThis shouldn't happen, please file a bug to "
|
||||
"tensorflow team.";
|
||||
}
|
||||
delete request;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_requests; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -58,21 +58,30 @@ Status EagerClusterFunctionLibraryRuntime::Instantiate(
|
||||
const FunctionLibraryDefinition& func_lib_def =
|
||||
options.lib_def ? *options.lib_def : lib_def;
|
||||
|
||||
RegisterFunctionRequest request;
|
||||
EnqueueRequest* request = new EnqueueRequest;
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
|
||||
const uint64 context_id = ctx_->GetContextId();
|
||||
request.set_context_id(context_id);
|
||||
request->set_context_id(context_id);
|
||||
|
||||
RegisterFunctionOp* register_function =
|
||||
request->add_queue()->mutable_register_function();
|
||||
// TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to
|
||||
// support nested functions.
|
||||
*request.mutable_function_def() = *func_lib_def.Find(function_name);
|
||||
request.set_is_component_function(true);
|
||||
*register_function->mutable_function_def() =
|
||||
*func_lib_def.Find(function_name);
|
||||
register_function->set_is_component_function(true);
|
||||
|
||||
Status status;
|
||||
Notification done;
|
||||
RegisterFunctionResponse response;
|
||||
eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
// TODO(yujingzhang): make this call async.
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
request, response, [request, response, &status, &done](const Status& s) {
|
||||
status = s;
|
||||
delete request;
|
||||
delete response;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
|
@ -39,7 +39,6 @@ class EagerClient {
|
||||
CLIENT_METHOD(WaitQueueDone);
|
||||
CLIENT_METHOD(KeepAlive);
|
||||
CLIENT_METHOD(CloseContext);
|
||||
CLIENT_METHOD(RegisterFunction);
|
||||
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
|
@ -386,8 +386,10 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
|
||||
context, std::move(handle_to_decref));
|
||||
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
||||
} else {
|
||||
} else if (item.has_send_tensor()) {
|
||||
s = SendTensor(item.send_tensor(), context->Context());
|
||||
} else {
|
||||
s = RegisterFunction(item.register_function(), context->Context());
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
@ -449,16 +451,12 @@ Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::RegisterFunction(
|
||||
const RegisterFunctionRequest* request,
|
||||
RegisterFunctionResponse* response) {
|
||||
ServerContext* context = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
const RegisterFunctionOp& register_function, EagerContext* eager_context) {
|
||||
// If the function is a component of a multi-device function, we only need to
|
||||
// register it locally.
|
||||
return context->Context()->AddFunctionDef(request->function_def(),
|
||||
request->is_component_function());
|
||||
return eager_context->AddFunctionDef(
|
||||
register_function.function_def(),
|
||||
register_function.is_component_function());
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
|
||||
|
@ -102,9 +102,6 @@ class EagerServiceImpl {
|
||||
Status CloseContext(const CloseContextRequest* request,
|
||||
CloseContextResponse* response);
|
||||
|
||||
Status RegisterFunction(const RegisterFunctionRequest* request,
|
||||
RegisterFunctionResponse* response);
|
||||
|
||||
protected:
|
||||
// This is the server-side execution context. All state regarding execution of
|
||||
// a client's ops is held in this server-side context (all generated tensors,
|
||||
@ -208,6 +205,8 @@ class EagerServiceImpl {
|
||||
QueueResponse* queue_response);
|
||||
Status SendTensor(const SendTensorOp& send_tensor,
|
||||
EagerContext* eager_context);
|
||||
Status RegisterFunction(const RegisterFunctionOp& register_function,
|
||||
EagerContext* eager_context);
|
||||
const WorkerEnv* const env_; // Not owned.
|
||||
|
||||
mutex contexts_mu_;
|
||||
|
@ -84,13 +84,12 @@ class FakeEagerClient : public EagerClient {
|
||||
CLIENT_METHOD(WaitQueueDone);
|
||||
CLIENT_METHOD(KeepAlive);
|
||||
CLIENT_METHOD(CloseContext);
|
||||
CLIENT_METHOD(RegisterFunction);
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||
EnqueueResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented(""));
|
||||
done(impl_->Enqueue(request, response));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -312,13 +311,14 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||
|
||||
RegisterFunctionRequest register_function_request;
|
||||
register_function_request.set_context_id(context_id);
|
||||
*register_function_request.mutable_function_def() = MatMulFunction();
|
||||
RegisterFunctionResponse register_function_response;
|
||||
EnqueueRequest enqueue_request;
|
||||
enqueue_request.set_context_id(context_id);
|
||||
RegisterFunctionOp* register_function =
|
||||
enqueue_request.add_queue()->mutable_register_function();
|
||||
*register_function->mutable_function_def() = MatMulFunction();
|
||||
EnqueueResponse enqueue_response;
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.RegisterFunction(
|
||||
®ister_function_request, ®ister_function_response));
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&enqueue_request, &enqueue_response));
|
||||
|
||||
EnqueueRequest remote_enqueue_request;
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
|
@ -82,7 +82,6 @@ class GrpcEagerClient : public EagerClient {
|
||||
CLIENT_METHOD(Enqueue);
|
||||
CLIENT_METHOD(WaitQueueDone);
|
||||
CLIENT_METHOD(KeepAlive);
|
||||
CLIENT_METHOD(RegisterFunction);
|
||||
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
|
@ -54,7 +54,6 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
|
||||
ENQUEUE_REQUEST(WaitQueueDone);
|
||||
ENQUEUE_REQUEST(KeepAlive);
|
||||
ENQUEUE_REQUEST(CloseContext);
|
||||
ENQUEUE_REQUEST(RegisterFunction);
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
// Request a StreamingEnqueue call.
|
||||
|
@ -70,7 +70,6 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
||||
HANDLER(WaitQueueDone);
|
||||
HANDLER(KeepAlive);
|
||||
HANDLER(CloseContext);
|
||||
HANDLER(RegisterFunction);
|
||||
#undef HANDLER
|
||||
|
||||
// Called when a new request has been received as part of a StreamingEnqueue
|
||||
|
@ -39,6 +39,8 @@ message QueueItem {
|
||||
RemoteTensorHandle handle_to_decref = 1;
|
||||
Operation operation = 2;
|
||||
SendTensorOp send_tensor = 3;
|
||||
// Takes a FunctionDef and makes it enqueable on the remote worker.
|
||||
RegisterFunctionOp register_function = 4;
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,18 +141,14 @@ message CloseContextRequest {
|
||||
|
||||
message CloseContextResponse {}
|
||||
|
||||
message RegisterFunctionRequest {
|
||||
fixed64 context_id = 1;
|
||||
|
||||
FunctionDef function_def = 2;
|
||||
message RegisterFunctionOp {
|
||||
FunctionDef function_def = 1;
|
||||
|
||||
// If true, it means that function_def is produced by graph partition during
|
||||
// multi-device function instantiation.
|
||||
bool is_component_function = 3;
|
||||
bool is_component_function = 2;
|
||||
}
|
||||
|
||||
message RegisterFunctionResponse {}
|
||||
|
||||
message SendTensorOp {
|
||||
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
||||
// situation when directly sending tensors, we include an "artificial" op ID
|
||||
@ -218,8 +216,4 @@ service EagerService {
|
||||
// Closes the context. No calls to other methods using the existing context ID
|
||||
// are valid after this.
|
||||
rpc CloseContext(CloseContextRequest) returns (CloseContextResponse);
|
||||
|
||||
// Takes a FunctionDef and makes it enqueable on the remote worker.
|
||||
rpc RegisterFunction(RegisterFunctionRequest)
|
||||
returns (RegisterFunctionResponse);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user