Instead of creating EagerClientCache by a separate factory method, this change add GetEagerClientCache in WorkerCacheInterface to allow it create EagerClientCache. With this change, we don't need to keep channel_cache in grpc_server_lib anymore since all instance that needs channel_cache will be created by WorkerCacheInterface.

PiperOrigin-RevId: 254863249
This commit is contained in:
Xiao Yu 2019-06-24 16:41:25 -07:00 committed by TensorFlower Gardener
parent 1d29b5c344
commit 37d3b56522
11 changed files with 40 additions and 13 deletions

View File

@ -236,10 +236,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
*base_request.add_cluster_device_attributes() = da;
}
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;

View File

@ -269,6 +269,7 @@ cc_library(
":worker_interface",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/eager:eager_client",
],
)

View File

@ -160,6 +160,7 @@ cc_library(
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_cache_partial",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)

View File

@ -322,13 +322,13 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
channel_cache_.reset(
std::shared_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index);
const string host_port = channel_cache_->TranslateTask(name_prefix);
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
auto colon_index = host_port.find_last_of(':');
@ -343,7 +343,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
" differs from expected port ", bound_port_);
}
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_,
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache,
worker_impl(), name_prefix);
return Status::OK();
}

View File

@ -95,8 +95,6 @@ class GrpcServer : public ServerInterface {
WorkerEnv* worker_env() { return &worker_env_; }
MasterEnv* master_env() { return &master_env_; }
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
protected:
virtual Status GetPort(int* port) const;
Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
@ -124,9 +122,6 @@ class GrpcServer : public ServerInterface {
const ServerDef& server_def() const { return server_def_; }
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
void set_channel_cache(GrpcChannelCache* channel_cache) {
channel_cache_.reset(channel_cache);
}
private:
// The overall server configuration.
@ -156,7 +151,6 @@ class GrpcServer : public ServerInterface {
std::unique_ptr<Master> master_impl_;
AsyncServiceInterface* master_service_ = nullptr;
std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
std::shared_ptr<GrpcChannelCache> channel_cache_;
// Implementation of a TensorFlow worker, and RPC polling thread.
WorkerEnv worker_env_;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
@ -90,6 +91,12 @@ class GrpcWorkerCache : public WorkerCachePartial {
}
}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
return Status::OK();
}
void SetLogging(bool v) override { logger_.SetLogging(v); }
void ClearLogs() override { logger_.ClearLogs(); }

View File

@ -55,6 +55,10 @@ class DummyWorkerCache : public WorkerCacheInterface {
WorkerInterface* GetOrCreateWorker(const string& target) override {
return nullptr;
}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
return errors::Unimplemented("Unimplemented.");
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return false;

View File

@ -162,6 +162,11 @@ class TestWorkerCache : public WorkerCacheInterface {
void ReleaseWorker(const string& target, WorkerInterface* worker) override {}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
return errors::Unimplemented("Unimplemented.");
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
auto it = localities_.find(device);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h" // for DeviceLocality
#include "tensorflow/core/lib/core/status.h"
@ -69,6 +70,10 @@ class WorkerCacheInterface {
DeviceLocality* locality,
StatusCallback done) = 0;
// Build and return a EagerClientCache object wrapping that channel.
virtual Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0;
// Start/stop logging activity.
virtual void SetLogging(bool active) {}

View File

@ -54,6 +54,11 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
return wrapped_->ReleaseWorker(target, worker);
}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
return wrapped_->GetEagerClientCache(eager_client_cache);
}
// Set *locality with the DeviceLocality of the specified remote device
// within its local environment. Returns true if *locality
// was set, using only locally cached data. Returns false

View File

@ -61,6 +61,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
return state.worker;
}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
return wrapped_->GetEagerClientCache(eager_client_cache);
}
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
}