Instead of creating EagerClientCache by a separate factory method, this change add GetEagerClientCache in WorkerCacheInterface to allow it create EagerClientCache. With this change, we don't need to keep channel_cache in grpc_server_lib anymore since all instance that needs channel_cache will be created by WorkerCacheInterface.
PiperOrigin-RevId: 254863249
This commit is contained in:
parent
1d29b5c344
commit
37d3b56522
@ -236,10 +236,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
|
||||
grpc_server->channel_cache();
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
|
||||
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// Initialize remote eager workers.
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||
|
@ -269,6 +269,7 @@ cc_library(
|
||||
":worker_interface",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -160,6 +160,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_partial",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -322,13 +322,13 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
GrpcChannelSpec channel_spec;
|
||||
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
|
||||
|
||||
channel_cache_.reset(
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache(
|
||||
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
|
||||
|
||||
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
|
||||
"/task:", options.task_index);
|
||||
|
||||
const string host_port = channel_cache_->TranslateTask(name_prefix);
|
||||
const string host_port = channel_cache->TranslateTask(name_prefix);
|
||||
int requested_port;
|
||||
|
||||
auto colon_index = host_port.find_last_of(':');
|
||||
@ -343,7 +343,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
" differs from expected port ", bound_port_);
|
||||
}
|
||||
|
||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_,
|
||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache,
|
||||
worker_impl(), name_prefix);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -95,8 +95,6 @@ class GrpcServer : public ServerInterface {
|
||||
WorkerEnv* worker_env() { return &worker_env_; }
|
||||
MasterEnv* master_env() { return &master_env_; }
|
||||
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
|
||||
|
||||
protected:
|
||||
virtual Status GetPort(int* port) const;
|
||||
Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
|
||||
@ -124,9 +122,6 @@ class GrpcServer : public ServerInterface {
|
||||
const ServerDef& server_def() const { return server_def_; }
|
||||
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
|
||||
|
||||
void set_channel_cache(GrpcChannelCache* channel_cache) {
|
||||
channel_cache_.reset(channel_cache);
|
||||
}
|
||||
|
||||
private:
|
||||
// The overall server configuration.
|
||||
@ -156,7 +151,6 @@ class GrpcServer : public ServerInterface {
|
||||
std::unique_ptr<Master> master_impl_;
|
||||
AsyncServiceInterface* master_service_ = nullptr;
|
||||
std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache_;
|
||||
|
||||
// Implementation of a TensorFlow worker, and RPC polling thread.
|
||||
WorkerEnv worker_env_;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
||||
@ -90,6 +91,12 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
||||
}
|
||||
}
|
||||
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SetLogging(bool v) override { logger_.SetLogging(v); }
|
||||
|
||||
void ClearLogs() override { logger_.ClearLogs(); }
|
||||
|
@ -55,6 +55,10 @@ class DummyWorkerCache : public WorkerCacheInterface {
|
||||
WorkerInterface* GetOrCreateWorker(const string& target) override {
|
||||
return nullptr;
|
||||
}
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
return errors::Unimplemented("Unimplemented.");
|
||||
}
|
||||
bool GetDeviceLocalityNonBlocking(const string& device,
|
||||
DeviceLocality* locality) override {
|
||||
return false;
|
||||
|
@ -162,6 +162,11 @@ class TestWorkerCache : public WorkerCacheInterface {
|
||||
|
||||
void ReleaseWorker(const string& target, WorkerInterface* worker) override {}
|
||||
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
return errors::Unimplemented("Unimplemented.");
|
||||
}
|
||||
|
||||
bool GetDeviceLocalityNonBlocking(const string& device,
|
||||
DeviceLocality* locality) override {
|
||||
auto it = localities_.find(device);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h" // for DeviceLocality
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -69,6 +70,10 @@ class WorkerCacheInterface {
|
||||
DeviceLocality* locality,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
// Build and return a EagerClientCache object wrapping that channel.
|
||||
virtual Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0;
|
||||
|
||||
// Start/stop logging activity.
|
||||
virtual void SetLogging(bool active) {}
|
||||
|
||||
|
@ -54,6 +54,11 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
|
||||
return wrapped_->ReleaseWorker(target, worker);
|
||||
}
|
||||
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
return wrapped_->GetEagerClientCache(eager_client_cache);
|
||||
}
|
||||
|
||||
// Set *locality with the DeviceLocality of the specified remote device
|
||||
// within its local environment. Returns true if *locality
|
||||
// was set, using only locally cached data. Returns false
|
||||
|
@ -61,6 +61,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
|
||||
return state.worker;
|
||||
}
|
||||
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
return wrapped_->GetEagerClientCache(eager_client_cache);
|
||||
}
|
||||
|
||||
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
|
||||
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user