Implement ClusterSpec Propagation in TF Master

ClusterSpec propagation is a capability upgrade for TensorFlow that should make
it much easier to (1) build distributed TensorFlow clusters, and (2) handle
node failures. The ClusterSpec propagation capability allows TensorFlow workers
to be booted independently of each other, and with no knowledge about others.
The client can then construct a ClusterDef (ClusterSpec), and then send it
to the TF master at session creation. The master in turn then propagates the
ClusterDef along to all of the workers.
Change: 155159972
This commit is contained in:
Brennan Saeta 2017-05-04 19:43:48 -08:00 committed by TensorFlower Gardener
parent efa08d80a5
commit f28935a7d2
63 changed files with 1396 additions and 594 deletions

View File

@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceType& jit_device_name, const DeviceType& jit_device_name,
perftools::gputools::Platform* platform, perftools::gputools::Platform* platform,
Allocator* xla_allocator) Allocator* xla_allocator)
: LocalDevice(options, attrs, xla_allocator), : LocalDevice(options, attrs),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name), jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator), xla_allocator_(xla_allocator),

View File

@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
options, options,
Device::BuildDeviceAttributes( Device::BuildDeviceAttributes(
"", type, Bytes(256 << 20), DeviceLocality(), "", type, Bytes(256 << 20), DeviceLocality(),
strings::StrCat("device: XLA compilation device ", type.type())), strings::StrCat("device: XLA compilation device ", type.type()))),
cpu_allocator()),
allocator_(new XlaCompilationAllocator()) {} allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {} XlaCompilationDevice::~XlaCompilationDevice() {}

View File

@ -118,6 +118,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/types.proto" "tensorflow/core/framework/types.proto"
"tensorflow/core/framework/versions.proto" "tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/lib/core/error_codes.proto"
"tensorflow/core/protobuf/cluster.proto"
"tensorflow/core/protobuf/config.proto" "tensorflow/core/protobuf/config.proto"
"tensorflow/core/protobuf/debug.proto" "tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/rewriter_config.proto" "tensorflow/core/protobuf/rewriter_config.proto"

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc tensorflow/core/protobuf/debug.pb.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h tensorflow/core/protobuf/rewriter_config.pb.h

View File

@ -1,6 +1,7 @@
tensorflow/core/util/saved_tensor_slice.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/rewriter_config.proto tensorflow/core/protobuf/rewriter_config.proto

View File

@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto", "framework/versions.proto",
"lib/core/error_codes.proto", "lib/core/error_codes.proto",
"protobuf/config.proto", "protobuf/config.proto",
"protobuf/cluster.proto",
"protobuf/debug.proto", "protobuf/debug.proto",
"protobuf/queue_runner.proto", "protobuf/queue_runner.proto",
"protobuf/rewriter_config.proto", "protobuf/rewriter_config.proto",

View File

@ -23,8 +23,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
Device::Device(Env* env, const DeviceAttributes& device_attributes, Device::Device(Env* env, const DeviceAttributes& device_attributes)
Allocator* device_allocator)
: DeviceBase(env), device_attributes_(device_attributes) { : DeviceBase(env), device_attributes_(device_attributes) {
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
<< "Invalid device name: " << name(); << "Invalid device name: " << name();

View File

@ -53,8 +53,7 @@ namespace tensorflow {
class Device : public DeviceBase { class Device : public DeviceBase {
public: public:
Device(Env* env, const DeviceAttributes& device_attributes, Device(Env* env, const DeviceAttributes& device_attributes);
Allocator* device_allocator);
~Device() override; ~Device() override;
// Full name of this device (see top comment). // Full name of this device (see top comment).

View File

@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
for (Device* d : devices) { for (Device* d : devices) {
devices_.push_back(d); devices_.push_back(d);
// Register under both the full name and the local name. // Register under the (1) full name, (2) canonical name, and (3) local name.
string full_name = d->name(); string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d; device_map_[CopyToBackingStore(full_name)] = d;
DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
if (parsed_name.has_job && parsed_name.has_replica &&
parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
string canonical_name = DeviceNameUtils::FullName(
parsed_name.job, parsed_name.replica, parsed_name.task,
parsed_name.type, parsed_name.id);
device_map_[CopyToBackingStore(canonical_name)] = d;
}
string lname = DeviceNameUtils::LocalName(d->name()); string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d; device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++; device_type_counts_[d->device_type()]++;
@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
} }
DeviceMgr::~DeviceMgr() { DeviceMgr::~DeviceMgr() {
for (auto p : devices_) delete p; // TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
for (Device* p : devices_) delete p;
} }
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s; Status s;
auto iter = device_map_.find(name); auto iter = device_map_.find(name);
if (iter == device_map_.end()) { if (iter == device_map_.end()) {
std::vector<StringPiece> device_names;
for (auto&& itr : device_map_) {
device_names.push_back(itr.first);
}
LOG(WARNING) << "Unknown device: " << name
<< " all devices: " << str_util::Join(device_names, ", ");
return errors::InvalidArgument(name, " unknown device."); return errors::InvalidArgument(name, " unknown device.");
} }
*device = iter->second; *device = iter->second;

View File

@ -36,6 +36,7 @@ class DeviceMgr {
public: public:
// Takes ownership of each device in 'devices'. // Takes ownership of each device in 'devices'.
// TODO(zhifengc): Other initialization information. // TODO(zhifengc): Other initialization information.
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices); explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr(); ~DeviceMgr();
@ -61,6 +62,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const; int NumDeviceType(const string& type) const;
private: private:
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec; typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_; DeviceVec devices_;

View File

@ -39,7 +39,10 @@ class DeviceSet {
// Set the device designated as the "client". This device // Set the device designated as the "client". This device
// must also be registered via AddDevice(). // must also be registered via AddDevice().
void set_client_device(Device* device) { client_device_ = device; } void set_client_device(Device* device) {
DCHECK(client_device_ == nullptr);
client_device_ = device;
}
// Returns a pointer to the device designated as the "client". // Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; } Device* client_device() const { return client_device_; }

View File

@ -27,8 +27,7 @@ namespace {
static Device* Dev(const char* type, const char* name) { static Device* Dev(const char* type, const char* name) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
: Device(nullptr, attr, nullptr) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
}; };

View File

@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
int gpu_id, const string& physical_device_desc, int gpu_id, const string& physical_device_desc,
Allocator* gpu_allocator, Allocator* cpu_allocator, Allocator* gpu_allocator, Allocator* cpu_allocator,
bool sync_every_op, int32 max_streams) bool sync_every_op, int32 max_streams)
: LocalDevice(options, : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit, memory_limit, locality,
locality, physical_device_desc), physical_device_desc)),
gpu_allocator),
gpu_allocator_(gpu_allocator), gpu_allocator_(gpu_allocator),
cpu_allocator_(cpu_allocator), cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id), gpu_id_(gpu_id),

View File

@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
}; };
LocalDevice::LocalDevice(const SessionOptions& options, LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes, const DeviceAttributes& attributes)
Allocator* device_allocator) : Device(options.env, attributes), owned_tp_info_(nullptr) {
: Device(options.env, attributes, device_allocator),
owned_tp_info_(nullptr) {
// If we're running on the CPU, log warnings if we're not compiled using the // If we're running on the CPU, log warnings if we're not compiled using the
// best flags for performance. // best flags for performance.
port::WarnAboutUnusedCPUFeatures(); port::WarnAboutUnusedCPUFeatures();

View File

@ -33,8 +33,8 @@ struct SessionOptions;
// GPUDevice into more 'process-wide' abstractions. // GPUDevice into more 'process-wide' abstractions.
class LocalDevice : public Device { class LocalDevice : public Device {
public: public:
LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, LocalDevice(const SessionOptions& options,
Allocator* device_allocator); const DeviceAttributes& attributes);
~LocalDevice() override; ~LocalDevice() override;
private: private:

View File

@ -0,0 +1,54 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h"
namespace tensorflow {
// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
Device* underlying,
bool owns_underlying) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
underlying->parsed_name();
CHECK(underlying_parsed_name.has_type);
CHECK(underlying_parsed_name.has_id);
parsed_name.type = underlying_parsed_name.type;
parsed_name.id = underlying_parsed_name.id;
string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
parsed_name.task, parsed_name.type,
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
return new RenamedDevice(underlying, attributes, owns_underlying);
}
RenamedDevice::RenamedDevice(Device* underlying,
const DeviceAttributes& attributes,
bool owns_underlying)
: Device(underlying->env(), attributes),
underlying_(underlying),
owns_underlying_(owns_underlying) {}
RenamedDevice::~RenamedDevice() {
if (owns_underlying_) {
delete underlying_;
}
}
} // namespace tensorflow

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// Wraps a device with a new name, delegating work to the wrapped device.
//
// This class is used to wrap local devices when using clusterspec propagation
// where the name of a particular device may change in the context of a given
// session.
class RenamedDevice : public Device {
public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
bool owns_underlying);
~RenamedDevice() override;
// Below are virtual methods defined on DeviceBase
bool RequiresRecordingAccessedTensors() const override {
return underlying_->RequiresRecordingAccessedTensors();
}
const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
return underlying_->tensorflow_cpu_worker_threads();
}
const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
return underlying_->tensorflow_gpu_device_info();
}
Allocator* GetAllocator(AllocatorAttributes attr) override {
return underlying_->GetAllocator(attr);
}
Allocator* GetStepAllocator(AllocatorAttributes attr,
ResourceMgr* step_resource_manager) override {
return underlying_->GetStepAllocator(attr, step_resource_manager);
}
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice* eigen_sycl_device() const override {
return underlying_->eigen_sycl_device();
}
#endif
PerOpGpuDevice* MakeGpuDevice() override {
return underlying_->MakeGpuDevice();
}
void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
DeviceContext* dc, Allocator* allocator) override {
underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override {
return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
}
// Below are virtual methods defined on Device
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
underlying_->Compute(op_kernel, context);
}
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override {
underlying_->ComputeAsync(op_kernel, context, std::move(done));
}
void ConsumeListOfAccessedTensors(
DeviceContext* context, const TensorReferenceVector& tensors) override {
underlying_->ConsumeListOfAccessedTensors(context, tensors);
}
Status Sync() override { return underlying_->Sync(); }
Status MaybeRewriteGraph(const FunctionDefLibrary& library,
std::unique_ptr<Graph>* graph) override {
return underlying_->MaybeRewriteGraph(library, graph);
}
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override {
return underlying_->FillContextMap(graph, device_context_map);
}
private:
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
bool owns_underlying);
Device* const underlying_;
const bool owns_underlying_;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_

View File

@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
class FakeDevice : public Device { class FakeDevice : public Device {
private: private:
explicit FakeDevice(const DeviceAttributes& device_attributes) explicit FakeDevice(const DeviceAttributes& device_attributes)
: Device(nullptr, device_attributes, nullptr) {} : Device(nullptr, device_attributes) {}
public: public:
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }

View File

@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
const string& name, Bytes memory_limit, const string& name, Bytes memory_limit,
const DeviceLocality& locality, const DeviceLocality& locality,
Allocator* allocator) Allocator* allocator)
: LocalDevice(options, : LocalDevice(options, Device::BuildDeviceAttributes(
Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit, name, DEVICE_CPU, memory_limit, locality)),
locality),
allocator),
allocator_(allocator) {} allocator_(allocator) {}
ThreadPoolDevice::~ThreadPoolDevice() {} ThreadPoolDevice::~ThreadPoolDevice() {}

View File

