diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 583ae64edd1..1727c045604 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -71,6 +74,13 @@ EagerContext::EagerContext(const SessionOptions& opts, runner_ = [this](std::function closure) { this->thread_pool_->Schedule(std::move(closure)); }; + + std::unique_ptr drl( + new DeviceResolverLocal(local_device_mgr())); + std::unique_ptr cprl(new CollectiveParamResolverLocal( + local_device_mgr(), drl.get(), "/job:localhost/replica:0/task:0")); + collective_executor_mgr_.reset(new CollectiveExecutorMgr( + opts.config, local_device_mgr(), std::move(drl), std::move(cprl))); } void EagerContext::InitDeviceMapAndAsync() { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 51109f8f1ae..cdef9478933 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #endif +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -147,6 +148,11 @@ class EagerContext { bool LogMemory() { return log_memory_; } Rendezvous* GetRendezvous() { return rendezvous_; } + std::unique_ptr GetCollectiveExecutorHandle() { + return std::unique_ptr( + new CollectiveExecutor::Handle( + collective_executor_mgr_->FindOrCreate(0), true /*inherit_ref*/)); + } const tensorflow::DeviceMgr* local_device_mgr() const { return (local_device_manager_ != nullptr) ? local_device_manager_.get() @@ -273,6 +279,8 @@ class EagerContext { Env* const env_; + std::unique_ptr collective_executor_mgr_; + #ifndef __ANDROID__ void CloseRemoteContexts(); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 5bf7888fad5..a6199f2aeb2 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -284,7 +284,8 @@ Status EagerLocalExecute(EagerOperation* op, "Unable to find a FunctionLibraryRuntime corresponding to device ", device->name()); } - kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); + kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory(), + ctx->GetCollectiveExecutorHandle()); status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel); if (!status.ok()) { delete kernel; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 192d22dfd5a..317e9a16074 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -84,6 +84,15 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, tensorflow::HOST_MEMORY); } + gtl::InlinedVector input_device_contexts; + for (int i = 0; i < inputs->size(); i++) { + DeviceContext* device_context = nullptr; + if (device_->tensorflow_gpu_device_info() != nullptr) { + device_context = device_->tensorflow_gpu_device_info()->default_context; + } + input_device_contexts.push_back(device_context); + } + OpKernelContext::Params params; params.device = device_; params.frame_iter = FrameAndIter(0, 0); @@ -110,6 +119,9 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, } params.step_container = step_container; + params.collective_executor = + collective_executor_ ? collective_executor_->get() : nullptr; + params.input_device_contexts = &input_device_contexts; OpKernelContext context(¶ms); diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 52dac94ccca..ee430b7fc70 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -55,10 +56,16 @@ class KernelAndDevice { KernelAndDevice* out); KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) + : KernelAndDevice(rendez, log_memory, nullptr) {} + + KernelAndDevice( + tensorflow::Rendezvous* rendez, bool log_memory, + std::unique_ptr collective_executor) : device_(nullptr), flr_(nullptr), rendez_(rendez), - log_memory_(log_memory) {} + log_memory_(log_memory), + collective_executor_(std::move(collective_executor)) {} // TODO(ashankar): Handle list-valued inputs. Status Run(std::vector* inputs, std::vector* outputs, @@ -92,6 +99,7 @@ class KernelAndDevice { std::function)>* runner_; std::function)> default_runner_; const bool log_memory_; + const std::unique_ptr collective_executor_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 9f4c57e880a..19a0c5e5be2 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1527,6 +1527,7 @@ T* OpKernelContext::op_device_context() { template T* OpKernelContext::input_device_context(int index) { + DCHECK_NE(params_->input_device_contexts, nullptr); DCHECK_GE(index, 0); DCHECK_LT(index, params_->input_device_contexts->size()); static_assert(std::is_base_of::value, @@ -1535,6 +1536,7 @@ T* OpKernelContext::input_device_context(int index) { } inline DeviceContext* OpKernelContext::input_device_context(int index) { + DCHECK_NE(params_->input_device_contexts, nullptr); DCHECK_GE(index, 0); DCHECK_LT(index, params_->input_device_contexts->size()); return (*params_->input_device_contexts)[index]; diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 2f6b038dda9..cbbe5cf49e2 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -478,10 +478,6 @@ class Context(object): Raises: ValueError: If name is not a string or is an invalid device name. """ - devices = self._context_devices - if devices is None: - self._initialize_handle_and_devices() - devices = self._context_devices eager_context = self._eager_context old_device_name = eager_context.device_name old_device_spec = eager_context.device_spec @@ -502,7 +498,9 @@ class Context(object): if old_device_name: new_device_spec = copy.copy(old_device_spec) else: - new_device_spec = pydev.DeviceSpec.from_string(devices[0]) + self._initialize_handle_and_devices() + new_device_spec = pydev.DeviceSpec.from_string( + self._context_devices[0]) new_device_spec.merge_from(device_spec) else: new_device_spec = pydev.DeviceSpec.from_string("")