Meger register function logic into EagerService.Enqueue. Then we can execute Register Function asynchronously via StreamingEqueue.

PiperOrigin-RevId: 272528172
This commit is contained in:
Xiao Yu 2019-10-02 15:04:44 -07:00 committed by TensorFlower Gardener
parent fe15ce0d73
commit 7bdc261c65
13 changed files with 80 additions and 101 deletions

View File

@ -44,7 +44,6 @@ EAGER_CLIENT_METHOD(Enqueue);
EAGER_CLIENT_METHOD(WaitQueueDone);
EAGER_CLIENT_METHOD(KeepAlive);
EAGER_CLIENT_METHOD(CloseContext);
EAGER_CLIENT_METHOD(RegisterFunction);
#undef EAGER_CLIENT_METHOD
#define WORKER_CLIENT_METHOD(method) \

View File

@ -69,10 +69,6 @@ class XrtGrpcEagerClient {
void CloseContextAsync(const eager::CloseContextRequest* request,
eager::CloseContextResponse* response,
StatusCallback done, CallOptions* call_opts = nullptr);
void RegisterFunctionAsync(const eager::RegisterFunctionRequest* request,
eager::RegisterFunctionResponse* response,
StatusCallback done,
CallOptions* call_opts = nullptr);
// The following two methods are actually from the WorkerService API, not
// EagerService, but are necessary for using remote Eager, and we include them

View File

@ -381,14 +381,15 @@ std::shared_ptr<XrtRecvTensorFuture> XrtTfContext::RecvTensor(
}
Status XrtTfContext::RegisterFunction(const FunctionDef& def) {
eager::RegisterFunctionRequest request;
eager::EnqueueRequest request;
request.set_context_id(context_id_);
*request.mutable_function_def() = def;
auto* register_function = request.add_queue()->mutable_register_function();
*register_function->mutable_function_def() = def;
eager::RegisterFunctionResponse response;
eager::EnqueueResponse response;
Status status;
absl::Notification done;
eager_client_->RegisterFunctionAsync(&request, &response, [&](Status s) {
eager_client_->EnqueueAsync(&request, &response, [&](Status s) {
status = s;
done.Notify();
});

View File

@ -21,6 +21,7 @@ limitations under the License.
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform.h"
@ -398,39 +399,31 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
// Only client context can register function on remote worker context.
if (remote_device_manager_ == nullptr) return Status::OK();
#if !defined(IS_MOBILE_PLATFORM)
BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(GetContextId());
eager::RegisterFunctionRequest request;
request.set_context_id(GetContextId());
*request.mutable_function_def() = fdef;
StripDefaultAttributes(*OpRegistry::Global(),
request.mutable_function_def()->mutable_node_def());
std::vector<eager::RegisterFunctionResponse> responses(
remote_contexts_.size());
std::vector<Status> statuses(remote_contexts_.size());
eager::RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() = fdef;
StripDefaultAttributes(
*OpRegistry::Global(),
register_function->mutable_function_def()->mutable_node_def());
int i = 0;
for (const auto& target : remote_contexts_) {
eager::EagerClient* eager_client;
statuses[i] = remote_eager_workers_->GetClient(target, &eager_client);
if (!statuses[i].ok()) {
blocking_counter.DecrementCount();
continue;
}
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
eager_client->RegisterFunctionAsync(
&request, &responses[i],
[i, &statuses, &blocking_counter](const Status& status) {
statuses[i] = status;
blocking_counter.DecrementCount();
eager::EnqueueResponse* response = new eager::EnqueueResponse();
eager_client->StreamingEnqueueAsync(
request.get(), response, [request, response](const Status& status) {
if (!status.ok()) {
LOG(ERROR) << "Failed to register function remotely due to "
<< status.error_message()
<< "\nThis shouldn't happen, please file a bug to "
"tensorflow team.";
}
delete response;
});
i++;
}
blocking_counter.Wait();
for (int i = 0; i < remote_contexts_.size(); i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
@ -441,40 +434,34 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<string>& remote_workers) {
#if !defined(IS_MOBILE_PLATFORM)
// Register multiple functions on selected remote workers.
int num_requests = function_defs.size() * remote_workers.size();
BlockingCounter counter(num_requests);
std::vector<Status> statuses(num_requests);
uint64 context_id = GetContextId();
for (int i = 0; i < remote_workers.size(); i++) {
eager::EagerClient* eager_client;
Status s =
remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
if (!s.ok()) {
for (int j = 0; j < function_defs.size(); j++) {
statuses[i * function_defs.size() + j] = s;
counter.DecrementCount();
}
continue;
}
for (int j = 0; j < function_defs.size(); j++) {
eager::RegisterFunctionRequest request;
request.set_context_id(context_id);
*request.mutable_function_def() = *function_defs[j];
auto* response = new eager::RegisterFunctionResponse();
int request_idx = i * function_defs.size() + j;
eager_client->RegisterFunctionAsync(
&request, response,
[request_idx, &statuses, response, &counter](const Status& s) {
statuses[request_idx] = s;
auto* request = new eager::EnqueueRequest;
request->set_context_id(context_id);
eager::RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() = *function_defs[j];
auto* response = new eager::EnqueueResponse;
eager_client->StreamingEnqueueAsync(
request, response, [request, response](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Failed to register function remotely due to "
<< s.error_message()
<< "\nThis shouldn't happen, please file a bug to "
"tensorflow team.";
}
delete request;
delete response;
counter.DecrementCount();
});
}
}
counter.Wait();
for (int i = 0; i < num_requests; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
}

View File

@ -58,21 +58,30 @@ Status EagerClusterFunctionLibraryRuntime::Instantiate(
const FunctionLibraryDefinition& func_lib_def =
options.lib_def ? *options.lib_def : lib_def;
RegisterFunctionRequest request;
EnqueueRequest* request = new EnqueueRequest;
EnqueueResponse* response = new EnqueueResponse;
const uint64 context_id = ctx_->GetContextId();
request.set_context_id(context_id);
request->set_context_id(context_id);
RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
// TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to
// support nested functions.
*request.mutable_function_def() = *func_lib_def.Find(function_name);
request.set_is_component_function(true);
*register_function->mutable_function_def() =
*func_lib_def.Find(function_name);
register_function->set_is_component_function(true);
Status status;
Notification done;
RegisterFunctionResponse response;
eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) {
status = s;
done.Notify();
});
// TODO(yujingzhang): make this call async.
eager_client->StreamingEnqueueAsync(
request, response, [request, response, &status, &done](const Status& s) {
status = s;
delete request;
delete response;
done.Notify();
});
done.WaitForNotification();
TF_RETURN_IF_ERROR(status);

View File

@ -39,7 +39,6 @@ class EagerClient {
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
#undef CLIENT_METHOD

View File

@ -386,8 +386,10 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
context, std::move(handle_to_decref));
s = context->Context()->Executor().AddOrExecute(std::move(node));
} else {
} else if (item.has_send_tensor()) {
s = SendTensor(item.send_tensor(), context->Context());
} else {
s = RegisterFunction(item.register_function(), context->Context());
}
if (!s.ok()) {
@ -449,16 +451,12 @@ Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
}
Status EagerServiceImpl::RegisterFunction(
const RegisterFunctionRequest* request,
RegisterFunctionResponse* response) {
ServerContext* context = nullptr;
TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
core::ScopedUnref context_unref(context);
const RegisterFunctionOp& register_function, EagerContext* eager_context) {
// If the function is a component of a multi-device function, we only need to
// register it locally.
return context->Context()->AddFunctionDef(request->function_def(),
request->is_component_function());
return eager_context->AddFunctionDef(
register_function.function_def(),
register_function.is_component_function());
}
Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,

View File

@ -102,9 +102,6 @@ class EagerServiceImpl {
Status CloseContext(const CloseContextRequest* request,
CloseContextResponse* response);
Status RegisterFunction(const RegisterFunctionRequest* request,
RegisterFunctionResponse* response);
protected:
// This is the server-side execution context. All state regarding execution of
// a client's ops is held in this server-side context (all generated tensors,
@ -208,6 +205,8 @@ class EagerServiceImpl {
QueueResponse* queue_response);
Status SendTensor(const SendTensorOp& send_tensor,
EagerContext* eager_context);
Status RegisterFunction(const RegisterFunctionOp& register_function,
EagerContext* eager_context);
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;

View File

@ -84,13 +84,12 @@ class FakeEagerClient : public EagerClient {
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
#undef CLIENT_METHOD
void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
done(errors::Unimplemented(""));
done(impl_->Enqueue(request, response));
}
private:
@ -312,13 +311,14 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
RegisterFunctionRequest register_function_request;
register_function_request.set_context_id(context_id);
*register_function_request.mutable_function_def() = MatMulFunction();
RegisterFunctionResponse register_function_response;
EnqueueRequest enqueue_request;
enqueue_request.set_context_id(context_id);
RegisterFunctionOp* register_function =
enqueue_request.add_queue()->mutable_register_function();
*register_function->mutable_function_def() = MatMulFunction();
EnqueueResponse enqueue_response;
TF_ASSERT_OK(eager_service_impl.RegisterFunction(
&register_function_request, &register_function_response));
TF_ASSERT_OK(eager_service_impl.Enqueue(&enqueue_request, &enqueue_response));
EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id);

View File

@ -82,7 +82,6 @@ class GrpcEagerClient : public EagerClient {
CLIENT_METHOD(Enqueue);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(RegisterFunction);
#undef CLIENT_METHOD

View File

@ -54,7 +54,6 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
ENQUEUE_REQUEST(WaitQueueDone);
ENQUEUE_REQUEST(KeepAlive);
ENQUEUE_REQUEST(CloseContext);
ENQUEUE_REQUEST(RegisterFunction);
#undef ENQUEUE_REQUEST
// Request a StreamingEnqueue call.

View File

@ -70,7 +70,6 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
HANDLER(WaitQueueDone);
HANDLER(KeepAlive);
HANDLER(CloseContext);
HANDLER(RegisterFunction);
#undef HANDLER
// Called when a new request has been received as part of a StreamingEnqueue

View File

@ -39,6 +39,8 @@ message QueueItem {
RemoteTensorHandle handle_to_decref = 1;
Operation operation = 2;
SendTensorOp send_tensor = 3;
// Takes a FunctionDef and makes it enqueable on the remote worker.
RegisterFunctionOp register_function = 4;
}
}
@ -139,18 +141,14 @@ message CloseContextRequest {
message CloseContextResponse {}
message RegisterFunctionRequest {
fixed64 context_id = 1;
FunctionDef function_def = 2;
message RegisterFunctionOp {
FunctionDef function_def = 1;
// If true, it means that function_def is produced by graph partition during
// multi-device function instantiation.
bool is_component_function = 3;
bool is_component_function = 2;
}
message RegisterFunctionResponse {}
message SendTensorOp {
// All remote tensors are identified by <Op ID, Output num>. To mimic this
// situation when directly sending tensors, we include an "artificial" op ID
@ -218,8 +216,4 @@ service EagerService {
// Closes the context. No calls to other methods using the existing context ID
// are valid after this.
rpc CloseContext(CloseContextRequest) returns (CloseContextResponse);
// Takes a FunctionDef and makes it enqueable on the remote worker.
rpc RegisterFunction(RegisterFunctionRequest)
returns (RegisterFunctionResponse);
}