@ -77,7 +77,6 @@ cc_library(
], ],
deps = [ deps = [
":graph_mgr", ":graph_mgr",
":rendezvous_mgr_interface",
":worker_cache", ":worker_cache",
"//tensorflow/core:master_proto_cc", "//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -92,9 +91,9 @@ cc_library(
deps = [ deps = [
":graph_mgr", ":graph_mgr",
":worker_session", ":worker_session",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
], ],
) )
@ -237,6 +236,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc", "//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
], ],
) )

View File

@ -35,9 +35,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env, BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
const string& worker_name) : worker_env_(worker_env) {}
: worker_env_(worker_env), worker_name_(worker_name) {}
BaseRendezvousMgr::~BaseRendezvousMgr() { BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) { for (auto& p : table_) {
@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
} }
} }
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id); return FindOrCreate(step_id);
} }
@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_); mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id); Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) { if (iter == table_.end()) {
auto rr = Create(step_id, worker_env_, worker_name_); auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first; iter = table_.insert({step_id, rr}).first;
} }
iter->second->Ref(); iter->second->Ref();
@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
} }
} }
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
const string& worker_name,
int64 step_id,
bool tolerate_dup_recv) bool tolerate_dup_recv)
: env_(env), : env_(env),
worker_name_(worker_name),
step_id_(step_id), step_id_(step_id),
local_(NewLocalRendezvous(tolerate_dup_recv)) {} local_(NewLocalRendezvous(tolerate_dup_recv)),
session_(nullptr) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() { BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty()); CHECK(active_.empty());
@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
return device_name.starts_with(worker_name); return device_name.starts_with(worker_name);
} }
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
CHECK_NE(session, nullptr) << "session must not be null!";
std::vector<DeferredCall> deferred_calls;
{
mutex_lock l(mu_);
if (session_ != nullptr) {
if (session_->worker_name == session->worker_name) {
LOG(INFO) << "Skipping rendezvous re-initialization.";
return Status::OK();
}
Status s = errors::Internal(
"Double init! Worker names would have changed from: ",
session_->worker_name, " -> ", session->worker_name);
LOG(WARNING) << s;
return s;
}
session_ = session;
std::swap(deferred_calls, deferred_calls_);
}
for (DeferredCall& call : deferred_calls) {
RecvLocalAsyncInternal(call.parsed, std::move(call.done));
}
return Status::OK();
}
WorkerSession* BaseRemoteRendezvous::session() {
mutex_lock l(mu_);
return session_;
}
bool BaseRemoteRendezvous::is_initialized() {
mutex_lock l(mu_);
return is_initialized_locked();
}
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args, const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) { const Tensor& val, const bool is_dead) {
@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!status_.ok()) return status_; if (!status_.ok()) return status_;
} DCHECK(is_initialized_locked());
if (!IsLocalDevice(worker_name_, parsed.src_device)) { if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", return errors::InvalidArgument(
parsed.FullKey(), " @ ", worker_name_); "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name);
}
} }
// Buffers "val" and "device_context" in local_. // Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead); return local_->Send(parsed, args, val, is_dead);
@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) { bool is_src) {
// Cache session pointer to avoid repeatedly taking & releasing the lock
// (e.g. calling session())
WorkerSession* sess = nullptr;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!status_.ok()) return status_; if (!status_.ok()) return status_;
if (!is_initialized_locked()) {
return errors::Internal("ValidateDevices called before initialization.");
}
sess = session_;
} }
if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) { if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", worker_name_); parsed.FullKey(), " @ ", sess->worker_name);
} }
if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) { if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ", return errors::InvalidArgument("Invalid rendezvous key (dst): ",
parsed.FullKey(), " @ ", worker_name_); parsed.FullKey(), " @ ", sess->worker_name);
} }
return Status::OK(); return Status::OK();
} }
@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args, const Rendezvous::Args& recv_args,
DoneCallback done) { DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
Status s = ValidateDevices(parsed, false /*!is_src*/); Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) { if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false); done(s, Args(), recv_args, Tensor(), false);
@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) { DoneCallback done) {
{
mutex_lock l(mu_);
if (!is_initialized_locked()) {
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
// remote worker) before the RunStep (or PartialRunStep) RPC from the
// master arrives. RecvLocalAsync thus buffers the arguments until after
// the RemoteRendezvous is Initialize()'d, when it completes the
// rendezvous logic. At some point after Initialize() is called, a Tensor
// is produced locally that will then be sent in response to the incoming
// RPC.
DeferredCall call(parsed, std::move(done));
deferred_calls_.push_back(call);
return;
}
}
RecvLocalAsyncInternal(parsed, std::move(done));
}
void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */); Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) { if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false); done(s, Args(), Args(), Tensor(), false);
@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
active_.erase(call); active_.erase(call);
} }
BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
DoneCallback done)
: parsed(parsed), done(std::move(done)) {}
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -59,15 +59,17 @@ class BaseRecvTensorCall;
// RendezvousMgr must have keys generated by Rendezvous::CreateKey(). // RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface { class BaseRendezvousMgr : public RendezvousMgrInterface {
public: public:
explicit BaseRendezvousMgr(const WorkerEnv* worker_env, explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
const string& worker_name);
~BaseRendezvousMgr() override; ~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the // Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the // "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance. // returned Rendezvous instance.
Rendezvous* Find(int64 step_id) override; //
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
RemoteRendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs // Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs. // "done" when the tensor for "key" is produced or an error occurs.
@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
protected: protected:
virtual BaseRemoteRendezvous* Create(int64 step_id, virtual BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env, const WorkerEnv* worker_env) = 0;
const string& worker_name) = 0;
private: private:
// Maps step_id to rendezvous. // Maps step_id to rendezvous.
@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Not owned. // Not owned.
const WorkerEnv* const worker_env_; const WorkerEnv* const worker_env_;
const string worker_name_;
mutex mu_; mutex mu_;
Table table_ GUARDED_BY(mu_); Table table_ GUARDED_BY(mu_);
@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Buffering of Tensor values is delegated to a "local" Rendezvous // Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds // obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers. // functionality to coordinate with remote workers.
class BaseRemoteRendezvous : public Rendezvous { class BaseRemoteRendezvous : public RemoteRendezvous {
public: public:
BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name, BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
int64 step_id, bool tolerate_dup_recv); bool tolerate_dup_recv);
// Upgrades the BaseRemoteRendezvous to full initialization.
Status Initialize(WorkerSession* session) override;
// Forwards to local_, where the Tensor "val" will be buffered and // Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored. // any waiting callback stored.
@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
// Removes "call" from active_ if "call" is in active_. // Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call); void DeregisterCall(BaseRecvTensorCall* call);
WorkerSession* session();
bool is_initialized();
~BaseRemoteRendezvous() override; ~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned. const WorkerEnv* const env_; // Not owned.
const string worker_name_;
const int64 step_id_; const int64 step_id_;
private: private:
@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
// Status given by StartAbort() if any. // Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_); Status status_ GUARDED_BY(mu_);
WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
// Data structures to handle calls when partially initialized.
struct DeferredCall {
const ParsedKey parsed;
DoneCallback done;
DeferredCall(const ParsedKey& parsed, DoneCallback done);
};
std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls. // Active outstanding RecvTensor calls.
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_); gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return session_ != nullptr;
}
// If "is_src" is true, checks that the rendezvous key "parsed"'s // If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the // source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process. // rendezvous key "parsed"'s destination is in this process.
@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
const Rendezvous::Args& out_args, const Tensor& in, const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done); Tensor* out, StatusCallback done);
// Must be called only if fully initialized.
void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
}; };

View File

@ -46,10 +46,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
GraphMgr::GraphMgr(const WorkerEnv* worker_env, GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
RendezvousMgrInterface* rendezvous_mgr) : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
// The default value of sync_on_finish will be flipped soon and this // The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well. // environment variable will be removed as well.
Status status = Status status =
@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
}; };
popts.get_incarnation = [this](const string& name) -> int64 { popts.get_incarnation = [this](const string& name) -> int64 {
Device* device = nullptr; Device* device = nullptr;
Status s = worker_env_->device_mgr->LookupDevice(name, &device); Status s = device_mgr_->LookupDevice(name, &device);
if (s.ok()) { if (s.ok()) {
return device->attributes().incarnation(); return device->attributes().incarnation();
} else { } else {
@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
ExecutionUnit* unit = &(item->units.back()); ExecutionUnit* unit = &(item->units.back());
// Find the device. // Find the device.
Status s = Status s = device_mgr_->LookupDevice(device_name, &unit->device);
worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
if (!s.ok()) { if (!s.ok()) {
// Remove the empty unit from the item as the item destructor wants all // Remove the empty unit from the item as the item destructor wants all
// units to have valid devices. // units to have valid devices.
@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
// Function library runtime. // Function library runtime.
unit->lib = NewFunctionLibraryRuntime( unit->lib = NewFunctionLibraryRuntime(
worker_env_->device_mgr, worker_env_->env, unit->device, device_mgr_, worker_env_->env, unit->device,
subgraph->versions().producer(), item->lib_def, subgraph->versions().producer(), item->lib_def,
graph_options.optimizer_options()); graph_options.optimizer_options());
@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
} }
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in); Status s = SendInputsToRendezvous(rendezvous, in);
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
} }
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out); Status s = RecvOutputsFromRendezvous(rendezvous, out);
rendezvous->Unref(); rendezvous->Unref();
return s; return s;
@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) { StatusCallback done) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out, RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) { [done, rendezvous](const Status s) {
rendezvous->Unref(); rendezvous->Unref();
@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
} }
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, WorkerSession* session,
const ExecutorOpts& /*opts*/,
StepStatsCollector* collector, StepStatsCollector* collector,
CostGraphDef* cost_graph, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return; return;
} }
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
// Sends values specified by the caller. // Sends values specified by the caller.
Status s = SendInputsToRendezvous(rendezvous, in); if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in);
}
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
item->Unref(); item->Unref();
@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
StatusCallback done) { StatusCallback done) {
const int num_units = item->units.size(); const int num_units = item->units.size();
CHECK_GE(num_units, 1); CHECK_GE(num_units, 1);
ScopedStepContainer* step_container = ScopedStepContainer* step_container = new ScopedStepContainer(
new ScopedStepContainer(step_id, [this](const string& name) { step_id,
worker_env_->device_mgr->ClearContainers({name}); [this](const string& name) { device_mgr_->ClearContainers({name}); });
});
// NOTE: Transfer one ref of rendezvous and item. // NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier = ExecutorBarrier* barrier =
new ExecutorBarrier(num_units, rendezvous, new ExecutorBarrier(num_units, rendezvous,

View File

@ -37,6 +37,8 @@ namespace tensorflow {
class ExecutorOpts; class ExecutorOpts;
class StepStatsCollector; class StepStatsCollector;
class RendezvousMgrInterface; class RendezvousMgrInterface;
class DeviceMgr;
struct WorkerSession;
// GraphMgr keeps track of a set of graphs that are registered with a // GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle // TensorFlow worker. Each registered graph is identified by a handle
@ -62,8 +64,7 @@ class RendezvousMgrInterface;
// EXPECT_EQ(out["c"], Tensor({4, 6})); // EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr { class GraphMgr {
public: public:
explicit GraphMgr(const WorkerEnv* worker_env, explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
RendezvousMgrInterface* rendezvous_mgr);
~GraphMgr(); ~GraphMgr();
// Registers a graph. Fills in "handle" // Registers a graph. Fills in "handle"
@ -78,8 +79,8 @@ class GraphMgr {
typedef std::map<string, Tensor> NamedTensors; typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback; typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id, void ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, StepStatsCollector* collector, WorkerSession* session, const ExecutorOpts& opts,
CostGraphDef* cost_graph, StepStatsCollector* collector, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done); const NamedTensors& in, StatusCallback done);
@ -131,7 +132,7 @@ class GraphMgr {
}; };
const WorkerEnv* worker_env_; // Not owned. const WorkerEnv* worker_env_; // Not owned.
RendezvousMgrInterface* rendezvous_mgr_; // Not owned. DeviceMgr* device_mgr_;
CostModelManager cost_model_manager_; CostModelManager cost_model_manager_;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -48,12 +49,17 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
namespace {
const char* const kGrpcProtocol = "grpc://";
} // namespace
Master::Master(MasterEnv* env, double session_gc_seconds) Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env), : env_(env),
last_1000_steps_(1000), last_1000_steps_(1000),
@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) { CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() { SchedClosure([this, req, resp, done]() {
Status status; Status status;
WorkerCacheFactoryOptions worker_cache_factory_options;
string grpc_protocol("grpc");
worker_cache_factory_options.protocol = &grpc_protocol;
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
status = ValidateExternalGraphDefSyntax(req->graph_def()); status = ValidateExternalGraphDefSyntax(req->graph_def());
if (!status.ok()) return; if (!status.ok()) return;
// Ping all the workers and build the list of devices that the
// session will use. // The following 4 variables are set differently, depending on whether this
// session uses a client-provided clusterspec or not.
WorkerCacheInterface* worker_cache = nullptr;
// Note: worker_cache_ptr will be null except if this session is using a
// client-supplied ClusterDef (ClusterSpec propagation).
std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
std::unique_ptr<DeviceSet> device_set;
// TODO(saeta): Convert to std::make_unique when available. // TODO(saeta): Convert to std::make_unique when available.
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices( std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
new std::vector<std::unique_ptr<Device>>()); new std::vector<std::unique_ptr<Device>>());
status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
env_, env_->worker_cache, if (req->config().has_cluster_def()) {
remote_devices.get()); worker_cache_factory_options.cluster_def = &req->config().cluster_def();
if (!status.ok()) return;
// Set the server_def's job_name and task_index fields.
string normalized_string;
string grpc_protocol(kGrpcProtocol);
if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
0) {
normalized_string =
req->target().substr(grpc_protocol.length(), string::npos);
} else {
normalized_string = req->target();
}
for (auto&& job : req->config().cluster_def().job()) {
for (auto&& task : job.tasks()) {
if (task.second == normalized_string) {
if (worker_cache_factory_options.job_name != nullptr) {
status = errors::InvalidArgument(
"Found multiple matching tasks that correspond to "
"to the master. Master target: '",
req->target(), "'. ClusterDef: ",
req->config().cluster_def().ShortDebugString());
LOG(ERROR) << status;
return;
}
if (env_->local_devices[0]->parsed_name().job == job.name() &&
env_->local_devices[0]->parsed_name().task == task.first) {
// TODO(b/37868888): Remove this limitation when resolved
status = errors::InvalidArgument(
"The ClusterSpec names the job and task index to be the same "
"names that were provided when the server booted. This is "
"currently not allowed. Job: ",
job.name(), ", task index: ", task.first);
return;
}
worker_cache_factory_options.job_name = &job.name();
worker_cache_factory_options.task_index = task.first;
}
}
}
// Create the worker cache from the computed server_def.
status = env_->worker_cache_factory(worker_cache_factory_options,
&worker_cache);
if (!status.ok()) return;
worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
DeviceNameUtils::ParsedName name = d->parsed_name();
if (name.job == *worker_cache_factory_options.job_name &&
name.task == worker_cache_factory_options.task_index &&
name.type == "CPU") {
device_set->set_client_device(d.get());
}
}
} else {
worker_cache = env_->worker_cache;
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env_->local_devices) {
device_set->AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
device_set->set_client_device(d);
}
num_local_devices++;
}
}
CHECK(device_set->client_device());
SessionOptions options; SessionOptions options;
options.config = req->config(); options.config = req->config();
MasterSession* session =
env_->master_session_factory(options, env_, std::move(remote_devices)); MasterSession* session = env_->master_session_factory(
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
std::move(device_set));
GraphDef* gdef = GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def(); const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
status = session->Create(gdef);
status = session->Create(gdef, worker_cache_factory_options);
if (!status.ok()) { if (!status.ok()) {
session->Close().IgnoreError(); session->Close().IgnoreError();
session->Unref(); session->Unref();

View File

@ -19,17 +19,41 @@ limitations under the License.
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "tensorflow/core/distributed_runtime/master_session.h" #include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
class Device; class Device;
class DeviceSet;
class Env; class Env;
class MasterSession; class MasterSession;
class OpRegistryInterface; class OpRegistryInterface;
class WorkerCacheInterface; class WorkerCacheInterface;
// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
const ClusterDef* cluster_def = nullptr;
const string* job_name = nullptr;
int task_index;
const string* protocol = nullptr;
WorkerCacheFactoryOptions() {}
// Construct from a ServerDef proto.
//
// Note: server_def must outlive WorkerCacheFactoryOptions!
WorkerCacheFactoryOptions(const ServerDef& server_def) {
if (server_def.has_cluster() && !server_def.job_name().empty()) {
cluster_def = &server_def.cluster();
job_name = &server_def.job_name();
task_index = server_def.task_index();
protocol = &server_def.protocol();
}
}
};
// The master environment class, which holds a bag of pointers to // The master environment class, which holds a bag of pointers to
// per-master state. // per-master state.
// //
@ -57,8 +81,14 @@ struct MasterEnv {
// `MasterEnv*` is retained by the caller. // `MasterEnv*` is retained by the caller.
std::function<MasterSession*( std::function<MasterSession*(
SessionOptions, MasterEnv*, SessionOptions, MasterEnv*,
std::unique_ptr<std::vector<std::unique_ptr<Device>>>)> std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
std::unique_ptr<WorkerCacheInterface>,
std::unique_ptr<DeviceSet> device_set)>
master_session_factory; master_session_factory;
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
}; };
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -528,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_is_partial(is_partial_); c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run); c->req->set_is_last_partial_run(is_last_partial_run);
} }
c->req->set_session_handle(session_handle_);
c->req->set_graph_handle(part.graph_handle); c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id); c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts; *c->req->mutable_exec_opts() = exec_opts;
@ -871,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
// The graph handle may be empty if we failed during partition registration. // The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) { if (!part.graph_handle.empty()) {
Call* c = new Call; Call* c = new Call;
c->req.set_session_handle(session_handle_);
c->req.set_graph_handle(part.graph_handle); c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this` // NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called. // could be deleted before the callback is called.
@ -973,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
MasterSession::MasterSession( MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env, const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory) StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt), : session_opts_(opt),
env_(env), env_(env),
handle_(strings::FpToString(random::New64())), handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)), remote_devs_(std::move(remote_devs)),
worker_cache_(std::move(worker_cache)),
devices_(std::move(device_set)),
stats_publisher_factory_(std::move(stats_publisher_factory)), stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0), graph_version_(0),
run_graphs_(5), run_graphs_(5),
partial_run_graphs_(5) { partial_run_graphs_(5) {
UpdateLastAccessTime(); UpdateLastAccessTime();
CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size(); << " #remote " << remote_devs_->size();
for (auto&& d : *remote_devs_) {
devices_.AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env->local_devices) {
devices_.AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
devices_.set_client_device(d);
}
num_local_devices++;
}
LOG(INFO) << "Start master session " << handle_ LOG(INFO) << "Start master session " << handle_
<< " with config: " << std::endl << " with config: " << std::endl
<< session_opts_.config.DebugString(); << session_opts_.config.DebugString();
@ -1012,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros()); last_access_time_usec_.store(Env::Default()->NowMicros());
} }
Status MasterSession::Create(GraphDef* graph_def) { Status MasterSession::Create(GraphDef* graph_def,
const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.graph_options().place_pruned_graph()) { if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option. // TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the " LOG(WARNING) << "Distributed session does not support the "
@ -1020,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
} }
SimpleGraphExecutionStateOptions options; SimpleGraphExecutionStateOptions execution_options;
options.device_set = &devices_; execution_options.device_set = devices_.get();
options.session_options = &session_opts_; execution_options.session_options = &session_opts_;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
graph_def, options, &execution_state_)); graph_def, execution_options, &execution_state_));
}
if (options.cluster_def != nullptr) {
return CreateWorkerSessions(options);
} }
return Status::OK(); return Status::OK();
} }
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
<< "dynamic cluster membership.";
std::vector<string> worker_names;
worker_cache_->ListWorkers(&worker_names);
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;
// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;
// Request and responses used for a given worker.
CreateWorkerSessionRequest request;
CreateWorkerSessionResponse response;
Status status = Status::OK();
};
BlockingCounter done(worker_names.size());
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
auto cleanup = gtl::MakeCleanup([this, &workers] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
*workers[i].request.mutable_server_def()->mutable_cluster() =
*options.cluster_def;
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
status = errors::Internal("Could not parse name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
if (!name.has_job || !name.has_task) {
status = errors::Internal("Incomplete worker name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
workers[i].request.mutable_server_def()->set_job_name(name.job);
workers[i].request.mutable_server_def()->set_task_index(name.task);
}
for (size_t i = 0; i < worker_names.size(); ++i) {
auto cb = [i, &workers, &done](const Status& s) {
workers[i].status = s;
done.DecrementCount();
};
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
&workers[i].response, cb);
}
done.Wait();
for (size_t i = 0; i < workers.size(); ++i) {
status.Update(workers[i].status);
}
return status;
}
Status MasterSession::Extend(const ExtendSessionRequest* req, Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) { ExtendSessionResponse* resp) {
UpdateLastAccessTime(); UpdateLastAccessTime();
@ -1060,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK(); return Status::OK();
} }
WorkerCacheInterface* MasterSession::get_worker_cache() const {
if (worker_cache_) {
return worker_cache_.get();
}
return env_->worker_cache;
}
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) { ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts); const uint64 hash = HashBuildGraphOptions(opts);
@ -1083,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
<< "\n"; << "\n";
std::unique_ptr<SimpleClientGraph> client_graph; std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph( auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_, handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial, stats_publisher_factory_, execution_state_.get(), is_partial,
env_->worker_cache); worker_cache);
iter = m->insert({hash, entry}).first; iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph"; VLOG(1) << "Preparing to execute new graph";
} }
@ -1162,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
return errors::FailedPrecondition("Session is closed."); return errors::FailedPrecondition("Session is closed.");
} }
++num_running_; ++num_running_;
// Note: all code paths must eventually call MarkRunCompletion()
// in order to appropriate decrement the num_running_ counter.
} }
Status status; Status status;
if (!req.partial_run_handle().empty()) { if (!req.partial_run_handle().empty()) {
@ -1169,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
} else { } else {
status = DoRunWithLocalExecution(opts, req, resp); status = DoRunWithLocalExecution(opts, req, resp);
} }
{
mutex_lock l(mu_);
--num_running_;
if (num_running_ == 0) {
num_running_is_zero_.notify_all();
}
}
return status; return status;
} }
// Decrements num_running_ and broadcasts if num_running_ is zero.
void MasterSession::MarkRunCompletion() {
mutex_lock l(mu_);
--num_running_;
if (num_running_ == 0) {
num_running_is_zero_.notify_all();
}
}
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so. // Registers subgraphs if haven't done so.
PartitionOptions popts; PartitionOptions popts;
@ -1188,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
return strings::StrCat(prefix, "_S", next_node_id_++); return strings::StrCat(prefix, "_S", next_node_id_++);
}; };
popts.get_incarnation = [this](const string& name) -> int64 { popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_.FindDeviceByName(name); Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) { if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation; return PartitionOptions::kIllegalIncarnation;
} else { } else {
@ -1223,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
Status MasterSession::DoPartialRun(CallOptions* opts, Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) { MutableRunStepResponseWrapper* resp) {
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle(); const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr; RunState* run_state = nullptr;
{ {
@ -1321,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
rcg->Ref(); rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata()); req.options(), resp->mutable_metadata());
cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync( rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) { run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s; LOG(ERROR) << "Cleanup partition error: " << s;
} }
rcg->Unref(); rcg->Unref();
MarkRunCompletion();
}); });
mutex_lock l(mu_); mutex_lock l(mu_);
partial_runs_.erase(prun_handle); partial_runs_.erase(prun_handle);
@ -1368,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
Status MasterSession::DoRunWithLocalExecution( Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req, CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) { MutableRunStepResponseWrapper* resp) {
VLOG(2) << "DoRunWithLocalExecution " VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
<< "req: " << req.DebugString();
PerStepState pss; PerStepState pss;
pss.start_micros = Env::Default()->NowMicros(); pss.start_micros = Env::Default()->NowMicros();
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare. // Prepare.
BuildGraphOptions bgopts; BuildGraphOptions bgopts;
@ -1438,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
} }
} }
rcg->Ref(); rcg->Ref();
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s; LOG(ERROR) << "Cleanup partition error: " << s;
} }
rcg->Unref(); rcg->Unref();
MarkRunCompletion();
}); });
return s; return s;
} }

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h" #include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/master.pb.h"
@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
MasterSession( MasterSession(
const SessionOptions& options, const MasterEnv* env, const SessionOptions& options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory); StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(), // Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close(). // Run(), or Close().
// //
// After this method returns, `def` will no longer be valid. // After this method returns, `def` will no longer be valid.
Status Create(GraphDef* def); Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
// Returns the session handle. // Returns the session handle.
const string& handle() const { return handle_; } const string& handle() const { return handle_; }
@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
// The optional session-specific worker cluster.
// TODO(saeta): Convert to std::optional when available.
std::unique_ptr<WorkerCacheInterface> worker_cache_;
// Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
WorkerCacheInterface* get_worker_cache() const;
// The device set used by this session. // The device set used by this session.
DeviceSet devices_; std::unique_ptr<DeviceSet> devices_;
StatsPublisherFactory stats_publisher_factory_; StatsPublisherFactory stats_publisher_factory_;
@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
// Private dtor. The client must call Close(). // Private dtor. The client must call Close().
virtual ~MasterSession(); virtual ~MasterSession();
// Creates sessions on all workers.
//
// If this session is operating using the new ClusterSpec propagation behavior
// call this method in order to propagate the cluster membership to all
// workers.
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
Status StartStep(const BuildGraphOptions& opts, int64* count, Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial); ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
MutableRunStepResponseWrapper* resp); MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp); MutableRunStepResponseWrapper* resp);
void MarkRunCompletion();
void UpdateLastAccessTime(); void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);

View File

@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
const string& InMemoryRunGraphRequest::session_handle() const {
return session_handle_;
}
void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
session_handle_ = handle;
}
const string& InMemoryRunGraphRequest::graph_handle() const { const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_; return graph_handle_;
} }
@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) { if (!proto_version_) {
proto_version_.reset(new RunGraphRequest); proto_version_.reset(new RunGraphRequest);
proto_version_->set_session_handle(session_handle());
proto_version_->set_graph_handle(graph_handle()); proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id()); proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts(); *proto_version_->mutable_exec_opts() = exec_opts();
@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
return *proto_version_; return *proto_version_;
} }
const string& MutableProtoRunGraphRequest::session_handle() const {
return request_.session_handle();
}
void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
request_.set_session_handle(handle);
}
const string& MutableProtoRunGraphRequest::graph_handle() const { const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle(); return request_.graph_handle();
} }
@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
: request_(request) {} : request_(request) {}
const string& ProtoRunGraphRequest::session_handle() const {
return request_->session_handle();
}
const string& ProtoRunGraphRequest::graph_handle() const { const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle(); return request_->graph_handle();
} }

View File

@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
public: public:
virtual ~RunGraphRequestWrapper() {} virtual ~RunGraphRequestWrapper() {}
// The session handle used to register the graph. If empty, a single global
// namespace is used.
virtual const string& session_handle() const = 0;
// REQUIRED: graph_handle must be returned by a RegisterGraph call // REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService. // to the same WorkerService.
virtual const string& graph_handle() const = 0; virtual const string& graph_handle() const = 0;
@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
// See `RunGraphRequestWrapper` above for a description of the fields. // See `RunGraphRequestWrapper` above for a description of the fields.
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public: public:
virtual void set_session_handle(const string& handle) = 0;
virtual void set_graph_handle(const string& handle) = 0; virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0; virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0; virtual ExecutorOpts* mutable_exec_opts() = 0;
@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
public: public:
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;
@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override; const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods. // MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override; void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override; void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override; ExecutorOpts* mutable_exec_opts() override;
@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_last_partial_run(bool is_last_partial_run) override; void set_is_last_partial_run(bool is_last_partial_run) override;
private: private:
string session_handle_;
string graph_handle_; string graph_handle_;
int64 step_id_; int64 step_id_;
ExecutorOpts exec_opts_; ExecutorOpts exec_opts_;
@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public: public:
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;
@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override; const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods. // MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override; void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override; void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override; ExecutorOpts* mutable_exec_opts() override;
@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
ProtoRunGraphRequest(const RunGraphRequest* request); ProtoRunGraphRequest(const RunGraphRequest* request);
// RunGraphRequestWrapper methods. // RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override; const string& graph_handle() const override;
int64 step_id() const override; int64 step_id() const override;
const ExecutorOpts& exec_opts() const override; const ExecutorOpts& exec_opts() const override;

View File

@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
class RemoteDevice : public Device { class RemoteDevice : public Device {
public: public:
RemoteDevice(Env* env, const DeviceAttributes& da) RemoteDevice(Env* env, const DeviceAttributes& da)
: Device(env, da, nullptr), : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
local_dev_name_(GetLocalDeviceName(da.name())) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
GetStatusResponse resp; GetStatusResponse resp;
}; };
Call* call = new Call; Call* call = new Call;
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { auto cb = [env, worker_cache, worker_name, done, wi,
call](const Status& status) {
Status s = status;
std::vector<Device*> remote_devices; std::vector<Device*> remote_devices;
auto cleanup = gtl::MakeCleanup(
[&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
worker_cache->ReleaseWorker(worker_name, wi);
done(s, &remote_devices);
delete call;
});
if (s.ok()) { if (s.ok()) {
DeviceNameUtils::ParsedName worker_name_parsed;
if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
!worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
!worker_name_parsed.has_task) {
s = errors::InvalidArgument("Could not parse worker name: ",
worker_name);
LOG(WARNING) << s;
return;
}
remote_devices.reserve(call->resp.device_attributes_size()); remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) { for (const DeviceAttributes& da : call->resp.device_attributes()) {
auto d = new RemoteDevice(env, da); DeviceNameUtils::ParsedName device_name_parsed;
remote_devices.push_back(d); CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
<< "Device attribute name '" << da.name() << "' could not be "
<< "parsed. Device Attribute: " << da.DebugString();
// Preserve the exact name, if possible.
// TODO(b/37868888): Simplify when legacy device name formats removed.
if (device_name_parsed.job == worker_name_parsed.job &&
device_name_parsed.replica == worker_name_parsed.replica &&
device_name_parsed.task == worker_name_parsed.task) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
} else {
DeviceAttributes da_rewritten = da;
da_rewritten.set_name(DeviceNameUtils::FullName(
worker_name_parsed.job, worker_name_parsed.replica,
worker_name_parsed.task, device_name_parsed.type,
device_name_parsed.id));
auto d = new RemoteDevice(env, da_rewritten);
remote_devices.push_back(d);
}
} }
} }
worker_cache->ReleaseWorker(worker_name, wi);
done(s, &remote_devices);
delete call;
}; };
wi->GetStatusAsync(&call->req, &call->resp, cb); wi->GetStatusAsync(&call->req, &call->resp, cb);
} }

View File

@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
struct WorkerSession;
// RemoteRendezvous follow a 2-part initialization. First the objects are
// constructed. Eventually, they will be initialized. Clients of the
// RendezvousMgrInterface must guarantee to call Initialize on the returned
// RemoteRendezvous eventually.
//
// Partially initialized RemoteRendezvous must respect the Rendezvous interface
// (i.e. Send() must never block), however implementations are not expected to
// actually perform the underlying operations until after the RemoteRendezvous
// has been Initialize'd.
class RemoteRendezvous : public Rendezvous {
public:
// Fully construct the RemoteRendezvous.
virtual Status Initialize(WorkerSession* session) = 0;
};
// RendezvousMgr keeps track of a set of local rendezvous instances. // RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr // All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id" // until the tensor is received. Each global unique "step_id"
@ -51,7 +68,10 @@ class RendezvousMgrInterface {
// Returns Rendezvous supporting send and recv among workers in the // Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the // "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance. // returned Rendezvous instance.
virtual Rendezvous* Find(int64 step_id) = 0; //
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
virtual RemoteRendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs // Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs. // "done" when the tensor for "key" is produced or an error occurs.

View File

@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
}; };
// static utility function // static utility function
RendezvousMgrInterface* NewRpcRendezvousMgr( RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
const WorkerEnv* env, const string& worker_name, return new RpcRendezvousMgr(env);
WorkerCacheInterface* worker_cache) {
return new RpcRendezvousMgr(env, worker_name, worker_cache);
} }
} // namespace } // namespace
@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
// TODO(mrry): Refactor the *Env classes so that it is less fiddly // TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them. // to destroy them.
// Shut down all outstanding rendezvous.
delete worker_env_.rendezvous_mgr;
// We must delete graph_mgr before device_mgr, due to shared // We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will // ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful // free all stateless OpKernels, and pass over borrowed stateful
@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
// OpSegments.) // OpSegments.)
if (worker_env_.session_mgr != nullptr) { if (worker_env_.session_mgr != nullptr) {
delete worker_env_.session_mgr; // Deletes graph_mgr's. delete worker_env_.session_mgr; // Deletes graph_mgr's.
} else {
// Note: session_mgr's legacy_session_ deletes device_mgr now.
delete worker_env_.device_mgr;
} }
delete worker_env_.device_mgr;
// Do not delete (as these are not owned by the server): // Do not delete (as these are not owned by the server):
// - master_env_.env // - master_env_.env
@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool // - worker_env_.compute_pool
} }
Status GrpcServer::Init(ServiceInitFunction service_func, Status GrpcServer::Init(
RendezvousMgrCreationFunction rendevous_mgr_func) { ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_EQ(state_, NEW); CHECK_EQ(state_, NEW);
master_env_.env = env_; master_env_.env = env_;
@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
"/task:", server_def_.task_index()); "/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices)); &master_env_.local_devices));
worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices); worker_env_.local_devices = master_env_.local_devices;
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);
string unused; string unused;
string default_worker_name; string default_worker_name;
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
} }
WorkerCacheInterface* worker_cache; WorkerCacheInterface* worker_cache;
TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache)); WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
TF_RETURN_IF_ERROR(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache); CHECK_NE(nullptr, worker_cache);
// Set up worker environment. // Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
rendevous_mgr_func == nullptr ?
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr( worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache), std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(server_def, worker_cache); WorkerCacheFactoryOptions options(server_def);
return WorkerCacheFactory(options, worker_cache);
}); });
worker_env_.compute_pool = ComputePool(sess_opts); worker_env_.compute_pool = ComputePool(sess_opts);
@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
master_env_.master_session_factory = master_env_.master_session_factory =
[config]( [config](
SessionOptions options, const MasterEnv* env, SessionOptions options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) { std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config); options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs), return new MasterSession(options, env, std::move(remote_devs),
std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher); CreateNoOpStatsPublisher);
}; };
master_env_.worker_cache_factory =
[this](const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(options, worker_cache);
};
// Provide direct access to the master from in-process clients. // Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(), LocalMaster::Register(target(), master_impl_.get(),
@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK(); return Status::OK();
} }
Status GrpcServer::Init() { Status GrpcServer::Init() { return Init(nullptr, nullptr); }
return Init(nullptr, nullptr);
}
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) { GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) { for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports; std::map<int, string> host_ports;
for (const auto& task : job.tasks()) { for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first]; string& host_port = host_ports[task.first];
@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
task.first, "\": ", host_port, " and ", task.first, "\": ", host_port, " and ",
task.second); task.second);
} }
if (job.name() == server_def.job_name() && if (job.name() == *options.job_name && task.first == options.task_index) {
task.first == server_def.task_index()) {
host_port = strings::StrCat("localhost:", bound_port_); host_port = strings::StrCat("localhost:", bound_port_);
} else { } else {
host_port = task.second; host_port = task.second;
@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
return Status::OK(); return Status::OK();
} }
Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) { WorkerCacheInterface** worker_cache) {
string name_prefix = if (options.job_name == nullptr || options.job_name->empty()) {
strings::StrCat("/job:", server_def.job_name(), "/replica:0", Status s = errors::InvalidArgument(
"/task:", server_def.task_index()); "The master (current machine) is not included in the provided "
"cluster_def. ",
options.cluster_def->DebugString());
LOG(WARNING) << s;
return s;
}
GrpcChannelSpec channel_spec; GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
std::unique_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index);
std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix); const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port; int requested_port;
@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
return ::grpc::InsecureServerCredentials(); return ::grpc::InsecureServerCredentials();
} }
ChannelCreationFunction GrpcServer::GetChannelCreationFunction( ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
const ServerDef& server_def) const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being // We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function // returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);

View File

@ -37,9 +37,7 @@ class GrpcWorker;
class Master; class Master;
// function that creates a RendezvousMgr. // function that creates a RendezvousMgr.
typedef std::function<RendezvousMgrInterface*( typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
const WorkerEnv*, const std::string& worker_name,
WorkerCacheInterface* worker_cache)>
RendezvousMgrCreationFunction; RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to // function that registers a service to the server. The service needs to
@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
protected: protected:
Status Init(ServiceInitFunction service_func, Status Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func); const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init(); Status Init();
@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const; const ServerDef& server_def) const;
virtual ChannelCreationFunction GetChannelCreationFunction( virtual ChannelCreationFunction GetChannelCreationFunction() const;
const ServerDef& server_def) const;
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session. // Creates a WorkerCacheInterface for a session.
Status WorkerCacheFactory(const ServerDef& server_def, Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache); WorkerCacheInterface** worker_cache);
// Parses a ServerDef into a GrpcChannelSpec. // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
Status ParseChannelSpec(const ServerDef& server_def, Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec); GrpcChannelSpec* channel_spec);
// Returns the port to which this server is bound. // Returns the port to which this server is bound.

View File

@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
/* static */ /* static */
Status GrpcSession::Create(const SessionOptions& options, Status GrpcSession::Create(const SessionOptions& options,
std::unique_ptr<GrpcSession>* out_session) { std::unique_ptr<GrpcSession>* out_session) {
std::unique_ptr<GrpcSession> ret(new GrpcSession(options)); std::unique_ptr<GrpcSession> session(new GrpcSession(options));
std::unique_ptr<MasterInterface> master; std::unique_ptr<MasterInterface> master;
// For testing, we enable the client to disable the use of the local // For testing, we enable the client to disable the use of the local
// master registry, so that the RPC stack is exercised. // master registry, so that the RPC stack is exercised.
@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
options.target.substr(kSchemePrefixLength), &master_channel)); options.target.substr(kSchemePrefixLength), &master_channel));
master.reset(NewGrpcMaster(master_channel)); master.reset(NewGrpcMaster(master_channel));
} }
ret->SetRemoteMaster(std::move(master)); session->SetRemoteMaster(std::move(master));
*out_session = std::move(ret); *out_session = std::move(session);
return Status::OK(); return Status::OK();
} }
@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
CreateSessionRequest req; CreateSessionRequest req;
*req.mutable_config() = options_.config; *req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph; *req.mutable_graph_def() = graph;
req.set_target(options_.target);
ReEncodeConsts(req.mutable_graph_def()); ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp; CreateSessionResponse resp;
Status s = master_->CreateSession(call_options, &req, &resp); Status s = master_->CreateSession(call_options, &req, &resp);

View File

@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// completes, and we may decide to bound some of the request // completes, and we may decide to bound some of the request
// types. // types.
ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false);
@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(GetStatus, false);
} }
void CreateWorkerSessionHandler(
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
call) {
Schedule([this, call]() {
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(CreateWorkerSession, false);
}
void CleanupAllHandler( void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) { WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() { Schedule([this, call]() {
@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response, ::grpc::ByteBuffer* response,
StatusCallback done) { StatusCallback done) {
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
const string& key = request->rendezvous_key(); const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed; Rendezvous::ParsedKey parsed;
@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
// of execution of the callback lambda body below, an RPC // of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous. // cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
session->rendezvous_mgr->RecvLocalAsync( env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed, step_id, parsed,
[opts, response, done, src_dev](const Status& status, [opts, response, done, src_dev](const Status& status,
const Rendezvous::Args& send_args, const Rendezvous::Args& send_args,

View File

@ -38,9 +38,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous { class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public: public:
RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name, RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
WorkerCacheInterface* cache, int64 step_id) : BaseRemoteRendezvous(env, step_id, false) {}
: BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
protected: protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private: private:
~RpcRemoteRendezvous() override {} ~RpcRemoteRendezvous() override {}
WorkerCacheInterface* const cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
}; };
@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
return call_freelist; return call_freelist;
} }
// A private cache that wraps worker_cache and allows reuse of
// WorkerInterface objects.
class WorkerFreeListCache : public WorkerCacheInterface {
public:
explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
~WorkerFreeListCache() {
for (auto p : workers_) {
wrapped_->ReleaseWorker(p.first, p.second.worker);
}
}
void ListWorkers(std::vector<string>* workers) const override {
wrapped_->ListWorkers(workers);
}
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
if (p != workers_.end()) {
return p->second.worker;
}
WorkerState state;
state.worker = wrapped_->CreateWorker(target);
if (state.worker != nullptr) {
workers_.insert(std::make_pair(target, state));
}
return state.worker;
}
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
wrapped_->GetDeviceLocalityAsync(device, locality, done);
}
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
void ClearLogs() override { wrapped_->ClearLogs(); }
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
return wrapped_->RetrieveLogs(step_id, ss);
}
private:
WorkerCacheInterface* wrapped_;
// Information kept per created WorkerInterface.
struct WorkerState {
WorkerInterface* worker;
// TODO(jeff,sanjay): Add reference count if we support eviction.
};
// TODO(jeff,sanjay): Eviction when the map becomes too big.
mutex mu_;
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
};
void RpcRemoteRendezvous::RecvFromRemoteAsync( void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) { DoneCallback done) {
CHECK(is_initialized());
Status s; Status s;
// Prepare a RecvTensor call that can handle being aborted. // Prepare a RecvTensor call that can handle being aborted.
@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device, s = errors::Internal(parsed.src_device,
" is invalid remote source device."); " is invalid remote source device.");
} }
WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); WorkerSession* sess = session();
WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) { if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_); s = errors::Internal("No worker known as ", call->src_worker_);
} }
Device* dst_device; Device* dst_device;
if (s.ok()) { if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
} }
if (!s.ok()) { if (!s.ok()) {
get_call_freelist()->Release(call, cache_); if (rwi != nullptr) {
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
}
get_call_freelist()->Release(call, sess->worker_cache.get());
done(s, Args(), recv_args, Tensor{}, false); done(s, Args(), recv_args, Tensor{}, false);
return; return;
} }
@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// current status should be bad. // current status should be bad.
Status s = call->status(); Status s = call->status();
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
cache_->ReleaseWorker(call->src_worker_, call->wi_); session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
call->wi_ = nullptr; call->wi_ = nullptr;
get_call_freelist()->Release(call, cache_); get_call_freelist()->Release(call, session()->worker_cache.get());
Unref(); Unref();
}); });
} }
} // namespace } // namespace
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env, RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
const string& worker_name, : BaseRendezvousMgr(env) {}
WorkerCacheInterface* worker_cache)
: BaseRendezvousMgr(env, worker_name),
cache_(new WorkerFreeListCache(worker_cache)) {}
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env, const WorkerEnv* worker_env) {
const string& worker_name) { return new RpcRemoteRendezvous(worker_env, step_id);
return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
step_id);
} }
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
namespace tensorflow { namespace tensorflow {
class DeviceMgr;
// RendezvousMgr keeps track of a set of local rendezvous instances. // RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr // All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id" // until the tensor is received. Each global unique "step_id"
@ -44,17 +44,12 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey. // RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr { class RpcRendezvousMgr : public BaseRendezvousMgr {
public: public:
explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name, explicit RpcRendezvousMgr(const WorkerEnv* env);
WorkerCacheInterface* worker_cache);
protected: protected:
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
const string& session_name) override;
private: private:
// Private cache_ that allows us to reuse WorkerInterface objects.
std::unique_ptr<WorkerCacheInterface> cache_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
}; };

View File

@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache), : cache_(new DummyWorkerCache),
worker_session_("/job:mnist/replica:1/task:2", worker_session_("/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_), std::unique_ptr<WorkerCacheInterface>(cache_),
std::unique_ptr<RendezvousMgrInterface>(), std::unique_ptr<DeviceMgr>(),
std::unique_ptr<GraphMgr>()), std::unique_ptr<GraphMgr>()),
rmgr_(&env, worker_session_.worker_name, cache_) { rmgr_(&env) {
env.env = Env::Default(); env.env = Env::Default();
} }
@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
"/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort(). { // Explicit Abort().
const int64 step_id = 123; const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
SchedClosure([this, rendez]() { SchedClosure([this, rendez]() {
env.env->SleepForMicroseconds(100 * 1000); env.env->SleepForMicroseconds(100 * 1000);
@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING); Tensor val(DT_STRING);
bool val_dead = false; bool val_dead = false;
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
} }
{ // Cleanup causes Abort(). { // Cleanup causes Abort().
const int64 step_id = 321; const int64 step_id = 321;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
SchedClosure([this, step_id]() { SchedClosure([this, step_id]() {
env.env->SleepForMicroseconds(100 * 1000); env.env->SleepForMicroseconds(100 * 1000);
@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING); Tensor val(DT_STRING);
bool val_dead = false; bool val_dead = false;
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
} }
} }
@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
const int64 step_id = 123; const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
"/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ {
Rendezvous* rendez = rmgr_.Find(step_id); RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez); core::ScopedUnref unref(rendez);
Rendezvous::Args args; Rendezvous::Args args;
args.device_context = dc; args.device_context = dc;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
} }
{ {

View File

@ -17,8 +17,9 @@ limitations under the License.
#include <utility> #include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow { namespace tensorflow {
@ -26,23 +27,12 @@ namespace tensorflow {
SessionMgr::SessionMgr( SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name, WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache, std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: SessionMgr(
worker_env, default_worker_name, std::move(default_worker_cache),
default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory) WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env), : worker_env_(worker_env),
legacy_session_( legacy_session_(default_worker_name, std::move(default_worker_cache),
default_worker_name, std::move(default_worker_cache), std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr), std::unique_ptr<GraphMgr>(
std::unique_ptr<GraphMgr>( new GraphMgr(worker_env, worker_env->device_mgr))),
new GraphMgr(worker_env, default_rendezvous_mgr))),
worker_cache_factory_(std::move(worker_cache_factory)) {} worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
Status SessionMgr::CreateSession(const string& session, Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) { const ServerDef& server_def) {
mutex_lock l(mu_); mutex_lock l(mu_);
if (session.empty()) {
return errors::InvalidArgument("Session must be non-empty.");
}
const string worker_name = WorkerNameFromServerDef(server_def); const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr; WorkerCacheInterface* worker_cache = nullptr;
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr( std::vector<Device*> renamed_devices;
new RpcRendezvousMgr(worker_env_, worker_name, worker_cache)); for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(
RenamedDevice::NewRenamedDevice(worker_name, d, false));
}
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
std::unique_ptr<GraphMgr> graph_mgr( std::unique_ptr<GraphMgr> graph_mgr(
new GraphMgr(worker_env_, rendezvous_mgr.get())); new GraphMgr(worker_env_, device_mgr.get()));
std::unique_ptr<WorkerSession> worker_session(new WorkerSession( std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache), worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr), std::move(graph_mgr))); std::move(device_mgr), std::move(graph_mgr)));
sessions_.insert(std::make_pair(session, std::move(worker_session))); sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK(); return Status::OK();
@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
if (it != sessions_.end()) { if (it != sessions_.end()) {
sessions_.erase(it); sessions_.erase(it);
} }
std::set<string> graph_handles;
for (auto graph_handle_it = sessions_by_graph_handle_.begin();
graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
if (graph_handle_it->second == session) {
graph_handles.insert(graph_handle_it->first);
graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
if (graph_handle_it == sessions_by_graph_handle_.end()) break;
}
}
for (auto step_id_it = graphs_by_step_id_.begin();
step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
step_id_it = graphs_by_step_id_.erase(step_id_it);
if (step_id_it == graphs_by_step_id_.end()) break;
}
}
return Status::OK(); return Status::OK();
} }
@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; } WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
const string& graph_handle) {
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it == sessions_by_graph_handle_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForSessionUnlocked(it->second);
}
}
WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
const string& graph_handle) {
mutex_lock l(mu_);
return WorkerSessionForGraphHandleUnlocked(graph_handle);
}
WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it == graphs_by_step_id_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForGraphHandleUnlocked(it->second);
}
}
void SessionMgr::AssociateGraphWithSession(const string& session,
const string& graph_handle) {
mutex_lock l(mu_);
sessions_by_graph_handle_[graph_handle] = session;
}
void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
mutex_lock l(mu_);
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it != sessions_by_graph_handle_.end()) {
sessions_by_graph_handle_.erase(it);
}
}
void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id) {
mutex_lock l(mu_);
graphs_by_step_id_[step_id] = graph_handle;
}
void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it != graphs_by_step_id_.end()) {
graphs_by_step_id_.erase(it);
}
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -30,6 +30,8 @@ struct WorkerEnv;
// SessionMgr keeps track of information related to a given session. // SessionMgr keeps track of information related to a given session.
// //
// SessionMgr runs on the workers.
//
// SessionMgr is threadsafe. // SessionMgr is threadsafe.
class SessionMgr { class SessionMgr {
public: public:
@ -39,7 +41,6 @@ class SessionMgr {
explicit SessionMgr( explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name, WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache, std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory); WorkerCacheFactory worker_cache_factory);
~SessionMgr() {} ~SessionMgr() {}
@ -50,49 +51,36 @@ class SessionMgr {
WorkerSession* WorkerSessionForSession(const string& session); WorkerSession* WorkerSessionForSession(const string& session);
WorkerSession* LegacySession(); WorkerSession* LegacySession();
// Locates the worker session for a given graph handle
WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
void AssociateGraphWithSession(const string& session,
const string& graph_handle);
void DisassociateGraphFromSession(const string& graph_handle);
// Locates a worker session for a given step id
WorkerSession* WorkerSessionForStepId(const int64 step_id);
void AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id);
void DisassociateStepIdFromGraph(const int64 step_id);
Status DeleteSession(const string& session); Status DeleteSession(const string& session);
static string WorkerNameFromServerDef(const ServerDef& server_def); static string WorkerNameFromServerDef(const ServerDef& server_def);
private: private:
// Private constructor to work around std::unique_ptr ownership issues.
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
const WorkerEnv* const worker_env_; // Not owned. const WorkerEnv* const worker_env_; // Not owned.
// A note about destruction:
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
// OpKernels, which are also held in their respective devices'
// OpSegments.)
//
// legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
// that sessions_'s WorkerSessions are deleted (which do not own the
// underlying devices, but instead own RenamedDevices) before
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
// device_mgr is deleted after WorkerSession's graph_mgr.
WorkerSession legacy_session_; WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_; const WorkerCacheFactory worker_cache_factory_;
WorkerSession* WorkerSessionForSessionUnlocked(const string& session) WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
EXCLUSIVE_LOCKS_REQUIRED(mu_); EXCLUSIVE_LOCKS_REQUIRED(mu_);
WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_; mutex mu_;
// A map from session identifier to internal session structure. // A map from session identifier to internal session structure.
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_); std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
// A map from graph handles to the session that they belong to.
std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
// A map from globally-unique step id's to the corresponding graph handles.
std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest() SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0", : mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), std::unique_ptr<WorkerCacheInterface>(),
std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
&env_, "/job:mnist/replica:0/task:0", nullptr)),
factory_), factory_),
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {} legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }
TEST_F(SessionMgrTest, AssociateGraphWithSession) { TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def; ServerDef server_def;
string session_handle = "test_session_handle"; string session_handle = "";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; EXPECT_EQ(mgr_.LegacySession(), session);
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
} }
TEST_F(SessionMgrTest, AssociateStepWithGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(session, step_session);
ASSERT_EQ(graph_session, step_session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, WorkerNameFromServerDef) { TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def; ServerDef server_def;
server_def.set_job_name("worker"); server_def.set_job_name("worker");

View File

@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
Status s = session->graph_mgr->Register( Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(), request->session_handle(), request->graph_def(), request->graph_options(),
request->debug_options(), response->mutable_graph_handle()); request->debug_options(), response->mutable_graph_handle());
if (s.ok()) {
env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
response->graph_handle());
}
done(s); done(s);
} }
@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response, DeregisterGraphResponse* response,
StatusCallback done) { StatusCallback done) {
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Deregister(request->graph_handle()); Status s = session->graph_mgr->Deregister(request->graph_handle());
env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
done(s); done(s);
} }
@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
} }
void Worker::AbortStep(int64 step_id) { void Worker::AbortStep(int64 step_id) {
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root // Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this // cause may return first back to the client instead of this
@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id); TRACEPRINTF("RunGraph: %lld", step_id);
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); env_->session_mgr->WorkerSessionForSession(request->session_handle());
env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
GraphMgr::NamedTensors in; GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out); Status s = PrepareRunGraph(request, &in, out);
@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
} }
CostGraphDef* cost_graph = response->mutable_cost_graph(); CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync( session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, request->exec_opts(), collector, request->graph_handle(), step_id, session, request->exec_opts(),
cost_graph, cm, in, collector, cost_graph, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts, [this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) { done](Status s) {
if (s.ok()) { if (s.ok()) {
@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const string& graph_handle = request->graph_handle(); const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id); TRACEPRINTF("PartialRunGraph: %lld", step_id);
WorkerSession* session = WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(graph_handle); env_->session_mgr->WorkerSessionForSession(request->session_handle());
env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
GraphMgr::NamedTensors in; GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out); Status s = PrepareRunGraph(request, &in, out);
@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
[cm]() { cm->StartCancel(); }); [cm]() { cm->StartCancel(); });
} }
session->graph_mgr->ExecuteAsync( session->graph_mgr->ExecuteAsync(
graph_handle, step_id, request->exec_opts(), nullptr /* collector */, graph_handle, step_id, session, request->exec_opts(),
nullptr /* cost_graph */, cm, in, nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, token, graph_handle, step_id, cm](Status s) { [this, token, graph_handle, step_id, cm](Status s) {
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response, CleanupGraphResponse* response,
StatusCallback done) { StatusCallback done) {
const int64 step_id = request->step_id(); const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); env_->rendezvous_mgr->Cleanup(step_id);
session->rendezvous_mgr->Cleanup(step_id);
done(Status::OK()); done(Status::OK());
} }
@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) { Device** src_dev) {
// Figures out which device the tensor is hosted on. // Figures out which device the tensor is hosted on.
TF_RETURN_IF_ERROR( string local_name = DeviceNameUtils::LocalName(parsed.src_device);
env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect? // Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#include <vector>
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -24,8 +25,10 @@ namespace thread {
class ThreadPool; class ThreadPool;
} // namespace thread } // namespace thread
class Device;
class DeviceMgr; class DeviceMgr;
class Env; class Env;
class RendezvousMgrInterface;
class SessionMgr; class SessionMgr;
// The worker environment class, which holds a bag of pointers to // The worker environment class, which holds a bag of pointers to
@ -38,10 +41,18 @@ struct WorkerEnv {
// session_mgr encapsulates state for each session. // session_mgr encapsulates state for each session.
SessionMgr* session_mgr = nullptr; SessionMgr* session_mgr = nullptr;
// The local devices of this worker. Devices are owned by the device_mgr.
//
// REQUIRES: !local_devices.empty().
std::vector<Device*> local_devices;
// device_mgr manages local devices (cpu and gpu). The WorkerService // device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices. // is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr; DeviceMgr* device_mgr = nullptr;
// A set of rendezvous keyed by step ids.
RendezvousMgrInterface* rendezvous_mgr = nullptr;
// A pool of threads for scheduling compute work. // A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr; thread::ThreadPool* compute_pool = nullptr;
}; };

View File

@ -113,6 +113,11 @@ class WorkerInterface {
return CallAndWait(&ME::GetStatusAsync, request, response); return CallAndWait(&ME::GetStatusAsync, request, response);
} }
Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response) {
return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
}
Status RegisterGraph(const RegisterGraphRequest* request, Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) { RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response); return CallAndWait(&ME::RegisterGraphAsync, request, response);

View File

@ -17,14 +17,84 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
WorkerSession::WorkerSession( namespace {
const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache, // A private cache that wraps worker_cache and allows reuse of
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr, // WorkerInterface objects.
std::unique_ptr<GraphMgr> graph_mgr) class WorkerFreeListCache : public WorkerCacheInterface {
public:
explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
: wrapped_(std::move(w)) {}
~WorkerFreeListCache() final {
for (auto p : workers_) {
wrapped_->ReleaseWorker(p.first, p.second.worker);
}
}
void ListWorkers(std::vector<string>* workers) const override {
wrapped_->ListWorkers(workers);
}
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
if (p != workers_.end()) {
return p->second.worker;
}
WorkerState state;
state.worker = wrapped_->CreateWorker(target);
if (state.worker != nullptr) {
workers_.insert(std::make_pair(target, state));
}
return state.worker;
}
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
wrapped_->GetDeviceLocalityAsync(device, locality, done);
}
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
void ClearLogs() override { wrapped_->ClearLogs(); }
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
return wrapped_->RetrieveLogs(step_id, ss);
}
private:
std::unique_ptr<WorkerCacheInterface> wrapped_;
// Information kept per created WorkerInterface.
struct WorkerState {
WorkerInterface* worker;
// TODO(jeff,sanjay): Add reference count if we support eviction.
};
// TODO(jeff,sanjay): Eviction when the map becomes too big.
mutex mu_;
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
};
} // namespace
WorkerSession::WorkerSession(const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr)
: worker_name(worker_name), : worker_name(worker_name),
worker_cache(std::move(worker_cache)), worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
rendezvous_mgr(std::move(rendezvous_mgr)), device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)) {} graph_mgr(std::move(graph_mgr)) {}
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,14 +18,13 @@ limitations under the License.
#include <string> #include <string>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
namespace tensorflow { namespace tensorflow {
class GraphMgr; class GraphMgr;
class RendezvousMgrInterface;
class WorkerCacheInterface; class WorkerCacheInterface;
// WorkerSession encapsulates all of the state relating to a given session. // WorkerSession encapsulates all of the state relating to a given session.
@ -36,17 +35,20 @@ struct WorkerSession {
// Object from which WorkerInterface instances can be obtained. // Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache; const std::unique_ptr<WorkerCacheInterface> worker_cache;
// A set of rendezvous keyed by step ids. // Collection of local devices. These devices are typically RenamedDevices
const std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr; // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr
// == worker_env_.device_mgr, which holds the true devices.
const std::unique_ptr<DeviceMgr> device_mgr;
// graph_mgr keeps track of the registered graphs of this session. // graph_mgr keeps track of the registered graphs of this session.
// //
// Note: graph_mgr must be deleted before rendezvous_mgr! // Note: graph_mgr must be deleted before rendezvous_mgr!
// Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr; const std::unique_ptr<GraphMgr> graph_mgr;
WorkerSession(const string& worker_name, WorkerSession(const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr, std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr); std::unique_ptr<GraphMgr> graph_mgr);
}; };

View File

@ -115,7 +115,7 @@ class DeviceBase {
cpu_worker_threads_ = t; cpu_worker_threads_ = t;
} }
const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
CHECK(cpu_worker_threads_ != nullptr); CHECK(cpu_worker_threads_ != nullptr);
return cpu_worker_threads_; return cpu_worker_threads_;
} }
@ -140,7 +140,7 @@ class DeviceBase {
gpu_device_info_ = g; gpu_device_info_ = g;
} }
const GpuDeviceInfo* tensorflow_gpu_device_info() const { virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const {
return gpu_device_info_; return gpu_device_info_;
} }
@ -170,13 +170,13 @@ class DeviceBase {
return GetAllocator(attr); return GetAllocator(attr);
} }
const Eigen::ThreadPoolDevice* eigen_cpu_device() { virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr); CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_; return eigen_cpu_device_;
} }
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice* eigen_sycl_device() const { virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr); CHECK(eigen_sycl_device_ != nullptr);
return eigen_sycl_device_; return eigen_sycl_device_;
} }

View File

@ -0,0 +1,82 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ClusterProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
// This file contains protos to be used when defining a TensorFlow
// cluster.
//
// EXAMPLES
// --------
//
// 1. A single-process cluster, containing "/job:local/task:0".
//
// Cluster:
// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
//
// Server:
// cluster { $CLUSTER } job_name: 'local' task_index: 0
//
// 2. A two-process cluster, containing "/job:local/task:{0,1}".
//
// Cluster:
// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
// tasks { key: 1 value: 'localhost:2223' } }
//
// Servers:
// cluster { $CLUSTER } job_name: 'local' task_index: 0
// cluster { $CLUSTER } job_name: 'local' task_index: 1
//
// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
// "/job:ps/task:{0,1}".
//
// Cluster:
// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
// tasks { key: 1 value: 'worker2:2222' }
// tasks { key: 2 value: 'worker3:2222' } }
// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
// tasks { key: 1 value: 'ps1:2222' } }
//
// Servers:
// cluster { $CLUSTER } job_name: 'worker' task_index: 0
// cluster { $CLUSTER } job_name: 'worker' task_index: 1
// cluster { $CLUSTER } job_name: 'worker' task_index: 2
// cluster { $CLUSTER } job_name: 'ps' task_index: 0
// cluster { $CLUSTER } job_name: 'ps' task_index: 1
// Defines a single job in a TensorFlow cluster.
message JobDef {
// The name of this job.
string name = 1;
// Mapping from task ID to "hostname:port" string.
//
// If the `name` field contains "worker", and the `tasks` map contains a
// mapping from 7 to "example.org:2222", then the device prefix
// "/job:worker/task:7" will be assigned to "example.org:2222".
map<int32, string> tasks = 2;
}
// Defines a TensorFlow cluster as a set of jobs.
message ClusterDef {
// The jobs that comprise the cluster.
repeated JobDef job = 1;
}

View File

@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto";
import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/debug.proto";
import "tensorflow/core/protobuf/cluster.proto";
import "tensorflow/core/protobuf/rewriter_config.proto"; import "tensorflow/core/protobuf/rewriter_config.proto";
message GPUOptions { message GPUOptions {
@ -259,6 +260,11 @@ message ConfigProto {
// Options that apply when this session uses the distributed runtime. // Options that apply when this session uses the distributed runtime.
RPCOptions rpc_options = 13; RPCOptions rpc_options = 13;
// Optional list of all workers to use in this session.
ClusterDef cluster_def = 14;
// Next: 15
}; };
// Options for a single Run() call. // Options for a single Run() call.

View File

@ -38,6 +38,9 @@ message CreateSessionRequest {
// Configuration options. // Configuration options.
ConfigProto config = 2; ConfigProto config = 2;
// The target string used from the client's perspective.
string target = 3;
} }
message CreateSessionResponse { message CreateSessionResponse {

View File

@ -16,6 +16,7 @@ limitations under the License.
syntax = "proto3"; syntax = "proto3";
import "tensorflow/core/protobuf/config.proto"; import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/cluster.proto";
package tensorflow; package tensorflow;
option cc_enable_arenas = true; option cc_enable_arenas = true;
@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos";
option java_multiple_files = true; option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime"; option java_package = "org.tensorflow.distruntime";
// This file contains protos to be used when defining a TensorFlow
// cluster, and a server within that cluster.
//
// EXAMPLES
// --------
//
// 1. A single-process cluster, containing "/job:local/task:0".
//
// Cluster:
// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
//
// Server:
// cluster { $CLUSTER } job_name: 'local' task_index: 0
//
// 2. A two-process cluster, containing "/job:local/task:{0,1}".
//
// Cluster:
// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
// tasks { key: 1 value: 'localhost:2223' } }
//
// Servers:
// cluster { $CLUSTER } job_name: 'local' task_index: 0
// cluster { $CLUSTER } job_name: 'local' task_index: 1
//
// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
// "/job:ps/task:{0,1}".
//
// Cluster:
// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
// tasks { key: 1 value: 'worker2:2222' }
// tasks { key: 2 value: 'worker3:2222' } }
// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
// tasks { key: 1 value: 'ps1:2222' } }
//
// Servers:
// cluster { $CLUSTER } job_name: 'worker' task_index: 0
// cluster { $CLUSTER } job_name: 'worker' task_index: 1
// cluster { $CLUSTER } job_name: 'worker' task_index: 2
// cluster { $CLUSTER } job_name: 'ps' task_index: 0
// cluster { $CLUSTER } job_name: 'ps' task_index: 1
// Defines a single job in a TensorFlow cluster.
message JobDef {
// The name of this job.
string name = 1;
// Mapping from task ID to "hostname:port" string.
//
// If the `name` field contains "worker", and the `tasks` map contains a
// mapping from 7 to "example.org:2222", then the device prefix
// "/job:worker/task:7" will be assigned to "example.org:2222".
//
// NOTE(mrry): Currently, only a dense task ID space starting at 0 is
// supported.
map<int32, string> tasks = 2;
}
// Defines a TensorFlow cluster as a set of jobs.
message ClusterDef {
// The jobs that comprise the cluster.
repeated JobDef job = 1;
}
// Defines the configuration of a single TensorFlow server. // Defines the configuration of a single TensorFlow server.
message ServerDef { message ServerDef {
// The cluster of which this server is a member. // The cluster of which this server is a member.

View File

@ -119,6 +119,10 @@ message RegisterGraphResponse {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
message DeregisterGraphRequest { message DeregisterGraphRequest {
// The session_handle used when registering the graph. If session_handle is
// empty, a single global namespace is used.
string session_handle = 2;
// REQUIRED: graph_handle must be returned by a RegisterGraph call // REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService. // to the same WorkerService.
string graph_handle = 1; string graph_handle = 1;
@ -167,6 +171,12 @@ message ExecutorOpts {
}; };
message RunGraphRequest { message RunGraphRequest {
// session_handle is the the master-generated unique id for this session.
// If session_handle is non-empty, it must be the same as used when
// registering the graph. If it is empty, a single global namespace is used to
// search for the graph_handle.
string session_handle = 8;
// REQUIRED: graph_handle must be returned by a RegisterGraph call // REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService. // to the same WorkerService.
string graph_handle = 1; string graph_handle = 1;
@ -193,6 +203,8 @@ message RunGraphRequest {
bool is_partial = 6; bool is_partial = 6;
// True if this is the last partial run request in a sequence of requests. // True if this is the last partial run request in a sequence of requests.
bool is_last_partial_run = 7; bool is_last_partial_run = 7;
// Next: 9
} }
message RunGraphResponse { message RunGraphResponse {

View File

@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import * from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.config_pb2 import * from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.protobuf.rewriter_config_pb2 import * from tensorflow.core.protobuf.rewriter_config_pb2 import *
from tensorflow.core.util.event_pb2 import * from tensorflow.core.util.event_pb2 import *
@ -131,6 +132,7 @@ _allowed_symbols = [
'AttrValue', 'AttrValue',
'AutoParallelOptions', 'AutoParallelOptions',
'ConfigProto', 'ConfigProto',
'ClusterDef',
'DeviceSpec', 'DeviceSpec',
'Event', 'Event',
'GPUOptions', 'GPUOptions',

View File

@ -29,6 +29,7 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes from tensorflow.python.framework import common_shapes
@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with CaptureStderr() as log: with CaptureStderr() as log:
sess.run(c) sess.run(c)
# Ensure that we did log device placement. # Ensure that we did log device placement.
self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log)) self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log))
def testLocalMasterSessionTimeout(self): def testLocalMasterSessionTimeout(self):
# Test that the timeout passed in a config to the session works correctly. # Test that the timeout passed in a config to the session works correctly.
@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase):
server = server_lib.Server.create_local_server() server = server_lib.Server.create_local_server()
self.runTestBuildGraphError(session.Session(server.target)) self.runTestBuildGraphError(session.Session(server.target))
def testClusterSpecPropagationSimple(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config)
output = sess.run(const)
self.assertEqual(17, output)
def testClusterSpecPropagationWorker2Placement(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config, graph=g)
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
output = sess.run(const, options=run_options, run_metadata=run_metadata)
self.assertEqual(17, output)
self.assertEqual(1,
len([
node_stats
for dev_stats in run_metadata.step_stats.dev_stats
for node_stats in dev_stats.node_stats
if '/job:worker/replica:0/task:1/device:CPU:0' ==
dev_stats.device and 'Const' == node_stats.node_name
]))
def testClusterSpecPropagationWorker1Placement(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config, graph=g)
output = sess.run(const)
self.assertEqual(17, output)
def testClusterSpecPropagationThreeServers2Graphs(self):
"""Boots 3 servers, creates 2 sessions, ensures appropriate operations.
We create 2 clusterspecs:
1. server2 as the master, server1 as a worker
2. server2 as the master, server3 as a worker
We ensure that variables on the workers are independent.
"""
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
server3 = server_lib.Server.create_local_server()
cluster_def1 = cluster_pb2.ClusterDef()
job1 = cluster_def1.job.add()
job1.name = 'worker1'
job1.tasks[0] = server2.target[len('grpc://'):]
job1.tasks[1] = server1.target[len('grpc://'):]
cluster_def2 = cluster_pb2.ClusterDef()
job2 = cluster_def2.job.add()
job2.name = 'worker2'
job2.tasks[0] = server2.target[len('grpc://'):]
job2.tasks[1] = server3.target[len('grpc://'):]
config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
with ops.Graph().as_default() as g1:
with ops.device('/job:worker1/task:1'):
var1 = variables.Variable(array_ops.zeros([2]), name='var1')
update_op1 = state_ops.assign_add(
var1, array_ops.ones([2]), name='var1_assign_add')
init1 = variables.global_variables_initializer()
with ops.Graph().as_default() as g2:
with ops.device('/job:worker2/task:1'):
var2 = variables.Variable(array_ops.zeros([2]), name='var2')
update_op2 = state_ops.assign_add(
var2, array_ops.ones([2]), name='var2_assign_add')
init2 = variables.global_variables_initializer()
sess1 = session.Session(server2.target, graph=g1, config=config1)
sess2 = session.Session(server2.target, graph=g2, config=config2)
init1.run(session=sess1)
init2.run(session=sess2)
expected_zeros = np.zeros([2])
expected_ones = np.ones([2])
self.assertAllEqual(expected_zeros, sess1.run(var1))
self.assertAllEqual(expected_zeros, sess2.run(var2))
self.assertAllEqual(expected_ones, sess1.run(update_op1))
self.assertAllEqual(expected_ones, sess1.run(var1))
self.assertAllEqual(expected_zeros, sess2.run(var2))
self.assertAllEqual(expected_ones, sess2.run(update_op2))
self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1))
self.assertAllEqual(expected_ones, sess2.run(var2))
self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
def testClusterSpecPropagationThreeServers(self):
"""Boots 3 servers, creates 2 sessions, ensures appropriate operations.
We create 2 clusterspecs:
1. server2 as the master, server1 as a worker
2. server2 as the master, server3 as a worker
We ensure that variables on the workers are independent.
"""
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
server3 = server_lib.Server.create_local_server()
cluster_def1 = cluster_pb2.ClusterDef()
job1 = cluster_def1.job.add()
job1.name = 'worker'
job1.tasks[0] = server2.target[len('grpc://'):]
job1.tasks[1] = server1.target[len('grpc://'):]
cluster_def2 = cluster_pb2.ClusterDef()
job2 = cluster_def2.job.add()
job2.name = 'worker'
job2.tasks[0] = server2.target[len('grpc://'):]
job2.tasks[1] = server3.target[len('grpc://'):]
config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
with ops.device('/job:worker/task:1'):
var = variables.Variable(array_ops.zeros([2]), name='var')
feed = array_ops.placeholder(dtypes.float32, shape=(2))
update_op = var.assign_add(feed)
sess1 = session.Session(server2.target, config=config1)
sess2 = session.Session(server2.target, config=config2)
variables.global_variables_initializer().run(session=sess1)
variables.global_variables_initializer().run(session=sess2)
expected_zeros = np.zeros([2])
expected_ones = np.ones([2])
self.assertAllEqual(expected_zeros, sess1.run(var))
self.assertAllEqual(expected_zeros, sess2.run(var))
self.assertAllEqual(expected_ones,
sess1.run(update_op, feed_dict={feed: expected_ones}))
self.assertAllEqual(expected_ones, sess1.run(var))
self.assertAllEqual(expected_zeros, sess2.run(var))
self.assertAllEqual(expected_ones,
sess2.run(update_op, feed_dict={feed: expected_ones}))
self.assertAllEqual(expected_ones + expected_ones,
sess1.run(update_op, feed_dict={feed: expected_ones}))
self.assertAllEqual(expected_ones, sess2.run(var))
self.assertAllEqual(expected_ones + expected_ones, sess1.run(var))
def testClusterSpecPropagationThreeServersOneCluster(self):
"""Boots 3 servers, ensures appropriate communication across workers.
Additionally, in this cluster, we ensure the master is not the 0-th worker.
Note: this test only uses one session.
"""
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
server3 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server3.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
job.tasks[2] = server1.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
# Add ops to the devices in non-linear order.
with ops.device('/job:worker/task:1'):
feed1 = array_ops.placeholder(dtypes.float32, shape=(2))
const1 = constant_op.constant(2.0)
mul1 = const1 * feed1
with ops.device('/job:worker/task:2'):
feed2 = array_ops.placeholder(dtypes.float32, shape=(2))
const2 = constant_op.constant(2.0)
mul2 = const2 * feed2
with ops.device('/job:worker/task:0'):
feed0 = array_ops.placeholder(dtypes.float32, shape=(2))
const0 = constant_op.constant(2.0)
mul0 = const0 * feed0
sum_op = mul0 + mul1 + mul2
ones = np.ones([2])
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
# Run!
with session.Session(server1.target, config=config) as sess:
output = sess.run(
sum_op,
options=run_options,
run_metadata=run_metadata,
feed_dict={feed1: ones,
feed2: ones,
feed0: ones})
self.assertAllEqual(6 * ones, output)
self.assertEqual(
3,
len([
dev_stats.device
for dev_stats in run_metadata.step_stats.dev_stats
for node_stats in dev_stats.node_stats
if '/job:worker/replica:0/task:' in dev_stats.device and
node_stats.node_name.startswith('Const')
]), run_metadata)
def testClusterSpecPropagationPartialRun(self):
"""Test successful partial run with ClusterSpec propagation."""
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with ops.device('/job:worker/task:0'):
a = array_ops.placeholder(dtypes.float32, shape=[])
with ops.device('/job:worker/task:1'):
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
with ops.device('/job:worker/task:0'):
r2 = math_ops.multiply(r1, c)
with session.Session(server1.target, config=config) as sess:
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
res = sess.partial_run(h, r2, feed_dict={c: 3})
self.assertEqual(9, res)
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -276,14 +277,14 @@ class ClusterSpec(object):
"from integers to strings." % job_name) "from integers to strings." % job_name)
self._cluster_spec[job_name] = job_tasks self._cluster_spec[job_name] = job_tasks
self._make_cluster_def() self._make_cluster_def()
elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): elif isinstance(cluster, cluster_pb2.ClusterDef):
self._cluster_def = cluster self._cluster_def = cluster
self._cluster_spec = {} self._cluster_spec = {}
for job_def in self._cluster_def.job: for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = { self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()} i: t for i, t in job_def.tasks.items()}
elif isinstance(cluster, ClusterSpec): elif isinstance(cluster, ClusterSpec):
self._cluster_def = tensorflow_server_pb2.ClusterDef() self._cluster_def = cluster_pb2.ClusterDef()
self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_def.MergeFrom(cluster.as_cluster_def())
self._cluster_spec = {} self._cluster_spec = {}
for job_def in self._cluster_def.job: for job_def in self._cluster_def.job:
@ -440,7 +441,7 @@ class ClusterSpec(object):
TypeError: If `cluster_spec` is not a dictionary mapping strings to lists TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
of strings. of strings.
""" """
self._cluster_def = tensorflow_server_pb2.ClusterDef() self._cluster_def = cluster_pb2.ClusterDef()
# NOTE(mrry): Sort by job_name to produce deterministic protobufs. # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
for job_name, tasks in sorted(self._cluster_spec.items()): for job_name, tasks in sorted(self._cluster_spec.items()):

View File

@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
# Distributed computing support. # Distributed computing support.
from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef from tensorflow.core.protobuf.cluster_pb2 import JobDef
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.training.server_lib import Server from tensorflow.python.training.server_lib import Server
@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server
_allowed_symbols = [ _allowed_symbols = [
# TODO(cwhipkey): review these and move to contrib or expose through # TODO(cwhipkey): review these and move to contrib or expose through
# documentation. # documentation.
"generate_checkpoint_state_proto", # Used internally by saver. "generate_checkpoint_state_proto", # Used internally by saver.
"checkpoint_exists", # Only used in test? "checkpoint_exists", # Only used in test?
"get_checkpoint_mtimes", # Only used in test? "get_checkpoint_mtimes", # Only used in test?
# Legacy: remove. # Legacy: remove.
"do_quantize_training_on_graphdef", # At least use grah_def, not graphdef. "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef.
# No uses within tensorflow. # No uses within tensorflow.
"queue_runner", # Use tf.train.start_queue_runner etc directly. "queue_runner", # Use tf.train.start_queue_runner etc directly.
# This is also imported internally. # This is also imported internally.
# TODO(drpng): document these. The reference in howtos/distributed does # TODO(drpng): document these. The reference in howtos/distributed does
# not link. # not link.
"SyncReplicasOptimizer", "SyncReplicasOptimizer",
# Protobufs: # Protobufs:
"BytesList", # from example_pb2. "BytesList", # from example_pb2.
"ClusterDef", "ClusterDef",
"Example", # from example_pb2 "Example", # from example_pb2
"Feature", # from example_pb2 "Feature", # from example_pb2
"Features", # from example_pb2 "Features", # from example_pb2
"FeatureList", # from example_pb2 "FeatureList", # from example_pb2
"FeatureLists", # from example_pb2 "FeatureLists", # from example_pb2
"FloatList", # from example_pb2. "FloatList", # from example_pb2.
"Int64List", # from example_pb2. "Int64List", # from example_pb2.
"JobDef", "JobDef",
"SaverDef", # From saver_pb2. "SaverDef", # From saver_pb2.
"SequenceExample", # from example_pb2. "SequenceExample", # from example_pb2.
"ServerDef", "ServerDef",
] ]
# Include extra modules for docstrings because: # Include extra modules for docstrings because:

View File

@ -6,6 +6,10 @@ tf_class {
name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER"
mtype: "<type \'int\'>" mtype: "<type \'int\'>"
} }
member {
name: "CLUSTER_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member { member {
name: "DESCRIPTOR" name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>" mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"

View File

@ -1,6 +1,6 @@
path: "tensorflow.train.ClusterDef" path: "tensorflow.train.ClusterDef"
tf_class { tf_class {
is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.ClusterDef\'>" is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.ClusterDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>" is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member { member {
name: "DESCRIPTOR" name: "DESCRIPTOR"

View File

@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef.TasksEntry" path: "tensorflow.train.JobDef.TasksEntry"
tf_class { tf_class {
is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.TasksEntry\'>" is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.TasksEntry\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>" is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member { member {
name: "DESCRIPTOR" name: "DESCRIPTOR"

View File

@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef" path: "tensorflow.train.JobDef"
tf_class { tf_class {
is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.JobDef\'>" is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.JobDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>" is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member { member {
name: "DESCRIPTOR" name: "DESCRIPTOR"