Introduce EagerClusterFunctionLibraryRuntime.

- Allow inputs on remote devices.
- Run remote functions through eager service instead of worker service.

This cl has no behavior change since EagerClusterFunctionLibraryRuntime is not in use.

PiperOrigin-RevId: 268770768
This commit is contained in:
Yujing Zhang 2019-09-12 14:47:10 -07:00 committed by TensorFlower Gardener
parent 6015bf1d1d
commit 731984bfd0
9 changed files with 466 additions and 7 deletions

View File

@ -413,7 +413,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
return Status::OK();
}
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
const bool add_to_local_only) {
bool is_first_ref = false;
{
mutex_lock l(cache_mu_);
@ -432,7 +433,9 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
}
if (is_first_ref) {
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
return MaybeRegisterFunctionRemotely(fdef);
if (!add_to_local_only) {
return MaybeRegisterFunctionRemotely(fdef);
}
}
return Status::OK();
}

View File

@ -185,7 +185,8 @@ class EagerContext : public core::RefCounted {
EagerExecutor& Executor();
Status AddFunctionDef(const FunctionDef& fdef);
Status AddFunctionDef(const FunctionDef& fdef,
const bool add_to_local_only = false);
Status RemoveFunction(const string& func);

View File

@ -21,6 +21,29 @@ cc_library(
],
)
cc_library(
name = "cluster_function_library_runtime",
srcs = [
"cluster_function_library_runtime.cc",
],
hdrs = [
"cluster_function_library_runtime.h",
],
deps = [
":eager_client",
":remote_execute_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "destroy_tensor_handle_node",
hdrs = ["destroy_tensor_handle_node.h"],
@ -64,6 +87,7 @@ cc_library(
"eager_service_impl.h",
],
deps = [
":cluster_function_library_runtime",
":remote_mgr",
":remote_tensor_handle",
"//tensorflow:grpc++",
@ -95,6 +119,7 @@ tf_cc_test(
name = "eager_service_impl_test",
srcs = ["eager_service_impl_test.cc"],
deps = [
":cluster_function_library_runtime",
":eager_service_impl",
":remote_mgr",
"//tensorflow/c:c_api",
@ -111,6 +136,7 @@ tf_cc_test(
"//tensorflow/core/distributed_runtime:test_utils",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"@com_google_absl//absl/types:span",
],
)

View File

@ -0,0 +1,160 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include <map>
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
Status EagerClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle) {
const tensorflow::AttrTypeMap* attr_types;
bool is_function = false;
TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(function_name.c_str(),
&attr_types, &is_function));
if (!is_function) {
return errors::Internal(function_name, " is not a function.");
}
auto op = absl::make_unique<EagerOperation>(ctx_, function_name.c_str(),
is_function, attr_types);
TF_RETURN_IF_ERROR(op->SetDeviceName(options.target.c_str()));
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target
<< " (this: " << this << ")";
eager::EagerClient* eager_client = nullptr;
Device* device;
TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(options.target.c_str(), &device));
TF_RETURN_IF_ERROR(ctx_->GetClient(device, &eager_client));
if (eager_client == nullptr) {
return errors::InvalidArgument("Could not find eager client for target: ",
options.target);
}
const FunctionLibraryDefinition& func_lib_def =
options.lib_def ? *options.lib_def : lib_def;
RegisterFunctionRequest request;
const uint64 context_id = ctx_->GetContextId();
request.set_context_id(context_id);
// TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to
// support nested functions.
*request.mutable_function_def() = *func_lib_def.Find(function_name);
request.set_is_component_function(true);
Status status;
Notification done;
RegisterFunctionResponse response;
eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) {
status = s;
done.Notify();
});
done.WaitForNotification();
TF_RETURN_IF_ERROR(status);
mutex_lock l(mu_);
*handle = function_data_.size();
function_data_.emplace_back(options.target, context_id, eager_client,
std::move(op));
return Status::OK();
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
done(errors::Unimplemented("Not implemented"));
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
absl::Span<eager::RemoteTensorHandle* const> args,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client;
if (eager_client == nullptr) {
done(errors::Internal("Could not find eager client"));
return;
}
Device* device;
Status s = ctx_->FindDeviceFromName(function_data->target.c_str(), &device);
if (!s.ok()) {
done(errors::Internal("Failed to get device"));
return;
}
EagerOperation* op = function_data->op.get();
eager::EnqueueRequest* request = new eager::EnqueueRequest;
request->set_context_id(function_data->context_id);
eager::Operation* remote_op = request->add_queue()->mutable_operation();
for (size_t i = 0; i < args.size(); ++i) {
remote_op->add_inputs()->Swap(args[i]);
}
// TODO(yujingzhang): add step_id to eager::Operation to make sure that all
// component functions use the same step id.
// The remote component function should use the same op_id as its parent
// multi-device function's in order to get the global unqiue op_id generated
// by the master context.
remote_op->set_id(op_id);
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(function_data->target);
for (auto handle : op->Inputs()) {
handle->Ref();
}
// TODO(yujingzhang): Use RemoteExecuteNode once we enable async execution.
EnqueueResponse* response = new EnqueueResponse;
eager_client->EnqueueAsync(request, response,
[op, request, response, done](const Status& s) {
for (auto handle : op->Inputs()) {
handle->Unref();
}
done(s);
delete request;
delete response;
});
}
void EagerClusterFunctionLibraryRuntime::CleanUp(
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) {
done(Status::OK());
}
} // namespace eager
} // namespace tensorflow

View File

@ -0,0 +1,87 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
namespace tensorflow {
struct WorkerSession;
namespace eager {
// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
// functions across processes by making RPCs through eager service.
class EagerClusterFunctionLibraryRuntime
: public DistributedFunctionLibraryRuntime {
public:
EagerClusterFunctionLibraryRuntime(EagerContext* ctx,
DeviceMgr* remote_device_mgr)
: ctx_(ctx), remote_device_mgr_(remote_device_mgr) {}
~EagerClusterFunctionLibraryRuntime() override{};
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle) override;
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
absl::Span<eager::RemoteTensorHandle* const> args,
FunctionLibraryRuntime::DoneCallback done) override;
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) override;
DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
private:
EagerContext* ctx_;
DeviceMgr* remote_device_mgr_; // not owned.
struct FunctionData {
const string target;
const uint64 context_id;
EagerClient* eager_client = nullptr;
std::unique_ptr<EagerOperation> op;
FunctionData(const string& target, const uint64 context_id,
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
: target(target),
context_id(context_id),
eager_client(eager_client),
op(std::move(op)) {}
};
mutable mutex mu_;
std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
};
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_

View File

@ -367,7 +367,10 @@ Status EagerServiceImpl::RegisterFunction(
TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
core::ScopedUnref context_unref(context);
return context->Context()->AddFunctionDef(request->function_def());
// If the function is a component of a multi-device function, we only need to
// register it locally.
return context->Context()->AddFunctionDef(request->function_def(),
request->is_component_function());
}
Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,

View File

