diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 93f487c36ca..5e336c5287b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d86e741b69e..362a1018955 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, options, Device::BuildDeviceAttributes( "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type())), - cpu_allocator()), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 6fd1ae08149..560e45fc135 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -118,6 +118,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/types.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" + "tensorflow/core/protobuf/cluster.proto" "tensorflow/core/protobuf/config.proto" "tensorflow/core/protobuf/debug.proto" "tensorflow/core/protobuf/rewriter_config.proto" diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index c0969e6dee2..2f1fcb149e1 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/meta_graph.pb.cc +tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/debug.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 132b4775962..6087a45168d 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/meta_graph.pb.h +tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/rewriter_config.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f1da05e4c6e..c39257ffa91 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -1,6 +1,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc +tensorflow/core/protobuf/cluster.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2a78ea61016..5eadf5d55b6 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/meta_graph.proto +tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/rewriter_config.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 435618ace7a..9d0c6a6c3eb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [ "framework/versions.proto", "lib/core/error_codes.proto", "protobuf/config.proto", + "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index 78649afeb93..aa8a2d989bf 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -23,8 +23,7 @@ limitations under the License. namespace tensorflow { -Device::Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator) +Device::Device(Env* env, const DeviceAttributes& device_attributes) : DeviceBase(env), device_attributes_(device_attributes) { CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) << "Invalid device name: " << name(); diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 07c6bdd6831..c0e58f143e3 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -53,8 +53,7 @@ namespace tensorflow { class Device : public DeviceBase { public: - Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator); + Device(Env* env, const DeviceAttributes& device_attributes); ~Device() override; // Full name of this device (see top comment). diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 7807656cb25..31f12d48337 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector& devices) for (Device* d : devices) { devices_.push_back(d); - // Register under both the full name and the local name. + // Register under the (1) full name, (2) canonical name, and (3) local name. string full_name = d->name(); device_map_[CopyToBackingStore(full_name)] = d; + DeviceNameUtils::ParsedName parsed_name = d->parsed_name(); + if (parsed_name.has_job && parsed_name.has_replica && + parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) { + string canonical_name = DeviceNameUtils::FullName( + parsed_name.job, parsed_name.replica, parsed_name.task, + parsed_name.type, parsed_name.id); + device_map_[CopyToBackingStore(canonical_name)] = d; + } string lname = DeviceNameUtils::LocalName(d->name()); device_map_[CopyToBackingStore(lname)] = d; device_type_counts_[d->device_type()]++; @@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector& devices) } DeviceMgr::~DeviceMgr() { - for (auto p : devices_) delete p; + // TODO(b/37437134): Remove destructor after converting to std::unique_ptr. + for (Device* p : devices_) delete p; } StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { @@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const { Status s; auto iter = device_map_.find(name); if (iter == device_map_.end()) { + std::vector device_names; + for (auto&& itr : device_map_) { + device_names.push_back(itr.first); + } + LOG(WARNING) << "Unknown device: " << name + << " all devices: " << str_util::Join(device_names, ", "); return errors::InvalidArgument(name, " unknown device."); } *device = iter->second; diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index bb1ed726408..d16681ac59d 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -36,6 +36,7 @@ class DeviceMgr { public: // Takes ownership of each device in 'devices'. // TODO(zhifengc): Other initialization information. + // TODO(b/37437134): Use std::unique_ptr's to track ownership. explicit DeviceMgr(const std::vector& devices); ~DeviceMgr(); @@ -61,6 +62,7 @@ class DeviceMgr { int NumDeviceType(const string& type) const; private: + // TODO(b/37437134): Use std::unique_ptr's to track ownership. typedef gtl::InlinedVector DeviceVec; DeviceVec devices_; diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index b0540dfa95b..4cd56e583c0 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -39,7 +39,10 @@ class DeviceSet { // Set the device designated as the "client". This device // must also be registered via AddDevice(). - void set_client_device(Device* device) { client_device_ = device; } + void set_client_device(Device* device) { + DCHECK(client_device_ == nullptr); + client_device_ = device; + } // Returns a pointer to the device designated as the "client". Device* client_device() const { return client_device_; } diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index ff20ee94a7d..0507076c8c3 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -27,8 +27,7 @@ namespace { static Device* Dev(const char* type, const char* name) { class FakeDevice : public Device { public: - explicit FakeDevice(const DeviceAttributes& attr) - : Device(nullptr, attr, nullptr) {} + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 0e2343cfe3f..02f70d835d5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, int gpu_id, const string& physical_device_desc, Allocator* gpu_allocator, Allocator* cpu_allocator, bool sync_every_op, int32 max_streams) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit, - locality, physical_device_desc), - gpu_allocator), + : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU, + memory_limit, locality, + physical_device_desc)), gpu_allocator_(gpu_allocator), cpu_allocator_(cpu_allocator), gpu_id_(gpu_id), diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index 0a6342ed736..3f7c9f68dba 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo { }; LocalDevice::LocalDevice(const SessionOptions& options, - const DeviceAttributes& attributes, - Allocator* device_allocator) - : Device(options.env, attributes, device_allocator), - owned_tp_info_(nullptr) { + const DeviceAttributes& attributes) + : Device(options.env, attributes), owned_tp_info_(nullptr) { // If we're running on the CPU, log warnings if we're not compiled using the // best flags for performance. port::WarnAboutUnusedCPUFeatures(); diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h index d1c27c62481..84a4f66db4a 100644 --- a/tensorflow/core/common_runtime/local_device.h +++ b/tensorflow/core/common_runtime/local_device.h @@ -33,8 +33,8 @@ struct SessionOptions; // GPUDevice into more 'process-wide' abstractions. class LocalDevice : public Device { public: - LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, - Allocator* device_allocator); + LocalDevice(const SessionOptions& options, + const DeviceAttributes& attributes); ~LocalDevice() override; private: diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc new file mode 100644 index 00000000000..fa9713735ed --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.cc @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/renamed_device.h" + +namespace tensorflow { + +// TODO(saeta): Convert to returning a std::unique_ptr? +/* static */ +Device* RenamedDevice::NewRenamedDevice(const string& new_base, + Device* underlying, + bool owns_underlying) { + DeviceNameUtils::ParsedName parsed_name; + CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name)); + DeviceNameUtils::ParsedName underlying_parsed_name = + underlying->parsed_name(); + CHECK(underlying_parsed_name.has_type); + CHECK(underlying_parsed_name.has_id); + parsed_name.type = underlying_parsed_name.type; + parsed_name.id = underlying_parsed_name.id; + string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica, + parsed_name.task, parsed_name.type, + parsed_name.id); + DeviceAttributes attributes(underlying->attributes()); + attributes.set_name(name); + return new RenamedDevice(underlying, attributes, owns_underlying); +} + +RenamedDevice::RenamedDevice(Device* underlying, + const DeviceAttributes& attributes, + bool owns_underlying) + : Device(underlying->env(), attributes), + underlying_(underlying), + owns_underlying_(owns_underlying) {} + +RenamedDevice::~RenamedDevice() { + if (owns_underlying_) { + delete underlying_; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h new file mode 100644 index 00000000000..0158e18cedc --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -0,0 +1,119 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Wraps a device with a new name, delegating work to the wrapped device. +// +// This class is used to wrap local devices when using clusterspec propagation +// where the name of a particular device may change in the context of a given +// session. +class RenamedDevice : public Device { + public: + static Device* NewRenamedDevice(const string& new_base, Device* underlying, + bool owns_underlying); + ~RenamedDevice() override; + + // Below are virtual methods defined on DeviceBase + bool RequiresRecordingAccessedTensors() const override { + return underlying_->RequiresRecordingAccessedTensors(); + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { + return underlying_->tensorflow_cpu_worker_threads(); + } + + const GpuDeviceInfo* tensorflow_gpu_device_info() const override { + return underlying_->tensorflow_gpu_device_info(); + } + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return underlying_->GetAllocator(attr); + } + + Allocator* GetStepAllocator(AllocatorAttributes attr, + ResourceMgr* step_resource_manager) override { + return underlying_->GetStepAllocator(attr, step_resource_manager); + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { + return underlying_->eigen_cpu_device(); + } + +#ifdef TENSORFLOW_USE_SYCL + const Eigen::SyclDevice* eigen_sycl_device() const override { + return underlying_->eigen_sycl_device(); + } +#endif + + PerOpGpuDevice* MakeGpuDevice() override { + return underlying_->MakeGpuDevice(); + } + + void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, Allocator* allocator) override { + underlying_->ReinitializeGpuDevice(context, device, dc, allocator); + } + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); + } + + // Below are virtual methods defined on Device + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + underlying_->Compute(op_kernel, context); + } + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override { + underlying_->ComputeAsync(op_kernel, context, std::move(done)); + } + + void ConsumeListOfAccessedTensors( + DeviceContext* context, const TensorReferenceVector& tensors) override { + underlying_->ConsumeListOfAccessedTensors(context, tensors); + } + + Status Sync() override { return underlying_->Sync(); } + + Status MaybeRewriteGraph(const FunctionDefLibrary& library, + std::unique_ptr* graph) override { + return underlying_->MaybeRewriteGraph(library, graph); + } + + Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) override { + return underlying_->FillContextMap(graph, device_context_map); + } + + private: + RenamedDevice(Device* underlying, const DeviceAttributes& attributes, + bool owns_underlying); + Device* const underlying_; + const bool owns_underlying_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index bd84417b105..24f27af5f1a 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -66,7 +66,7 @@ class DummyOp : public OpKernel { class FakeDevice : public Device { private: explicit FakeDevice(const DeviceAttributes& device_attributes) - : Device(nullptr, device_attributes, nullptr) {} + : Device(nullptr, device_attributes) {} public: Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 60348e885f5..f5f8aab6946 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, const string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit, - locality), - allocator), + : LocalDevice(options, Device::BuildDeviceAttributes( + name, DEVICE_CPU, memory_limit, locality)), allocator_(allocator) {} ThreadPoolDevice::~ThreadPoolDevice() {} diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 0f5eb0cb320..d2a828f39f2 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -77,7 +77,6 @@ cc_library( ], deps = [ ":graph_mgr", - ":rendezvous_mgr_interface", ":worker_cache", "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", @@ -92,9 +91,9 @@ cc_library( deps = [ ":graph_mgr", ":worker_session", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", ], ) @@ -237,6 +236,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 5863727f19b..e68aea46ecd 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { -BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name) - : worker_env_(worker_env), worker_name_(worker_name) {} +BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env) + : worker_env_(worker_env) {} BaseRendezvousMgr::~BaseRendezvousMgr() { for (auto& p : table_) { @@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() { } } -Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { +RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) { return FindOrCreate(step_id); } @@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { mutex_lock l(mu_); Table::iterator iter = table_.find(step_id); if (iter == table_.end()) { - auto rr = Create(step_id, worker_env_, worker_name_); + auto rr = Create(step_id, worker_env_); iter = table_.insert({step_id, rr}).first; } iter->second->Ref(); @@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() { } } -BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, - const string& worker_name, - int64 step_id, +BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, bool tolerate_dup_recv) : env_(env), - worker_name_(worker_name), step_id_(step_id), - local_(NewLocalRendezvous(tolerate_dup_recv)) {} + local_(NewLocalRendezvous(tolerate_dup_recv)), + session_(nullptr) {} BaseRemoteRendezvous::~BaseRemoteRendezvous() { CHECK(active_.empty()); @@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name, return device_name.starts_with(worker_name); } +Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { + CHECK_NE(session, nullptr) << "session must not be null!"; + std::vector deferred_calls; + { + mutex_lock l(mu_); + if (session_ != nullptr) { + if (session_->worker_name == session->worker_name) { + LOG(INFO) << "Skipping rendezvous re-initialization."; + return Status::OK(); + } + Status s = errors::Internal( + "Double init! Worker names would have changed from: ", + session_->worker_name, " -> ", session->worker_name); + LOG(WARNING) << s; + return s; + } + session_ = session; + std::swap(deferred_calls, deferred_calls_); + } + for (DeferredCall& call : deferred_calls) { + RecvLocalAsyncInternal(call.parsed, std::move(call.done)); + } + return Status::OK(); +} + +WorkerSession* BaseRemoteRendezvous::session() { + mutex_lock l(mu_); + return session_; +} + +bool BaseRemoteRendezvous::is_initialized() { + mutex_lock l(mu_); + return is_initialized_locked(); +} + Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { @@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, { mutex_lock l(mu_); if (!status_.ok()) return status_; - } - if (!IsLocalDevice(worker_name_, parsed.src_device)) { - return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + DCHECK(is_initialized_locked()); + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { + return errors::InvalidArgument( + "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", + session_->worker_name); + } } // Buffers "val" and "device_context" in local_. return local_->Send(parsed, args, val, is_dead); @@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, bool is_src) { + // Cache session pointer to avoid repeatedly taking & releasing the lock + // (e.g. calling session()) + WorkerSession* sess = nullptr; { mutex_lock l(mu_); if (!status_.ok()) return status_; + if (!is_initialized_locked()) { + return errors::Internal("ValidateDevices called before initialization."); + } + sess = session_; } - if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) { + if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } - if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) { + if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { return errors::InvalidArgument("Invalid rendezvous key (dst): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } return Status::OK(); } @@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); + CHECK(is_initialized()) << "RecvAsync called when uninitialized."; Status s = ValidateDevices(parsed, false /*!is_src*/); if (!s.ok()) { done(s, Args(), recv_args, Tensor(), false); @@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { + { + mutex_lock l(mu_); + if (!is_initialized_locked()) { + // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a + // remote worker) before the RunStep (or PartialRunStep) RPC from the + // master arrives. RecvLocalAsync thus buffers the arguments until after + // the RemoteRendezvous is Initialize()'d, when it completes the + // rendezvous logic. At some point after Initialize() is called, a Tensor + // is produced locally that will then be sent in response to the incoming + // RPC. + DeferredCall call(parsed, std::move(done)); + deferred_calls_.push_back(call); + return; + } + } + RecvLocalAsyncInternal(parsed, std::move(done)); +} + +void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, + DoneCallback done) { Status s = ValidateDevices(parsed, true /* is_src */); if (!s.ok()) { done(s, Args(), Args(), Tensor(), false); @@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) { active_.erase(call); } +BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed, + DoneCallback done) + : parsed(parsed), done(std::move(done)) {} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index 447a75913d6..b252f45fe96 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -59,15 +59,17 @@ class BaseRecvTensorCall; // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). class BaseRendezvousMgr : public RendezvousMgrInterface { public: - explicit BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name); + explicit BaseRendezvousMgr(const WorkerEnv* worker_env); ~BaseRendezvousMgr() override; // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - Rendezvous* Find(int64 step_id) override; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + RemoteRendezvous* Find(int64 step_id) override; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. @@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { protected: virtual BaseRemoteRendezvous* Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) = 0; + const WorkerEnv* worker_env) = 0; private: // Maps step_id to rendezvous. @@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Not owned. const WorkerEnv* const worker_env_; - const string worker_name_; mutex mu_; Table table_ GUARDED_BY(mu_); @@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Buffering of Tensor values is delegated to a "local" Rendezvous // obtained from NewLocalRendezvous(). This class just adds // functionality to coordinate with remote workers. -class BaseRemoteRendezvous : public Rendezvous { +class BaseRemoteRendezvous : public RemoteRendezvous { public: - BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - int64 step_id, bool tolerate_dup_recv); + BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, + bool tolerate_dup_recv); + + // Upgrades the BaseRemoteRendezvous to full initialization. + Status Initialize(WorkerSession* session) override; // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. @@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous { // Removes "call" from active_ if "call" is in active_. void DeregisterCall(BaseRecvTensorCall* call); + WorkerSession* session(); + + bool is_initialized(); + ~BaseRemoteRendezvous() override; const WorkerEnv* const env_; // Not owned. - const string worker_name_; const int64 step_id_; private: @@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous { // Status given by StartAbort() if any. Status status_ GUARDED_BY(mu_); + WorkerSession* session_ GUARDED_BY(mu_); // Not owned. + + // Data structures to handle calls when partially initialized. + struct DeferredCall { + const ParsedKey parsed; + DoneCallback done; + + DeferredCall(const ParsedKey& parsed, DoneCallback done); + }; + std::vector deferred_calls_ GUARDED_BY(mu_); // Active outstanding RecvTensor calls. gtl::FlatSet active_ GUARDED_BY(mu_); + bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return session_ != nullptr; + } + // If "is_src" is true, checks that the rendezvous key "parsed"'s // source is in this process. If "is_src" is false, checks that the // rendezvous key "parsed"'s destination is in this process. @@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous { const Rendezvous::Args& out_args, const Tensor& in, Tensor* out, StatusCallback done); + // Must be called only if fully initialized. + void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); + TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); }; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index ce7ce372e85..5bde771e8de 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -46,10 +46,8 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr) - : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) { - CHECK(rendezvous_mgr) << "Rendezvous mgr was null"; +GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) + : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. Status status = @@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, }; popts.get_incarnation = [this](const string& name) -> int64 { Device* device = nullptr; - Status s = worker_env_->device_mgr->LookupDevice(name, &device); + Status s = device_mgr_->LookupDevice(name, &device); if (s.ok()) { return device->attributes().incarnation(); } else { @@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, ExecutionUnit* unit = &(item->units.back()); // Find the device. - Status s = - worker_env_->device_mgr->LookupDevice(device_name, &unit->device); + Status s = device_mgr_->LookupDevice(device_name, &unit->device); if (!s.ok()) { // Remove the empty unit from the item as the item destructor wants all // units to have valid devices. @@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, // Function library runtime. unit->lib = NewFunctionLibraryRuntime( - worker_env_->device_mgr, worker_env_->env, unit->device, + device_mgr_, worker_env_->env, unit->device, subgraph->versions().producer(), item->lib_def, graph_options.optimizer_options()); @@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, } Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = SendInputsToRendezvous(rendezvous, in); rendezvous->Unref(); return s; } Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = RecvOutputsFromRendezvous(rendezvous, out); rendezvous->Unref(); return s; @@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, StatusCallback done) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); RecvOutputsFromRendezvousAsync(rendezvous, out, [done, rendezvous](const Status s) { rendezvous->Unref(); @@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, + WorkerSession* session, + const ExecutorOpts& /*opts*/, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, @@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, return; } - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); + Status s = rendezvous->Initialize(session); // Sends values specified by the caller. - Status s = SendInputsToRendezvous(rendezvous, in); + if (s.ok()) { + s = SendInputsToRendezvous(rendezvous, in); + } + if (!s.ok()) { done(s); item->Unref(); @@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, StatusCallback done) { const int num_units = item->units.size(); CHECK_GE(num_units, 1); - ScopedStepContainer* step_container = - new ScopedStepContainer(step_id, [this](const string& name) { - worker_env_->device_mgr->ClearContainers({name}); - }); + ScopedStepContainer* step_container = new ScopedStepContainer( + step_id, + [this](const string& name) { device_mgr_->ClearContainers({name}); }); // NOTE: Transfer one ref of rendezvous and item. ExecutorBarrier* barrier = new ExecutorBarrier(num_units, rendezvous, diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 349af6c54e5..50391f47e4d 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -37,6 +37,8 @@ namespace tensorflow { class ExecutorOpts; class StepStatsCollector; class RendezvousMgrInterface; +class DeviceMgr; +struct WorkerSession; // GraphMgr keeps track of a set of graphs that are registered with a // TensorFlow worker. Each registered graph is identified by a handle @@ -62,8 +64,7 @@ class RendezvousMgrInterface; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle" @@ -78,8 +79,8 @@ class GraphMgr { typedef std::map NamedTensors; typedef std::function StatusCallback; void ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, StepStatsCollector* collector, - CostGraphDef* cost_graph, + WorkerSession* session, const ExecutorOpts& opts, + StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, const NamedTensors& in, StatusCallback done); @@ -131,7 +132,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - RendezvousMgrInterface* rendezvous_mgr_; // Not owned. + DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index b4adee3bf6c..e860c99d953 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -34,6 +34,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -48,12 +49,17 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +namespace { +const char* const kGrpcProtocol = "grpc://"; +} // namespace + Master::Master(MasterEnv* env, double session_gc_seconds) : env_(env), last_1000_steps_(1000), @@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { Status status; + WorkerCacheFactoryOptions worker_cache_factory_options; + string grpc_protocol("grpc"); + worker_cache_factory_options.protocol = &grpc_protocol; auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); status = ValidateExternalGraphDefSyntax(req->graph_def()); if (!status.ok()) return; - // Ping all the workers and build the list of devices that the - // session will use. + + // The following 4 variables are set differently, depending on whether this + // session uses a client-provided clusterspec or not. + WorkerCacheInterface* worker_cache = nullptr; + // Note: worker_cache_ptr will be null except if this session is using a + // client-supplied ClusterDef (ClusterSpec propagation). + std::unique_ptr worker_cache_ptr; + std::unique_ptr device_set; // TODO(saeta): Convert to std::make_unique when available. std::unique_ptr>> remote_devices( new std::vector>()); - status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), - env_, env_->worker_cache, - remote_devices.get()); - if (!status.ok()) return; + + if (req->config().has_cluster_def()) { + worker_cache_factory_options.cluster_def = &req->config().cluster_def(); + + // Set the server_def's job_name and task_index fields. + string normalized_string; + string grpc_protocol(kGrpcProtocol); + if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) == + 0) { + normalized_string = + req->target().substr(grpc_protocol.length(), string::npos); + } else { + normalized_string = req->target(); + } + for (auto&& job : req->config().cluster_def().job()) { + for (auto&& task : job.tasks()) { + if (task.second == normalized_string) { + if (worker_cache_factory_options.job_name != nullptr) { + status = errors::InvalidArgument( + "Found multiple matching tasks that correspond to " + "to the master. Master target: '", + req->target(), "'. ClusterDef: ", + req->config().cluster_def().ShortDebugString()); + LOG(ERROR) << status; + return; + } + if (env_->local_devices[0]->parsed_name().job == job.name() && + env_->local_devices[0]->parsed_name().task == task.first) { + // TODO(b/37868888): Remove this limitation when resolved + status = errors::InvalidArgument( + "The ClusterSpec names the job and task index to be the same " + "names that were provided when the server booted. This is " + "currently not allowed. Job: ", + job.name(), ", task index: ", task.first); + return; + } + worker_cache_factory_options.job_name = &job.name(); + worker_cache_factory_options.task_index = task.first; + } + } + } + + // Create the worker cache from the computed server_def. + status = env_->worker_cache_factory(worker_cache_factory_options, + &worker_cache); + if (!status.ok()) return; + worker_cache_ptr = std::unique_ptr(worker_cache); + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + DeviceNameUtils::ParsedName name = d->parsed_name(); + if (name.job == *worker_cache_factory_options.job_name && + name.task == worker_cache_factory_options.task_index && + name.type == "CPU") { + device_set->set_client_device(d.get()); + } + } + } else { + worker_cache = env_->worker_cache; + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + } + int num_local_devices = 0; + for (Device* d : env_->local_devices) { + device_set->AddDevice(d); + if (num_local_devices == 0) { + // Uses the first local device as the client device. + device_set->set_client_device(d); + } + num_local_devices++; + } + } + + CHECK(device_set->client_device()); + SessionOptions options; options.config = req->config(); - MasterSession* session = - env_->master_session_factory(options, env_, std::move(remote_devices)); + + MasterSession* session = env_->master_session_factory( + options, env_, std::move(remote_devices), std::move(worker_cache_ptr), + std::move(device_set)); + GraphDef* gdef = const_cast(req)->mutable_graph_def(); - status = session->Create(gdef); + + status = session->Create(gdef, worker_cache_factory_options); if (!status.ok()) { session->Close().IgnoreError(); session->Unref(); diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index a155bd384d8..bb548adda15 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -19,17 +19,41 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { class Device; +class DeviceSet; class Env; class MasterSession; class OpRegistryInterface; class WorkerCacheInterface; +// Options passed to the worker_cache_factory function. +struct WorkerCacheFactoryOptions { + const ClusterDef* cluster_def = nullptr; + const string* job_name = nullptr; + int task_index; + const string* protocol = nullptr; + + WorkerCacheFactoryOptions() {} + + // Construct from a ServerDef proto. + // + // Note: server_def must outlive WorkerCacheFactoryOptions! + WorkerCacheFactoryOptions(const ServerDef& server_def) { + if (server_def.has_cluster() && !server_def.job_name().empty()) { + cluster_def = &server_def.cluster(); + job_name = &server_def.job_name(); + task_index = server_def.task_index(); + protocol = &server_def.protocol(); + } + } +}; + // The master environment class, which holds a bag of pointers to // per-master state. // @@ -57,8 +81,14 @@ struct MasterEnv { // `MasterEnv*` is retained by the caller. std::function>>)> + std::unique_ptr>>, + std::unique_ptr, + std::unique_ptr device_set)> master_session_factory; + + std::function + worker_cache_factory; }; } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5257aea1e3a..50c5d90fc98 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -528,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( c->req->set_is_partial(is_partial_); c->req->set_is_last_partial_run(is_last_partial_run); } + c->req->set_session_handle(session_handle_); c->req->set_graph_handle(part.graph_handle); c->req->set_step_id(step_id); *c->req->mutable_exec_opts() = exec_opts; @@ -871,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { // The graph handle may be empty if we failed during partition registration. if (!part.graph_handle.empty()) { Call* c = new Call; + c->req.set_session_handle(session_handle_); c->req.set_graph_handle(part.graph_handle); // NOTE(mrry): We must capture `worker_cache_` since `this` // could be deleted before the callback is called. @@ -973,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) { MasterSession::MasterSession( const SessionOptions& opt, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory) : session_opts_(opt), env_(env), handle_(strings::FpToString(random::New64())), remote_devs_(std::move(remote_devs)), + worker_cache_(std::move(worker_cache)), + devices_(std::move(device_set)), stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), run_graphs_(5), partial_run_graphs_(5) { UpdateLastAccessTime(); + CHECK(devices_) << "device_set was null!"; VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() << " #remote " << remote_devs_->size(); - for (auto&& d : *remote_devs_) { - devices_.AddDevice(d.get()); - } - int num_local_devices = 0; - for (Device* d : env->local_devices) { - devices_.AddDevice(d); - if (num_local_devices == 0) { - // Uses the first local device as the client device. - devices_.set_client_device(d); - } - num_local_devices++; - } + LOG(INFO) << "Start master session " << handle_ << " with config: " << std::endl << session_opts_.config.DebugString(); @@ -1012,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() { last_access_time_usec_.store(Env::Default()->NowMicros()); } -Status MasterSession::Create(GraphDef* graph_def) { +Status MasterSession::Create(GraphDef* graph_def, + const WorkerCacheFactoryOptions& options) { if (session_opts_.config.graph_options().place_pruned_graph()) { // TODO(b/29900832): Fix this or remove the option. LOG(WARNING) << "Distributed session does not support the " @@ -1020,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) { session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); } - SimpleGraphExecutionStateOptions options; - options.device_set = &devices_; - options.session_options = &session_opts_; + SimpleGraphExecutionStateOptions execution_options; + execution_options.device_set = devices_.get(); + execution_options.session_options = &session_opts_; { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( - graph_def, options, &execution_state_)); + graph_def, execution_options, &execution_state_)); + } + if (options.cluster_def != nullptr) { + return CreateWorkerSessions(options); } return Status::OK(); } +Status MasterSession::CreateWorkerSessions( + const WorkerCacheFactoryOptions& options) { + CHECK(worker_cache_) << "CreateWorkerSessions should be called only with " + << "dynamic cluster membership."; + std::vector worker_names; + worker_cache_->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + CreateWorkerSessionRequest request; + CreateWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + *workers[i].request.mutable_server_def()->mutable_cluster() = + *options.cluster_def; + workers[i].request.mutable_server_def()->set_protocol(*options.protocol); + + DeviceNameUtils::ParsedName name; + if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) { + status = errors::Internal("Could not parse name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + if (!name.has_job || !name.has_task) { + status = errors::Internal("Incomplete worker name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + + workers[i].request.mutable_server_def()->set_job_name(name.job); + workers[i].request.mutable_server_def()->set_task_index(name.task); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->CreateWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); @@ -1060,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req, return Status::OK(); } +WorkerCacheInterface* MasterSession::get_worker_cache() const { + if (worker_cache_) { + return worker_cache_.get(); + } + return env_->worker_cache; +} + Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** rcg, bool is_partial) { const uint64 hash = HashBuildGraphOptions(opts); @@ -1083,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, << "\n"; std::unique_ptr client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); + WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( handle_, opts, std::move(client_graph), session_opts_, stats_publisher_factory_, execution_state_.get(), is_partial, - env_->worker_cache); - + worker_cache); iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } @@ -1162,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, return errors::FailedPrecondition("Session is closed."); } ++num_running_; + // Note: all code paths must eventually call MarkRunCompletion() + // in order to appropriate decrement the num_running_ counter. } Status status; if (!req.partial_run_handle().empty()) { @@ -1169,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, } else { status = DoRunWithLocalExecution(opts, req, resp); } - { - mutex_lock l(mu_); - --num_running_; - if (num_running_ == 0) { - num_running_is_zero_.notify_all(); - } - } return status; } +// Decrements num_running_ and broadcasts if num_running_ is zero. +void MasterSession::MarkRunCompletion() { + mutex_lock l(mu_); + --num_running_; + if (num_running_ == 0) { + num_running_is_zero_.notify_all(); + } +} + Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { // Registers subgraphs if haven't done so. PartitionOptions popts; @@ -1188,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { return strings::StrCat(prefix, "_S", next_node_id_++); }; popts.get_incarnation = [this](const string& name) -> int64 { - Device* d = devices_.FindDeviceByName(name); + Device* d = devices_->FindDeviceByName(name); if (d == nullptr) { return PartitionOptions::kIllegalIncarnation; } else { @@ -1223,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); const string& prun_handle = req.partial_run_handle(); RunState* run_state = nullptr; { @@ -1321,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts, rcg->Ref(); rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), req.options(), resp->mutable_metadata()); + cleanup.release(); // MarkRunCompletion called in done closure. rcg->CleanupPartitionsAsync( run_state->step_id, [this, rcg, prun_handle](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); mutex_lock l(mu_); partial_runs_.erase(prun_handle); @@ -1368,10 +1457,10 @@ Status MasterSession::CreateDebuggerState( Status MasterSession::DoRunWithLocalExecution( CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { - VLOG(2) << "DoRunWithLocalExecution " - << "req: " << req.DebugString(); + VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); PerStepState pss; pss.start_micros = Env::Default()->NowMicros(); + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); // Prepare. BuildGraphOptions bgopts; @@ -1438,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution( } } rcg->Ref(); - rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { + cleanup.release(); // MarkRunCompletion called in done closure. + rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); return s; } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index d47125be992..3acc5bc5f0a 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted { MasterSession( const SessionOptions& options, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory); // Initialize the MasterSession for "def". Must be called before Extend(), // Run(), or Close(). // // After this method returns, `def` will no longer be valid. - Status Create(GraphDef* def); + Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options); // Returns the session handle. const string& handle() const { return handle_; } @@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted { std::unique_ptr>> remote_devs_; + // The optional session-specific worker cluster. + // TODO(saeta): Convert to std::optional when available. + std::unique_ptr worker_cache_; + // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. + WorkerCacheInterface* get_worker_cache() const; + // The device set used by this session. - DeviceSet devices_; + std::unique_ptr devices_; StatsPublisherFactory stats_publisher_factory_; @@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted { // Private dtor. The client must call Close(). virtual ~MasterSession(); + // Creates sessions on all workers. + // + // If this session is operating using the new ClusterSpec propagation behavior + // call this method in order to propagate the cluster membership to all + // workers. + Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, @@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted { MutableRunStepResponseWrapper* resp); Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp); + void MarkRunCompletion(); void UpdateLastAccessTime(); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 7b58feb93cc..b077975ea50 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const { const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } +const string& InMemoryRunGraphRequest::session_handle() const { + return session_handle_; +} + +void InMemoryRunGraphRequest::set_session_handle(const string& handle) { + session_handle_ = handle; +} + const string& InMemoryRunGraphRequest::graph_handle() const { return graph_handle_; } @@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run( const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { if (!proto_version_) { proto_version_.reset(new RunGraphRequest); + proto_version_->set_session_handle(session_handle()); proto_version_->set_graph_handle(graph_handle()); proto_version_->set_step_id(step_id()); *proto_version_->mutable_exec_opts() = exec_opts(); @@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { return *proto_version_; } +const string& MutableProtoRunGraphRequest::session_handle() const { + return request_.session_handle(); +} + +void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { + request_.set_session_handle(handle); +} + const string& MutableProtoRunGraphRequest::graph_handle() const { return request_.graph_handle(); } @@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) : request_(request) {} +const string& ProtoRunGraphRequest::session_handle() const { + return request_->session_handle(); +} + const string& ProtoRunGraphRequest::graph_handle() const { return request_->graph_handle(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 02516eabb4a..795a6add0e7 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -223,6 +223,10 @@ class RunGraphRequestWrapper { public: virtual ~RunGraphRequestWrapper() {} + // The session handle used to register the graph. If empty, a single global + // namespace is used. + virtual const string& session_handle() const = 0; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. virtual const string& graph_handle() const = 0; @@ -262,6 +266,7 @@ class RunGraphRequestWrapper { // See `RunGraphRequestWrapper` above for a description of the fields. class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { public: + virtual void set_session_handle(const string& handle) = 0; virtual void set_graph_handle(const string& handle) = 0; virtual void set_step_id(int64 step_id) = 0; virtual ExecutorOpts* mutable_exec_opts() = 0; @@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { void set_is_last_partial_run(bool is_last_partial_run) override; private: + string session_handle_; string graph_handle_; int64 step_id_; ExecutorOpts exec_opts_; @@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { ProtoRunGraphRequest(const RunGraphRequest* request); // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 9632e9c4398..91c1fb99fef 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/remote_device.h" #include + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) { class RemoteDevice : public Device { public: RemoteDevice(Env* env, const DeviceAttributes& da) - : Device(env, da, nullptr), - local_dev_name_(GetLocalDeviceName(da.name())) {} + : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, GetStatusResponse resp; }; Call* call = new Call; - auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { + auto cb = [env, worker_cache, worker_name, done, wi, + call](const Status& status) { + Status s = status; std::vector remote_devices; + auto cleanup = gtl::MakeCleanup( + [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] { + worker_cache->ReleaseWorker(worker_name, wi); + done(s, &remote_devices); + delete call; + }); if (s.ok()) { + DeviceNameUtils::ParsedName worker_name_parsed; + if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) || + !worker_name_parsed.has_job || !worker_name_parsed.has_replica || + !worker_name_parsed.has_task) { + s = errors::InvalidArgument("Could not parse worker name: ", + worker_name); + LOG(WARNING) << s; + return; + } remote_devices.reserve(call->resp.device_attributes_size()); for (const DeviceAttributes& da : call->resp.device_attributes()) { - auto d = new RemoteDevice(env, da); - remote_devices.push_back(d); + DeviceNameUtils::ParsedName device_name_parsed; + CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed)) + << "Device attribute name '" << da.name() << "' could not be " + << "parsed. Device Attribute: " << da.DebugString(); + // Preserve the exact name, if possible. + // TODO(b/37868888): Simplify when legacy device name formats removed. + if (device_name_parsed.job == worker_name_parsed.job && + device_name_parsed.replica == worker_name_parsed.replica && + device_name_parsed.task == worker_name_parsed.task) { + auto d = new RemoteDevice(env, da); + remote_devices.push_back(d); + } else { + DeviceAttributes da_rewritten = da; + da_rewritten.set_name(DeviceNameUtils::FullName( + worker_name_parsed.job, worker_name_parsed.replica, + worker_name_parsed.task, device_name_parsed.type, + device_name_parsed.id)); + auto d = new RemoteDevice(env, da_rewritten); + remote_devices.push_back(d); + } } } - worker_cache->ReleaseWorker(worker_name, wi); - done(s, &remote_devices); - delete call; }; wi->GetStatusAsync(&call->req, &call->resp, cb); } diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index 04c1fc248ef..43267d4362f 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -25,6 +25,23 @@ limitations under the License. namespace tensorflow { +struct WorkerSession; + +// RemoteRendezvous follow a 2-part initialization. First the objects are +// constructed. Eventually, they will be initialized. Clients of the +// RendezvousMgrInterface must guarantee to call Initialize on the returned +// RemoteRendezvous eventually. +// +// Partially initialized RemoteRendezvous must respect the Rendezvous interface +// (i.e. Send() must never block), however implementations are not expected to +// actually perform the underlying operations until after the RemoteRendezvous +// has been Initialize'd. +class RemoteRendezvous : public Rendezvous { + public: + // Fully construct the RemoteRendezvous. + virtual Status Initialize(WorkerSession* session) = 0; +}; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -51,7 +68,10 @@ class RendezvousMgrInterface { // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - virtual Rendezvous* Find(int64 step_id) = 0; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + virtual RemoteRendezvous* Find(int64 step_id) = 0; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 7160962b168..3867dd1f4d0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { }; // static utility function -RendezvousMgrInterface* NewRpcRendezvousMgr( - const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache) { - return new RpcRendezvousMgr(env, worker_name, worker_cache); +RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { + return new RpcRendezvousMgr(env); } } // namespace @@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() { // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. + // Shut down all outstanding rendezvous. + delete worker_env_.rendezvous_mgr; + // We must delete graph_mgr before device_mgr, due to shared // ownership of OpKernels in the executors. (The graph_mgr will // free all stateless OpKernels, and pass over borrowed stateful @@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. + } else { + // Note: session_mgr's legacy_session_ deletes device_mgr now. + delete worker_env_.device_mgr; } - delete worker_env_.device_mgr; // Do not delete (as these are not owned by the server): // - master_env_.env @@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() { // - worker_env_.compute_pool } -Status GrpcServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendevous_mgr_func) { +Status GrpcServer::Init( + ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, "/task:", server_def_.task_index()); TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, &master_env_.local_devices)); - worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices); + worker_env_.local_devices = master_env_.local_devices; + worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices); + worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr + ? new RpcRendezvousMgr(&worker_env_) + : rendezvous_mgr_func(&worker_env_); string unused; string default_worker_name; if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), @@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func, } WorkerCacheInterface* worker_cache; - TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache)); + WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); + TF_RETURN_IF_ERROR( + WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); CHECK_NE(nullptr, worker_cache); // Set up worker environment. - std::unique_ptr rendezvous_mgr( - rendevous_mgr_func == nullptr ? - new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) : - rendevous_mgr_func(&worker_env_, name_prefix, worker_cache)); worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { - return WorkerCacheFactory(server_def, worker_cache); + WorkerCacheFactoryOptions options(server_def); + return WorkerCacheFactory(options, worker_cache); }); worker_env_.compute_pool = ComputePool(sess_opts); @@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func, master_env_.master_session_factory = [config]( SessionOptions options, const MasterEnv* env, - std::unique_ptr>> remote_devs) { + std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set) { options.config.MergeFrom(config); return new MasterSession(options, env, std::move(remote_devs), + std::move(worker_cache), std::move(device_set), CreateNoOpStatsPublisher); }; + master_env_.worker_cache_factory = + [this](const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache) { + return WorkerCacheFactory(options, worker_cache); + }; // Provide direct access to the master from in-process clients. LocalMaster::Register(target(), master_impl_.get(), @@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, return Status::OK(); } -Status GrpcServer::Init() { - return Init(nullptr, nullptr); -} +Status GrpcServer::Init() { return Init(nullptr, nullptr); } -Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, +Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { - for (const auto& job : server_def.cluster().job()) { + for (const auto& job : options.cluster_def->job()) { std::map host_ports; for (const auto& task : job.tasks()) { string& host_port = host_ports[task.first]; @@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, task.first, "\": ", host_port, " and ", task.second); } - if (job.name() == server_def.job_name() && - task.first == server_def.task_index()) { + if (job.name() == *options.job_name && task.first == options.task_index) { host_port = strings::StrCat("localhost:", bound_port_); } else { host_port = task.second; @@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, return Status::OK(); } -Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, +Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { - string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", - "/task:", server_def.task_index()); + if (options.job_name == nullptr || options.job_name->empty()) { + Status s = errors::InvalidArgument( + "The master (current machine) is not included in the provided " + "cluster_def. ", + options.cluster_def->DebugString()); + LOG(WARNING) << s; + return s; + } GrpcChannelSpec channel_spec; - TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); + TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); + + std::unique_ptr channel_cache( + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); + + string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", + "/task:", options.task_index); - std::unique_ptr channel_cache(NewGrpcChannelCache( - channel_spec, GetChannelCreationFunction(server_def))); const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; @@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials( return ::grpc::InsecureServerCredentials(); } -ChannelCreationFunction GrpcServer::GetChannelCreationFunction( - const ServerDef& server_def) const { +ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const { // We can do this because SparseGrpcChannelCache is robust to nullptr being // returned by the channel creation function return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 3b66291a9ab..7b54bb84c88 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -37,9 +37,7 @@ class GrpcWorker; class Master; // function that creates a RendezvousMgr. -typedef std::function +typedef std::function RendezvousMgrCreationFunction; // function that registers a service to the server. The service needs to @@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface { protected: Status Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func); + const RendezvousMgrCreationFunction& rendezvous_mgr_func); Status Init(); @@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface { virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( const ServerDef& server_def) const; - virtual ChannelCreationFunction GetChannelCreationFunction( - const ServerDef& server_def) const; + virtual ChannelCreationFunction GetChannelCreationFunction() const; virtual std::unique_ptr CreateMaster(MasterEnv* master_env); // Creates a WorkerCacheInterface for a session. - Status WorkerCacheFactory(const ServerDef& server_def, + Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache); - // Parses a ServerDef into a GrpcChannelSpec. - Status ParseChannelSpec(const ServerDef& server_def, + // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. + Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec); // Returns the port to which this server is bound. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 1aacef8a26a..38d59d5bb59 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix); /* static */ Status GrpcSession::Create(const SessionOptions& options, std::unique_ptr* out_session) { - std::unique_ptr ret(new GrpcSession(options)); + std::unique_ptr session(new GrpcSession(options)); std::unique_ptr master; // For testing, we enable the client to disable the use of the local // master registry, so that the RPC stack is exercised. @@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options, options.target.substr(kSchemePrefixLength), &master_channel)); master.reset(NewGrpcMaster(master_channel)); } - ret->SetRemoteMaster(std::move(master)); - *out_session = std::move(ret); + session->SetRemoteMaster(std::move(master)); + *out_session = std::move(session); return Status::OK(); } @@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options, CreateSessionRequest req; *req.mutable_config() = options_.config; *req.mutable_graph_def() = graph; + req.set_target(options_.target); ReEncodeConsts(req.mutable_graph_def()); CreateSessionResponse resp; Status s = master_->CreateSession(call_options, &req, &resp); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index c11266587d8..873ef8588f4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // completes, and we may decide to bound some of the request // types. ENQUEUE_REQUEST(GetStatus, false); + ENQUEUE_REQUEST(CreateWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(GetStatus, false); } + void CreateWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->CreateWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CreateWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { @@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, ::grpc::ByteBuffer* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; @@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, // of execution of the callback lambda body below, an RPC // cancellation should abort the rendezvous. opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); - session->rendezvous_mgr->RecvLocalAsync( + env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, [opts, response, done, src_dev](const Status& status, const Rendezvous::Args& send_args, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 7518a289fdb..8265100061e 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -38,9 +38,8 @@ namespace { class RpcRemoteRendezvous : public BaseRemoteRendezvous { public: - RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* cache, int64 step_id) - : BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {} + RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) + : BaseRemoteRendezvous(env, step_id, false) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { private: ~RpcRemoteRendezvous() override {} - WorkerCacheInterface* const cache_; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); }; @@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() { return call_freelist; } -// A private cache that wraps worker_cache and allows reuse of -// WorkerInterface objects. -class WorkerFreeListCache : public WorkerCacheInterface { - public: - explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {} - - ~WorkerFreeListCache() { - for (auto p : workers_) { - wrapped_->ReleaseWorker(p.first, p.second.worker); - } - } - - void ListWorkers(std::vector* workers) const override { - wrapped_->ListWorkers(workers); - } - - WorkerInterface* CreateWorker(const string& target) override { - mutex_lock l(mu_); - auto p = workers_.find(target); - if (p != workers_.end()) { - return p->second.worker; - } - WorkerState state; - state.worker = wrapped_->CreateWorker(target); - if (state.worker != nullptr) { - workers_.insert(std::make_pair(target, state)); - } - return state.worker; - } - - void ReleaseWorker(const string& target, WorkerInterface* worker) override { - // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. - } - - bool GetDeviceLocalityNonBlocking(const string& device, - DeviceLocality* locality) override { - return wrapped_->GetDeviceLocalityNonBlocking(device, locality); - } - - void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, - StatusCallback done) override { - wrapped_->GetDeviceLocalityAsync(device, locality, done); - } - - void SetLogging(bool active) override { wrapped_->SetLogging(active); } - - void ClearLogs() override { wrapped_->ClearLogs(); } - - bool RetrieveLogs(int64 step_id, StepStats* ss) override { - return wrapped_->RetrieveLogs(step_id, ss); - } - - private: - WorkerCacheInterface* wrapped_; - - // Information kept per created WorkerInterface. - struct WorkerState { - WorkerInterface* worker; - // TODO(jeff,sanjay): Add reference count if we support eviction. - }; - - // TODO(jeff,sanjay): Eviction when the map becomes too big. - mutex mu_; - std::unordered_map workers_ GUARDED_BY(mu_); -}; - void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { + CHECK(is_initialized()); Status s; // Prepare a RecvTensor call that can handle being aborted. @@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = errors::Internal(parsed.src_device, " is invalid remote source device."); } - WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); + WorkerSession* sess = session(); + WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } Device* dst_device; if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { - get_call_freelist()->Release(call, cache_); + if (rwi != nullptr) { + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); + } + get_call_freelist()->Release(call, sess->worker_cache.get()); done(s, Args(), recv_args, Tensor{}, false); return; } @@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // current status should be bad. Status s = call->status(); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); - cache_->ReleaseWorker(call->src_worker_, call->wi_); + session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_); call->wi_ = nullptr; - get_call_freelist()->Release(call, cache_); + get_call_freelist()->Release(call, session()->worker_cache.get()); Unref(); }); } } // namespace -RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env, - const string& worker_name, - WorkerCacheInterface* worker_cache) - : BaseRendezvousMgr(env, worker_name), - cache_(new WorkerFreeListCache(worker_cache)) {} +RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env) {} BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) { - return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(), - step_id); + const WorkerEnv* worker_env) { + return new RpcRemoteRendezvous(worker_env, step_id); } } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h index 75dc62d98fd..34c48a79177 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { +class DeviceMgr; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -44,17 +44,12 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RpcRendezvousMgr : public BaseRendezvousMgr { public: - explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache); + explicit RpcRendezvousMgr(const WorkerEnv* env); protected: - BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& session_name) override; + BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env); private: - // Private cache_ that allows us to reuse WorkerInterface objects. - std::unique_ptr cache_; - TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); }; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 9b778eab3a5..2d0d76623d4 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test { : cache_(new DummyWorkerCache), worker_session_("/job:mnist/replica:1/task:2", std::unique_ptr(cache_), - std::unique_ptr(), + std::unique_ptr(), std::unique_ptr()), - rmgr_(&env, worker_session_.worker_name, cache_) { + rmgr_(&env) { env.env = Env::Default(); } @@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { // Explicit Abort(). const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, rendez]() { env.env->SleepForMicroseconds(100 * 1000); @@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } { // Cleanup causes Abort(). const int64 step_id = 321; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, step_id]() { env.env->SleepForMicroseconds(100 * 1000); @@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } } @@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); Rendezvous::Args args; args.device_context = dc; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); } { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2be62f816c..22551d54821 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -26,23 +27,12 @@ namespace tensorflow { SessionMgr::SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory) - : SessionMgr( - worker_env, default_worker_name, std::move(default_worker_cache), - default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {} - -SessionMgr::SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory) : worker_env_(worker_env), - legacy_session_( - default_worker_name, std::move(default_worker_cache), - std::unique_ptr(default_rendezvous_mgr), - std::unique_ptr( - new GraphMgr(worker_env, default_rendezvous_mgr))), + legacy_session_(default_worker_name, std::move(default_worker_cache), + std::unique_ptr(worker_env->device_mgr), + std::unique_ptr( + new GraphMgr(worker_env, worker_env->device_mgr))), worker_cache_factory_(std::move(worker_cache_factory)) {} string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { @@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { Status SessionMgr::CreateSession(const string& session, const ServerDef& server_def) { mutex_lock l(mu_); + if (session.empty()) { + return errors::InvalidArgument("Session must be non-empty."); + } + const string worker_name = WorkerNameFromServerDef(server_def); WorkerCacheInterface* worker_cache = nullptr; TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); - std::unique_ptr rendezvous_mgr( - new RpcRendezvousMgr(worker_env_, worker_name, worker_cache)); + std::vector renamed_devices; + for (Device* d : worker_env_->local_devices) { + renamed_devices.push_back( + RenamedDevice::NewRenamedDevice(worker_name, d, false)); + } + std::unique_ptr device_mgr(new DeviceMgr(renamed_devices)); std::unique_ptr graph_mgr( - new GraphMgr(worker_env_, rendezvous_mgr.get())); + new GraphMgr(worker_env_, device_mgr.get())); std::unique_ptr worker_session(new WorkerSession( worker_name, std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), std::move(graph_mgr))); + std::move(device_mgr), std::move(graph_mgr))); sessions_.insert(std::make_pair(session, std::move(worker_session))); return Status::OK(); @@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) { if (it != sessions_.end()) { sessions_.erase(it); } - std::set graph_handles; - for (auto graph_handle_it = sessions_by_graph_handle_.begin(); - graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) { - if (graph_handle_it->second == session) { - graph_handles.insert(graph_handle_it->first); - graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it); - if (graph_handle_it == sessions_by_graph_handle_.end()) break; - } - } - for (auto step_id_it = graphs_by_step_id_.begin(); - step_id_it != graphs_by_step_id_.end(); ++step_id_it) { - if (graph_handles.find(step_id_it->second) != graph_handles.end()) { - step_id_it = graphs_by_step_id_.erase(step_id_it); - if (step_id_it == graphs_by_step_id_.end()) break; - } - } return Status::OK(); } @@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) { WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; } -WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked( - const string& graph_handle) { - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it == sessions_by_graph_handle_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForSessionUnlocked(it->second); - } -} - -WorkerSession* SessionMgr::WorkerSessionForGraphHandle( - const string& graph_handle) { - mutex_lock l(mu_); - return WorkerSessionForGraphHandleUnlocked(graph_handle); -} - -WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it == graphs_by_step_id_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForGraphHandleUnlocked(it->second); - } -} - -void SessionMgr::AssociateGraphWithSession(const string& session, - const string& graph_handle) { - mutex_lock l(mu_); - sessions_by_graph_handle_[graph_handle] = session; -} - -void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) { - mutex_lock l(mu_); - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it != sessions_by_graph_handle_.end()) { - sessions_by_graph_handle_.erase(it); - } -} - -void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id) { - mutex_lock l(mu_); - graphs_by_step_id_[step_id] = graph_handle; -} - -void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it != graphs_by_step_id_.end()) { - graphs_by_step_id_.erase(it); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 455b5c8d9d9..c44bca7b7a4 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -30,6 +30,8 @@ struct WorkerEnv; // SessionMgr keeps track of information related to a given session. // +// SessionMgr runs on the workers. +// // SessionMgr is threadsafe. class SessionMgr { public: @@ -39,7 +41,6 @@ class SessionMgr { explicit SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory); ~SessionMgr() {} @@ -50,49 +51,36 @@ class SessionMgr { WorkerSession* WorkerSessionForSession(const string& session); WorkerSession* LegacySession(); - // Locates the worker session for a given graph handle - WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle); - void AssociateGraphWithSession(const string& session, - const string& graph_handle); - void DisassociateGraphFromSession(const string& graph_handle); - - // Locates a worker session for a given step id - WorkerSession* WorkerSessionForStepId(const int64 step_id); - void AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id); - void DisassociateStepIdFromGraph(const int64 step_id); - Status DeleteSession(const string& session); static string WorkerNameFromServerDef(const ServerDef& server_def); private: - // Private constructor to work around std::unique_ptr ownership issues. - explicit SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory); - const WorkerEnv* const worker_env_; // Not owned. + + // A note about destruction: + // We must delete graph_mgr before device_mgr, due to shared + // ownership of OpKernels in the executors. (The graph_mgr will + // free all stateless OpKernels, and pass over borrowed stateful + // OpKernels, which are also held in their respective devices' + // OpSegments.) + // + // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure + // that sessions_'s WorkerSessions are deleted (which do not own the + // underlying devices, but instead own RenamedDevices) before + // legacy_session_ is deleted. Further, we must ensure that WorkerSession's + // device_mgr is deleted after WorkerSession's graph_mgr. + WorkerSession legacy_session_; const WorkerCacheFactory worker_cache_factory_; WorkerSession* WorkerSessionForSessionUnlocked(const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_); - WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle) - EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; // A map from session identifier to internal session structure. std::map> sessions_ GUARDED_BY(mu_); - - // A map from graph handles to the session that they belong to. - std::map sessions_by_graph_handle_ GUARDED_BY(mu_); - - // A map from globally-unique step id's to the corresponding graph handles. - std::map graphs_by_step_id_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index d3f3fa83958..7132f123a59 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test { SessionMgrTest() : mgr_(&env_, "/job:mnist/replica:0/task:0", std::unique_ptr(), - std::unique_ptr(new RpcRendezvousMgr( - &env_, "/job:mnist/replica:0/task:0", nullptr)), factory_), legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {} @@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) { TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - + EXPECT_NE(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateGraphWithSession) { +TEST_F(SessionMgrTest, LegacySession) { ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); + string session_handle = ""; WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); + EXPECT_EQ(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateStepWithGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(session, step_session); - ASSERT_EQ(graph_session, step_session); - - TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); -} - -TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - TEST_F(SessionMgrTest, WorkerNameFromServerDef) { ServerDef server_def; server_def.set_job_name("worker"); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 89639e21b5d..07bb17981d3 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, Status s = session->graph_mgr->Register( request->session_handle(), request->graph_def(), request->graph_options(), request->debug_options(), response->mutable_graph_handle()); - if (s.ok()) { - env_->session_mgr->AssociateGraphWithSession(request->session_handle(), - response->graph_handle()); - } done(s); } @@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) { WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); Status s = session->graph_mgr->Deregister(request->graph_handle()); - env_->session_mgr->DisassociateGraphFromSession(request->graph_handle()); done(s); } @@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id, } void Worker::AbortStep(int64 step_id) { - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - Rendezvous* rendez = session->rendezvous_mgr->Find(step_id); + Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { // Delay a bit before aborting the step. This way, the root // cause may return first back to the client instead of this @@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, const int64 step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); - env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, } CostGraphDef* cost_graph = response->mutable_cost_graph(); session->graph_mgr->ExecuteAsync( - request->graph_handle(), step_id, request->exec_opts(), collector, - cost_graph, cm, in, + request->graph_handle(), step_id, session, request->exec_opts(), + collector, cost_graph, cm, in, [this, step_id, response, session, cm, out, token, collector, opts, done](Status s) { if (s.ok()) { @@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, const string& graph_handle = request->graph_handle(); TRACEPRINTF("PartialRunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(graph_handle); - env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); + GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, [cm]() { cm->StartCancel(); }); } session->graph_mgr->ExecuteAsync( - graph_handle, step_id, request->exec_opts(), nullptr /* collector */, - nullptr /* cost_graph */, cm, in, + graph_handle, step_id, session, request->exec_opts(), + nullptr /* collector */, nullptr /* cost_graph */, cm, in, [this, token, graph_handle, step_id, cm](Status s) { { mutex_lock l(mu_); @@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - session->rendezvous_mgr->Cleanup(step_id); + env_->rendezvous_mgr->Cleanup(step_id); done(Status::OK()); } @@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request, Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, Device** src_dev) { // Figures out which device the tensor is hosted on. - TF_RETURN_IF_ERROR( - env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); + string local_name = DeviceNameUtils::LocalName(parsed.src_device); + TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); // Does the device have the right incarnation number we expect? if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 24fb5948a71..f09bea328fd 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ +#include #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -24,8 +25,10 @@ namespace thread { class ThreadPool; } // namespace thread +class Device; class DeviceMgr; class Env; +class RendezvousMgrInterface; class SessionMgr; // The worker environment class, which holds a bag of pointers to @@ -38,10 +41,18 @@ struct WorkerEnv { // session_mgr encapsulates state for each session. SessionMgr* session_mgr = nullptr; + // The local devices of this worker. Devices are owned by the device_mgr. + // + // REQUIRES: !local_devices.empty(). + std::vector local_devices; + // device_mgr manages local devices (cpu and gpu). The WorkerService // is the network interface for managed devices. DeviceMgr* device_mgr = nullptr; + // A set of rendezvous keyed by step ids. + RendezvousMgrInterface* rendezvous_mgr = nullptr; + // A pool of threads for scheduling compute work. thread::ThreadPool* compute_pool = nullptr; }; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 508bc7f4680..c9db28ec67f 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -113,6 +113,11 @@ class WorkerInterface { return CallAndWait(&ME::GetStatusAsync, request, response); } + Status CreateWorkerSession(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response) { + return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 8298e169595..8691450e9bc 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -17,14 +17,84 @@ limitations under the License. namespace tensorflow { -WorkerSession::WorkerSession( - const string& worker_name, - std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, - std::unique_ptr graph_mgr) +namespace { + +// A private cache that wraps worker_cache and allows reuse of +// WorkerInterface objects. +class WorkerFreeListCache : public WorkerCacheInterface { + public: + explicit WorkerFreeListCache(std::unique_ptr w) + : wrapped_(std::move(w)) {} + + ~WorkerFreeListCache() final { + for (auto p : workers_) { + wrapped_->ReleaseWorker(p.first, p.second.worker); + } + } + + void ListWorkers(std::vector* workers) const override { + wrapped_->ListWorkers(workers); + } + + WorkerInterface* CreateWorker(const string& target) override { + mutex_lock l(mu_); + auto p = workers_.find(target); + if (p != workers_.end()) { + return p->second.worker; + } + WorkerState state; + state.worker = wrapped_->CreateWorker(target); + if (state.worker != nullptr) { + workers_.insert(std::make_pair(target, state)); + } + return state.worker; + } + + void ReleaseWorker(const string& target, WorkerInterface* worker) override { + // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. + } + + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return wrapped_->GetDeviceLocalityNonBlocking(device, locality); + } + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + wrapped_->GetDeviceLocalityAsync(device, locality, done); + } + + void SetLogging(bool active) override { wrapped_->SetLogging(active); } + + void ClearLogs() override { wrapped_->ClearLogs(); } + + bool RetrieveLogs(int64 step_id, StepStats* ss) override { + return wrapped_->RetrieveLogs(step_id, ss); + } + + private: + std::unique_ptr wrapped_; + + // Information kept per created WorkerInterface. + struct WorkerState { + WorkerInterface* worker; + // TODO(jeff,sanjay): Add reference count if we support eviction. + }; + + // TODO(jeff,sanjay): Eviction when the map becomes too big. + mutex mu_; + std::unordered_map workers_ GUARDED_BY(mu_); +}; + +} // namespace + +WorkerSession::WorkerSession(const string& worker_name, + std::unique_ptr worker_cache, + std::unique_ptr device_mgr, + std::unique_ptr graph_mgr) : worker_name(worker_name), - worker_cache(std::move(worker_cache)), - rendezvous_mgr(std::move(rendezvous_mgr)), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + device_mgr(std::move(device_mgr)), graph_mgr(std::move(graph_mgr)) {} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index e6ebe883298..77cf4de8f74 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -18,14 +18,13 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" namespace tensorflow { class GraphMgr; -class RendezvousMgrInterface; class WorkerCacheInterface; // WorkerSession encapsulates all of the state relating to a given session. @@ -36,17 +35,20 @@ struct WorkerSession { // Object from which WorkerInterface instances can be obtained. const std::unique_ptr worker_cache; - // A set of rendezvous keyed by step ids. - const std::unique_ptr rendezvous_mgr; + // Collection of local devices. These devices are typically RenamedDevices + // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr + // == worker_env_.device_mgr, which holds the true devices. + const std::unique_ptr device_mgr; // graph_mgr keeps track of the registered graphs of this session. // // Note: graph_mgr must be deleted before rendezvous_mgr! + // Note: graph_mgr must be deleted before device_mgr! const std::unique_ptr graph_mgr; WorkerSession(const string& worker_name, std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, + std::unique_ptr device_mgr, std::unique_ptr graph_mgr); }; diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 8894671fdf3..27fe28fe60a 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -115,7 +115,7 @@ class DeviceBase { cpu_worker_threads_ = t; } - const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { CHECK(cpu_worker_threads_ != nullptr); return cpu_worker_threads_; } @@ -140,7 +140,7 @@ class DeviceBase { gpu_device_info_ = g; } - const GpuDeviceInfo* tensorflow_gpu_device_info() const { + virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const { return gpu_device_info_; } @@ -170,13 +170,13 @@ class DeviceBase { return GetAllocator(attr); } - const Eigen::ThreadPoolDevice* eigen_cpu_device() { + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { CHECK(eigen_cpu_device_ != nullptr); return eigen_cpu_device_; } #ifdef TENSORFLOW_USE_SYCL - const Eigen::SyclDevice* eigen_sycl_device() const { + virtual const Eigen::SyclDevice* eigen_sycl_device() const { CHECK(eigen_sycl_device_ != nullptr); return eigen_sycl_device_; } diff --git a/tensorflow/core/protobuf/cluster.proto b/tensorflow/core/protobuf/cluster.proto new file mode 100644 index 00000000000..33c87eefe02 --- /dev/null +++ b/tensorflow/core/protobuf/cluster.proto @@ -0,0 +1,82 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ClusterProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster. +// +// EXAMPLES +// -------- +// +// 1. A single-process cluster, containing "/job:local/task:0". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } +// +// Server: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// +// 2. A two-process cluster, containing "/job:local/task:{0,1}". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } +// tasks { key: 1 value: 'localhost:2223' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// cluster { $CLUSTER } job_name: 'local' task_index: 1 +// +// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and +// "/job:ps/task:{0,1}". +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } +// tasks { key: 2 value: 'worker3:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'worker' task_index: 0 +// cluster { $CLUSTER } job_name: 'worker' task_index: 1 +// cluster { $CLUSTER } job_name: 'worker' task_index: 2 +// cluster { $CLUSTER } job_name: 'ps' task_index: 0 +// cluster { $CLUSTER } job_name: 'ps' task_index: 1 + +// Defines a single job in a TensorFlow cluster. +message JobDef { + // The name of this job. + string name = 1; + + // Mapping from task ID to "hostname:port" string. + // + // If the `name` field contains "worker", and the `tasks` map contains a + // mapping from 7 to "example.org:2222", then the device prefix + // "/job:worker/task:7" will be assigned to "example.org:2222". + map tasks = 2; +} + +// Defines a TensorFlow cluster as a set of jobs. +message ClusterDef { + // The jobs that comprise the cluster. + repeated JobDef job = 1; +} diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 5c0f7232ebd..630f47633f8 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/protobuf/debug.proto"; +import "tensorflow/core/protobuf/cluster.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; message GPUOptions { @@ -259,6 +260,11 @@ message ConfigProto { // Options that apply when this session uses the distributed runtime. RPCOptions rpc_options = 13; + + // Optional list of all workers to use in this session. + ClusterDef cluster_def = 14; + + // Next: 15 }; // Options for a single Run() call. diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto index de91b6133e4..e607b1c42a5 100644 --- a/tensorflow/core/protobuf/master.proto +++ b/tensorflow/core/protobuf/master.proto @@ -38,6 +38,9 @@ message CreateSessionRequest { // Configuration options. ConfigProto config = 2; + + // The target string used from the client's perspective. + string target = 3; } message CreateSessionResponse { diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index c4077bd98e4..6199e707e5a 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/cluster.proto"; package tensorflow; option cc_enable_arenas = true; @@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; -// This file contains protos to be used when defining a TensorFlow -// cluster, and a server within that cluster. -// -// EXAMPLES -// -------- -// -// 1. A single-process cluster, containing "/job:local/task:0". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } -// -// Server: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// -// 2. A two-process cluster, containing "/job:local/task:{0,1}". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } -// tasks { key: 1 value: 'localhost:2223' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// cluster { $CLUSTER } job_name: 'local' task_index: 1 -// -// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and -// "/job:ps/task:{0,1}". -// -// Cluster: -// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } -// tasks { key: 1 value: 'worker2:2222' } -// tasks { key: 2 value: 'worker3:2222' } } -// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } -// tasks { key: 1 value: 'ps1:2222' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'worker' task_index: 0 -// cluster { $CLUSTER } job_name: 'worker' task_index: 1 -// cluster { $CLUSTER } job_name: 'worker' task_index: 2 -// cluster { $CLUSTER } job_name: 'ps' task_index: 0 -// cluster { $CLUSTER } job_name: 'ps' task_index: 1 - -// Defines a single job in a TensorFlow cluster. -message JobDef { - // The name of this job. - string name = 1; - - // Mapping from task ID to "hostname:port" string. - // - // If the `name` field contains "worker", and the `tasks` map contains a - // mapping from 7 to "example.org:2222", then the device prefix - // "/job:worker/task:7" will be assigned to "example.org:2222". - // - // NOTE(mrry): Currently, only a dense task ID space starting at 0 is - // supported. - map tasks = 2; -} - -// Defines a TensorFlow cluster as a set of jobs. -message ClusterDef { - // The jobs that comprise the cluster. - repeated JobDef job = 1; -} - // Defines the configuration of a single TensorFlow server. message ServerDef { // The cluster of which this server is a member. diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 661327847c1..cf05aece39a 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -119,6 +119,10 @@ message RegisterGraphResponse { //////////////////////////////////////////////////////////////////////////////// message DeregisterGraphRequest { + // The session_handle used when registering the graph. If session_handle is + // empty, a single global namespace is used. + string session_handle = 2; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -167,6 +171,12 @@ message ExecutorOpts { }; message RunGraphRequest { + // session_handle is the the master-generated unique id for this session. + // If session_handle is non-empty, it must be the same as used when + // registering the graph. If it is empty, a single global namespace is used to + // search for the graph_handle. + string session_handle = 8; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -193,6 +203,8 @@ message RunGraphRequest { bool is_partial = 6; // True if this is the last partial run request in a sequence of requests. bool is_last_partial_run = 7; + + // Next: 9 } message RunGraphResponse { diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 864a96ef348..6336ca23105 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import * from tensorflow.core.framework.attr_value_pb2 import * from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo from tensorflow.core.protobuf.config_pb2 import * +from tensorflow.core.protobuf.tensorflow_server_pb2 import * from tensorflow.core.protobuf.rewriter_config_pb2 import * from tensorflow.core.util.event_pb2 import * @@ -131,6 +132,7 @@ _allowed_symbols = [ 'AttrValue', 'AutoParallelOptions', 'ConfigProto', + 'ClusterDef', 'DeviceSpec', 'Event', 'GPUOptions', diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 9add5bd3cde..040cc333158 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -29,6 +29,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes @@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase): with CaptureStderr() as log: sess.run(c) # Ensure that we did log device placement. - self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log)) + self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log)) def testLocalMasterSessionTimeout(self): # Test that the timeout passed in a config to the session works correctly. @@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase): server = server_lib.Server.create_local_server() self.runTestBuildGraphError(session.Session(server.target)) + def testClusterSpecPropagationSimple(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationWorker2Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + output = sess.run(const, options=run_options, run_metadata=run_metadata) + self.assertEqual(17, output) + self.assertEqual(1, + len([ + node_stats + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:1/device:CPU:0' == + dev_stats.device and 'Const' == node_stats.node_name + ])) + + def testClusterSpecPropagationWorker1Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationThreeServers2Graphs(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker1' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker2' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.Graph().as_default() as g1: + with ops.device('/job:worker1/task:1'): + var1 = variables.Variable(array_ops.zeros([2]), name='var1') + update_op1 = state_ops.assign_add( + var1, array_ops.ones([2]), name='var1_assign_add') + init1 = variables.global_variables_initializer() + + with ops.Graph().as_default() as g2: + with ops.device('/job:worker2/task:1'): + var2 = variables.Variable(array_ops.zeros([2]), name='var2') + update_op2 = state_ops.assign_add( + var2, array_ops.ones([2]), name='var2_assign_add') + init2 = variables.global_variables_initializer() + + sess1 = session.Session(server2.target, graph=g1, config=config1) + sess2 = session.Session(server2.target, graph=g2, config=config2) + + init1.run(session=sess1) + init2.run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + + self.assertAllEqual(expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + self.assertAllEqual(expected_ones, sess2.run(update_op2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess2.run(var2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) + + def testClusterSpecPropagationThreeServers(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.device('/job:worker/task:1'): + var = variables.Variable(array_ops.zeros([2]), name='var') + feed = array_ops.placeholder(dtypes.float32, shape=(2)) + update_op = var.assign_add(feed) + + sess1 = session.Session(server2.target, config=config1) + sess2 = session.Session(server2.target, config=config2) + + variables.global_variables_initializer().run(session=sess1) + variables.global_variables_initializer().run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess2.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones + expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess2.run(var)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var)) + + def testClusterSpecPropagationThreeServersOneCluster(self): + """Boots 3 servers, ensures appropriate communication across workers. + + Additionally, in this cluster, we ensure the master is not the 0-th worker. + + Note: this test only uses one session. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server3.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + job.tasks[2] = server1.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + # Add ops to the devices in non-linear order. + + with ops.device('/job:worker/task:1'): + feed1 = array_ops.placeholder(dtypes.float32, shape=(2)) + const1 = constant_op.constant(2.0) + mul1 = const1 * feed1 + + with ops.device('/job:worker/task:2'): + feed2 = array_ops.placeholder(dtypes.float32, shape=(2)) + const2 = constant_op.constant(2.0) + mul2 = const2 * feed2 + + with ops.device('/job:worker/task:0'): + feed0 = array_ops.placeholder(dtypes.float32, shape=(2)) + const0 = constant_op.constant(2.0) + mul0 = const0 * feed0 + + sum_op = mul0 + mul1 + mul2 + + ones = np.ones([2]) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + + # Run! + with session.Session(server1.target, config=config) as sess: + output = sess.run( + sum_op, + options=run_options, + run_metadata=run_metadata, + feed_dict={feed1: ones, + feed2: ones, + feed0: ones}) + self.assertAllEqual(6 * ones, output) + + self.assertEqual( + 3, + len([ + dev_stats.device + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:' in dev_stats.device and + node_stats.node_name.startswith('Const') + ]), run_metadata) + + def testClusterSpecPropagationPartialRun(self): + """Test successful partial run with ClusterSpec propagation.""" + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.device('/job:worker/task:0'): + a = array_ops.placeholder(dtypes.float32, shape=[]) + with ops.device('/job:worker/task:1'): + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + with ops.device('/job:worker/task:0'): + r2 = math_ops.multiply(r1, c) + + with session.Session(server1.target, config=config) as sess: + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + res = sess.partial_run(h, r2, feed_dict={c: 3}) + self.assertEqual(9, res) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index d2ccf37d885..2091eca0b9c 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors @@ -276,14 +277,14 @@ class ClusterSpec(object): "from integers to strings." % job_name) self._cluster_spec[job_name] = job_tasks self._make_cluster_def() - elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): + elif isinstance(cluster, cluster_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items()} elif isinstance(cluster, ClusterSpec): - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: @@ -440,7 +441,7 @@ class ClusterSpec(object): TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, tasks in sorted(self._cluster_spec.items()): diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index bdf3d9c0175..f4ac3c97587 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import * # pylint: enable=wildcard-import # Distributed computing support. -from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef -from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.cluster_pb2 import JobDef from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import Server @@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server _allowed_symbols = [ # TODO(cwhipkey): review these and move to contrib or expose through # documentation. - "generate_checkpoint_state_proto", # Used internally by saver. + "generate_checkpoint_state_proto", # Used internally by saver. "checkpoint_exists", # Only used in test? "get_checkpoint_mtimes", # Only used in test? # Legacy: remove. "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef. - # No uses within tensorflow. + # No uses within tensorflow. "queue_runner", # Use tf.train.start_queue_runner etc directly. - # This is also imported internally. + # This is also imported internally. # TODO(drpng): document these. The reference in howtos/distributed does # not link. "SyncReplicasOptimizer", # Protobufs: - "BytesList", # from example_pb2. + "BytesList", # from example_pb2. "ClusterDef", - "Example", # from example_pb2 - "Feature", # from example_pb2 - "Features", # from example_pb2 - "FeatureList", # from example_pb2 - "FeatureLists", # from example_pb2 - "FloatList", # from example_pb2. - "Int64List", # from example_pb2. + "Example", # from example_pb2 + "Feature", # from example_pb2 + "Features", # from example_pb2 + "FeatureList", # from example_pb2 + "FeatureLists", # from example_pb2 + "FloatList", # from example_pb2. + "Int64List", # from example_pb2. "JobDef", - "SaverDef", # From saver_pb2. - "SequenceExample", # from example_pb2. + "SaverDef", # From saver_pb2. + "SequenceExample", # from example_pb2. "ServerDef", ] # Include extra modules for docstrings because: diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 805a9bdd4f1..da6af3919e9 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" mtype: "" } + member { + name: "CLUSTER_DEF_FIELD_NUMBER" + mtype: "" + } member { name: "DESCRIPTOR" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt index feb73bd7d4f..93ff856b09d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.ClusterDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt index 2d7fcbe5456..ac6d81541a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef.TasksEntry" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt index fc5b76341d2..ce34537fa13 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR"