Introduce EagerClusterFunctionLibraryRuntime.
- Allow inputs on remote devices. - Run remote functions through eager service instead of worker service. This cl has no behavior change since EagerClusterFunctionLibraryRuntime is not in use. PiperOrigin-RevId: 268770768
This commit is contained in:
parent
6015bf1d1d
commit
731984bfd0
tensorflow/core
common_runtime/eager
distributed_runtime/eager
BUILDcluster_function_library_runtime.cccluster_function_library_runtime.heager_service_impl.cceager_service_impl_test.cc
framework
protobuf
@ -413,7 +413,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
||||
const bool add_to_local_only) {
|
||||
bool is_first_ref = false;
|
||||
{
|
||||
mutex_lock l(cache_mu_);
|
||||
@ -432,7 +433,9 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
||||
}
|
||||
if (is_first_ref) {
|
||||
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
|
||||
return MaybeRegisterFunctionRemotely(fdef);
|
||||
if (!add_to_local_only) {
|
||||
return MaybeRegisterFunctionRemotely(fdef);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -185,7 +185,8 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
EagerExecutor& Executor();
|
||||
|
||||
Status AddFunctionDef(const FunctionDef& fdef);
|
||||
Status AddFunctionDef(const FunctionDef& fdef,
|
||||
const bool add_to_local_only = false);
|
||||
|
||||
Status RemoveFunction(const string& func);
|
||||
|
||||
|
@ -21,6 +21,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cluster_function_library_runtime",
|
||||
srcs = [
|
||||
"cluster_function_library_runtime.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cluster_function_library_runtime.h",
|
||||
],
|
||||
deps = [
|
||||
":eager_client",
|
||||
":remote_execute_node",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "destroy_tensor_handle_node",
|
||||
hdrs = ["destroy_tensor_handle_node.h"],
|
||||
@ -64,6 +87,7 @@ cc_library(
|
||||
"eager_service_impl.h",
|
||||
],
|
||||
deps = [
|
||||
":cluster_function_library_runtime",
|
||||
":remote_mgr",
|
||||
":remote_tensor_handle",
|
||||
"//tensorflow:grpc++",
|
||||
@ -95,6 +119,7 @@ tf_cc_test(
|
||||
name = "eager_service_impl_test",
|
||||
srcs = ["eager_service_impl_test.cc"],
|
||||
deps = [
|
||||
":cluster_function_library_runtime",
|
||||
":eager_service_impl",
|
||||
":remote_mgr",
|
||||
"//tensorflow/c:c_api",
|
||||
@ -111,6 +136,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/distributed_runtime:test_utils",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,160 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
Status EagerClusterFunctionLibraryRuntime::Instantiate(
|
||||
const string& function_name, const FunctionLibraryDefinition& lib_def,
|
||||
AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||
FunctionLibraryRuntime::LocalHandle* handle) {
|
||||
const tensorflow::AttrTypeMap* attr_types;
|
||||
bool is_function = false;
|
||||
TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(function_name.c_str(),
|
||||
&attr_types, &is_function));
|
||||
if (!is_function) {
|
||||
return errors::Internal(function_name, " is not a function.");
|
||||
}
|
||||
auto op = absl::make_unique<EagerOperation>(ctx_, function_name.c_str(),
|
||||
is_function, attr_types);
|
||||
TF_RETURN_IF_ERROR(op->SetDeviceName(options.target.c_str()));
|
||||
|
||||
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target
|
||||
<< " (this: " << this << ")";
|
||||
eager::EagerClient* eager_client = nullptr;
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(options.target.c_str(), &device));
|
||||
TF_RETURN_IF_ERROR(ctx_->GetClient(device, &eager_client));
|
||||
|
||||
if (eager_client == nullptr) {
|
||||
return errors::InvalidArgument("Could not find eager client for target: ",
|
||||
options.target);
|
||||
}
|
||||
|
||||
const FunctionLibraryDefinition& func_lib_def =
|
||||
options.lib_def ? *options.lib_def : lib_def;
|
||||
|
||||
RegisterFunctionRequest request;
|
||||
const uint64 context_id = ctx_->GetContextId();
|
||||
request.set_context_id(context_id);
|
||||
// TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to
|
||||
// support nested functions.
|
||||
*request.mutable_function_def() = *func_lib_def.Find(function_name);
|
||||
request.set_is_component_function(true);
|
||||
|
||||
Status status;
|
||||
Notification done;
|
||||
RegisterFunctionResponse response;
|
||||
eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
mutex_lock l(mu_);
|
||||
*handle = function_data_.size();
|
||||
function_data_.emplace_back(options.target, context_id, eager_client,
|
||||
std::move(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerClusterFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
||||
done(errors::Unimplemented("Not implemented"));
|
||||
}
|
||||
|
||||
void EagerClusterFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
|
||||
absl::Span<eager::RemoteTensorHandle* const> args,
|
||||
FunctionLibraryRuntime::DoneCallback done) {
|
||||
FunctionData* function_data = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
DCHECK_LE(handle, function_data_.size());
|
||||
function_data = &function_data_[handle];
|
||||
}
|
||||
|
||||
EagerClient* eager_client = function_data->eager_client;
|
||||
if (eager_client == nullptr) {
|
||||
done(errors::Internal("Could not find eager client"));
|
||||
return;
|
||||
}
|
||||
|
||||
Device* device;
|
||||
Status s = ctx_->FindDeviceFromName(function_data->target.c_str(), &device);
|
||||
if (!s.ok()) {
|
||||
done(errors::Internal("Failed to get device"));
|
||||
return;
|
||||
}
|
||||
|
||||
EagerOperation* op = function_data->op.get();
|
||||
|
||||
eager::EnqueueRequest* request = new eager::EnqueueRequest;
|
||||
request->set_context_id(function_data->context_id);
|
||||
eager::Operation* remote_op = request->add_queue()->mutable_operation();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
remote_op->add_inputs()->Swap(args[i]);
|
||||
}
|
||||
// TODO(yujingzhang): add step_id to eager::Operation to make sure that all
|
||||
// component functions use the same step id.
|
||||
// The remote component function should use the same op_id as its parent
|
||||
// multi-device function's in order to get the global unqiue op_id generated
|
||||
// by the master context.
|
||||
remote_op->set_id(op_id);
|
||||
remote_op->set_name(op->Name());
|
||||
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
||||
remote_op->set_device(function_data->target);
|
||||
|
||||
for (auto handle : op->Inputs()) {
|
||||
handle->Ref();
|
||||
}
|
||||
|
||||
// TODO(yujingzhang): Use RemoteExecuteNode once we enable async execution.
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
eager_client->EnqueueAsync(request, response,
|
||||
[op, request, response, done](const Status& s) {
|
||||
for (auto handle : op->Inputs()) {
|
||||
handle->Unref();
|
||||
}
|
||||
done(s);
|
||||
delete request;
|
||||
delete response;
|
||||
});
|
||||
}
|
||||
|
||||
void EagerClusterFunctionLibraryRuntime::CleanUp(
|
||||
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
||||
FunctionLibraryRuntime::DoneCallback done) {
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct WorkerSession;
|
||||
|
||||
namespace eager {
|
||||
|
||||
// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
|
||||
// functions across processes by making RPCs through eager service.
|
||||
class EagerClusterFunctionLibraryRuntime
|
||||
: public DistributedFunctionLibraryRuntime {
|
||||
public:
|
||||
EagerClusterFunctionLibraryRuntime(EagerContext* ctx,
|
||||
DeviceMgr* remote_device_mgr)
|
||||
: ctx_(ctx), remote_device_mgr_(remote_device_mgr) {}
|
||||
|
||||
~EagerClusterFunctionLibraryRuntime() override{};
|
||||
|
||||
Status Instantiate(const string& function_name,
|
||||
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||
FunctionLibraryRuntime::LocalHandle* handle) override;
|
||||
|
||||
void Run(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) override;
|
||||
|
||||
void Run(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
|
||||
absl::Span<eager::RemoteTensorHandle* const> args,
|
||||
FunctionLibraryRuntime::DoneCallback done) override;
|
||||
|
||||
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
||||
FunctionLibraryRuntime::DoneCallback done) override;
|
||||
|
||||
DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
|
||||
|
||||
private:
|
||||
EagerContext* ctx_;
|
||||
DeviceMgr* remote_device_mgr_; // not owned.
|
||||
|
||||
struct FunctionData {
|
||||
const string target;
|
||||
const uint64 context_id;
|
||||
EagerClient* eager_client = nullptr;
|
||||
std::unique_ptr<EagerOperation> op;
|
||||
|
||||
FunctionData(const string& target, const uint64 context_id,
|
||||
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
|
||||
: target(target),
|
||||
context_id(context_id),
|
||||
eager_client(eager_client),
|
||||
op(std::move(op)) {}
|
||||
};
|
||||
|
||||
mutable mutex mu_;
|
||||
std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
@ -367,7 +367,10 @@ Status EagerServiceImpl::RegisterFunction(
|
||||
TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
return context->Context()->AddFunctionDef(request->function_def());
|
||||
// If the function is a component of a multi-device function, we only need to
|
||||
// register it locally.
|
||||
return context->Context()->AddFunctionDef(request->function_def(),
|
||||
request->is_component_function());
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
@ -42,6 +44,13 @@ namespace {
|
||||
class TestEagerServiceImpl : public EagerServiceImpl {
|
||||
public:
|
||||
explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
|
||||
Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
|
||||
ServerContext* context = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
|
||||
core::ScopedUnref context_unref(context);
|
||||
*ctx = context->Context();
|
||||
return Status::OK();
|
||||
}
|
||||
Status GetTensorHandle(const uint64 context_id,
|
||||
const RemoteTensorHandleInternal& remote_handle,
|
||||
tensorflow::TensorHandle** handle) {
|
||||
@ -54,10 +63,48 @@ class TestEagerServiceImpl : public EagerServiceImpl {
|
||||
}
|
||||
};
|
||||
|
||||
class DummyEagerClientCache : public EagerClientCache {
|
||||
Status GetClient(const string& target, EagerClient** client) override {
|
||||
return errors::Unimplemented("");
|
||||
class FakeEagerClient : public EagerClient {
|
||||
public:
|
||||
FakeEagerClient() {}
|
||||
~FakeEagerClient() override {}
|
||||
|
||||
void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
|
||||
|
||||
#define CLIENT_METHOD(method) \
|
||||
void method##Async(const method##Request* request, \
|
||||
method##Response* response, StatusCallback done) \
|
||||
override { \
|
||||
done(impl_->method(request, response)); \
|
||||
}
|
||||
|
||||
CLIENT_METHOD(CreateContext);
|
||||
CLIENT_METHOD(Enqueue);
|
||||
CLIENT_METHOD(WaitQueueDone);
|
||||
CLIENT_METHOD(KeepAlive);
|
||||
CLIENT_METHOD(CloseContext);
|
||||
CLIENT_METHOD(RegisterFunction);
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||
EnqueueResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented(""));
|
||||
}
|
||||
|
||||
private:
|
||||
TestEagerServiceImpl* impl_;
|
||||
};
|
||||
|
||||
class DummyEagerClientCache : public EagerClientCache {
|
||||
public:
|
||||
DummyEagerClientCache() : client_(new FakeEagerClient) {}
|
||||
Status GetClient(const string& target, EagerClient** client) override {
|
||||
*client = client_.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<EagerClient> client_;
|
||||
};
|
||||
|
||||
class FakeCache : public TestWorkerCache {
|
||||
@ -66,6 +113,10 @@ class FakeCache : public TestWorkerCache {
|
||||
eager_client_cache->reset(new DummyEagerClientCache);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) const override {
|
||||
workers->push_back("/job:localhost/replica:0/task:0");
|
||||
}
|
||||
};
|
||||
|
||||
class EagerServiceImplTest : public ::testing::Test {
|
||||
@ -311,6 +362,110 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
// Test executes a function through EagerClusterFunctionLibraryRuntime.
|
||||
TEST_F(EagerServiceImplTest, ClusterFLRTest) {
|
||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||
|
||||
uint64 context_id = random::New64();
|
||||
|
||||
CreateContextRequest request;
|
||||
request.mutable_server_def()->set_job_name("localhost");
|
||||
request.mutable_server_def()->set_task_index(0);
|
||||
request.set_context_id(context_id);
|
||||
CreateContextResponse response;
|
||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||
|
||||
const string target_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
// Make the fake EagerClient use the local eager_service_impl.
|
||||
EagerContext* ctx = nullptr;
|
||||
TF_ASSERT_OK(eager_service_impl.GetEagerContext(context_id, &ctx));
|
||||
Device* device;
|
||||
TF_ASSERT_OK(ctx->FindDeviceFromName(target_device.c_str(), &device));
|
||||
EagerClient* client;
|
||||
TF_ASSERT_OK(ctx->GetClient(device, &client));
|
||||
FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client);
|
||||
fake_client->SetServiceImpl(&eager_service_impl);
|
||||
|
||||
auto eager_cluster_flr =
|
||||
absl::make_unique<EagerClusterFunctionLibraryRuntime>(ctx, nullptr);
|
||||
tensorflow::FunctionDef fdef = MatMulFunction();
|
||||
|
||||
// Create the remote input for MatMulFunction.
|
||||
EnqueueRequest remote_enqueue_request;
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
EnqueueResponse remote_enqueue_response;
|
||||
std::unordered_map<string, AttrValue> const_attrs;
|
||||
AttrValue val;
|
||||
val.set_type(tensorflow::DataType::DT_FLOAT);
|
||||
const_attrs.insert({"dtype", val});
|
||||
val.Clear();
|
||||
SetTensorProto(val.mutable_tensor());
|
||||
const_attrs.insert({"value", val});
|
||||
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, target_device,
|
||||
&remote_enqueue_request);
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
|
||||
// Instantiate MatMulFunction.
|
||||
FunctionLibraryRuntime::InstantiateOptions options;
|
||||
options.target = target_device;
|
||||
options.is_multi_device_function = true;
|
||||
options.input_devices.push_back(target_device);
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
FunctionLibraryDefinition func_lib_def{OpRegistry::Global(), {}};
|
||||
TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));
|
||||
TF_ASSERT_OK(eager_cluster_flr->Instantiate(
|
||||
fdef.signature().name(), func_lib_def, AttrSlice(&fdef.attr()), options,
|
||||
&handle));
|
||||
|
||||
// Run MatMulFunction.
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
const int64 step_id = opts.step_id;
|
||||
Notification done;
|
||||
Status status;
|
||||
RemoteTensorHandle input;
|
||||
input.set_op_id(1);
|
||||
input.set_output_num(0);
|
||||
input.set_op_device(target_device);
|
||||
input.set_device(target_device);
|
||||
eager_cluster_flr->Run(opts, handle, 2, {&input},
|
||||
[&status, &done](const Status& s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* tensor_handle;
|
||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
auto actual = t->flat<float>();
|
||||
EXPECT_EQ(4, actual.size());
|
||||
EXPECT_EQ(7, actual(0));
|
||||
EXPECT_EQ(10, actual(1));
|
||||
EXPECT_EQ(15, actual(2));
|
||||
EXPECT_EQ(22, actual(3));
|
||||
|
||||
Status cleanup_status;
|
||||
bool callback_is_called = false;
|
||||
eager_cluster_flr->CleanUp(
|
||||
step_id, handle, [&cleanup_status, &callback_is_called](const Status& s) {
|
||||
callback_is_called = true;
|
||||
cleanup_status.Update(s);
|
||||
});
|
||||
EXPECT_TRUE(callback_is_called);
|
||||
TF_ASSERT_OK(cleanup_status);
|
||||
|
||||
CloseContextRequest close_context_request;
|
||||
close_context_request.set_context_id(context_id);
|
||||
CloseContextResponse close_context_response;
|
||||
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
// Test creates a context and attempts to send a tensor (using the RPC), and
|
||||
// then use the tensor.
|
||||
TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||
|
@ -18,6 +18,11 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
// clang-format off
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -34,6 +39,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -816,6 +824,18 @@ class DistributedFunctionLibraryRuntime {
|
||||
FunctionLibraryRuntime::LocalHandle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) = 0;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// TODO(yujingzhang): Support outputting tensors on remote devices.
|
||||
virtual void Run(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::LocalHandle handle,
|
||||
const int64 op_id,
|
||||
absl::Span<eager::RemoteTensorHandle* const> args,
|
||||
FunctionLibraryRuntime::DoneCallback done) {
|
||||
done(errors::Unimplemented("Unimplemented."));
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
virtual void CleanUp(uint64 step_id,
|
||||
FunctionLibraryRuntime::LocalHandle handle,
|
||||
FunctionLibraryRuntime::DoneCallback done) = 0;
|
||||
|
@ -122,6 +122,10 @@ message RegisterFunctionRequest {
|
||||
fixed64 context_id = 1;
|
||||
|
||||
FunctionDef function_def = 2;
|
||||
|
||||
// If true, it means that function_def is produced by graph partition during
|
||||
// multi-device function instantiation.
|
||||
bool is_component_function = 3;
|
||||
}
|
||||
|
||||
message RegisterFunctionResponse {}
|
||||
|
Loading…
Reference in New Issue
Block a user