@ -17,8 +17,10 @@ limitations under the License.
#include <string.h>
#include "absl/types/span.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
@ -42,6 +44,13 @@ namespace {
class TestEagerServiceImpl : public EagerServiceImpl {
public:
explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
ServerContext* context = nullptr;
TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
core::ScopedUnref context_unref(context);
*ctx = context->Context();
return Status::OK();
}
Status GetTensorHandle(const uint64 context_id,
const RemoteTensorHandleInternal& remote_handle,
tensorflow::TensorHandle** handle) {
@ -54,10 +63,48 @@ class TestEagerServiceImpl : public EagerServiceImpl {
}
};
class DummyEagerClientCache : public EagerClientCache {
Status GetClient(const string& target, EagerClient** client) override {
return errors::Unimplemented("");
class FakeEagerClient : public EagerClient {
public:
FakeEagerClient() {}
~FakeEagerClient() override {}
void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
#define CLIENT_METHOD(method) \
void method##Async(const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(impl_->method(request, response)); \
}
CLIENT_METHOD(CreateContext);
CLIENT_METHOD(Enqueue);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
#undef CLIENT_METHOD
void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
done(errors::Unimplemented(""));
}
private:
TestEagerServiceImpl* impl_;
};
class DummyEagerClientCache : public EagerClientCache {
public:
DummyEagerClientCache() : client_(new FakeEagerClient) {}
Status GetClient(const string& target, EagerClient** client) override {
*client = client_.get();
return Status::OK();
}
private:
std::unique_ptr<EagerClient> client_;
};
class FakeCache : public TestWorkerCache {
@ -66,6 +113,10 @@ class FakeCache : public TestWorkerCache {
eager_client_cache->reset(new DummyEagerClientCache);
return Status::OK();
}
void ListWorkers(std::vector<string>* workers) const override {
workers->push_back("/job:localhost/replica:0/task:0");
}
};
class EagerServiceImplTest : public ::testing::Test {
@ -311,6 +362,110 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
&close_context_response));
}
// Test executes a function through EagerClusterFunctionLibraryRuntime.
TEST_F(EagerServiceImplTest, ClusterFLRTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0);
request.set_context_id(context_id);
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
const string target_device = "/job:localhost/replica:0/task:0/device:CPU:0";
// Make the fake EagerClient use the local eager_service_impl.
EagerContext* ctx = nullptr;
TF_ASSERT_OK(eager_service_impl.GetEagerContext(context_id, &ctx));
Device* device;
TF_ASSERT_OK(ctx->FindDeviceFromName(target_device.c_str(), &device));
EagerClient* client;
TF_ASSERT_OK(ctx->GetClient(device, &client));
FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client);
fake_client->SetServiceImpl(&eager_service_impl);
auto eager_cluster_flr =
absl::make_unique<EagerClusterFunctionLibraryRuntime>(ctx, nullptr);
tensorflow::FunctionDef fdef = MatMulFunction();
// Create the remote input for MatMulFunction.
EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id);
EnqueueResponse remote_enqueue_response;
std::unordered_map<string, AttrValue> const_attrs;
AttrValue val;
val.set_type(tensorflow::DataType::DT_FLOAT);
const_attrs.insert({"dtype", val});
val.Clear();
SetTensorProto(val.mutable_tensor());
const_attrs.insert({"value", val});
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, target_device,
&remote_enqueue_request);
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
// Instantiate MatMulFunction.
FunctionLibraryRuntime::InstantiateOptions options;
options.target = target_device;
options.is_multi_device_function = true;
options.input_devices.push_back(target_device);
FunctionLibraryRuntime::Handle handle;
FunctionLibraryDefinition func_lib_def{OpRegistry::Global(), {}};
TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));
TF_ASSERT_OK(eager_cluster_flr->Instantiate(
fdef.signature().name(), func_lib_def, AttrSlice(&fdef.attr()), options,
&handle));
// Run MatMulFunction.
FunctionLibraryRuntime::Options opts;
const int64 step_id = opts.step_id;
Notification done;
Status status;
RemoteTensorHandle input;
input.set_op_id(1);
input.set_output_num(0);
input.set_op_device(target_device);
input.set_device(target_device);
eager_cluster_flr->Run(opts, handle, 2, {&input},
[&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
TF_ASSERT_OK(status);
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
auto actual = t->flat<float>();
EXPECT_EQ(4, actual.size());
EXPECT_EQ(7, actual(0));
EXPECT_EQ(10, actual(1));
EXPECT_EQ(15, actual(2));
EXPECT_EQ(22, actual(3));
Status cleanup_status;
bool callback_is_called = false;
eager_cluster_flr->CleanUp(
step_id, handle, [&cleanup_status, &callback_is_called](const Status& s) {
callback_is_called = true;
cleanup_status.Update(s);
});
EXPECT_TRUE(callback_is_called);
TF_ASSERT_OK(cleanup_status);
CloseContextRequest close_context_request;
close_context_request.set_context_id(context_id);
CloseContextResponse close_context_response;
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
&close_context_response));
}
// Test creates a context and attempts to send a tensor (using the RPC), and
// then use the tensor.
TEST_F(EagerServiceImplTest, SendTensorTest) {

View File

@ -18,6 +18,11 @@ limitations under the License.
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
@ -34,6 +39,9 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/config.pb.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
#endif // IS_MOBILE_PLATFORM
namespace tensorflow {
@ -816,6 +824,18 @@ class DistributedFunctionLibraryRuntime {
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) = 0;
#if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Support outputting tensors on remote devices.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
const int64 op_id,
absl::Span<eager::RemoteTensorHandle* const> args,
FunctionLibraryRuntime::DoneCallback done) {
done(errors::Unimplemented("Unimplemented."));
}
#endif // IS_MOBILE_PLATFORM
virtual void CleanUp(uint64 step_id,
FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) = 0;

View File

@ -122,6 +122,10 @@ message RegisterFunctionRequest {
fixed64 context_id = 1;
FunctionDef function_def = 2;
// If true, it means that function_def is produced by graph partition during
// multi-device function instantiation.
bool is_component_function = 3;
}
message RegisterFunctionResponse {}