Populate some additional parameters when running eager operations.

PiperOrigin-RevId: 224391662
This commit is contained in:
Akshay Modi 2018-12-06 13:01:00 -08:00 committed by TensorFlower Gardener
parent 6427950eae
commit 844dbd8730
7 changed files with 46 additions and 7 deletions

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
@ -71,6 +74,13 @@ EagerContext::EagerContext(const SessionOptions& opts,
runner_ = [this](std::function<void()> closure) { runner_ = [this](std::function<void()> closure) {
this->thread_pool_->Schedule(std::move(closure)); this->thread_pool_->Schedule(std::move(closure));
}; };
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(local_device_mgr()));
std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
local_device_mgr(), drl.get(), "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
} }
void EagerContext::InitDeviceMapAndAsync() { void EagerContext::InitDeviceMapAndAsync() {

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#endif #endif
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
@ -147,6 +148,11 @@ class EagerContext {
bool LogMemory() { return log_memory_; } bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; } Rendezvous* GetRendezvous() { return rendezvous_; }
std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
return std::unique_ptr<CollectiveExecutor::Handle>(
new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(0), true /*inherit_ref*/));
}
const tensorflow::DeviceMgr* local_device_mgr() const { const tensorflow::DeviceMgr* local_device_mgr() const {
return (local_device_manager_ != nullptr) ? local_device_manager_.get() return (local_device_manager_ != nullptr) ? local_device_manager_.get()
@ -273,6 +279,8 @@ class EagerContext {
Env* const env_; Env* const env_;
std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
#ifndef __ANDROID__ #ifndef __ANDROID__
void CloseRemoteContexts(); void CloseRemoteContexts();

View File

@ -284,7 +284,8 @@ Status EagerLocalExecute(EagerOperation* op,
"Unable to find a FunctionLibraryRuntime corresponding to device ", "Unable to find a FunctionLibraryRuntime corresponding to device ",
device->name()); device->name());
} }
kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory(),
ctx->GetCollectiveExecutorHandle());
status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel); status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel);
if (!status.ok()) { if (!status.ok()) {
delete kernel; delete kernel;

View File

@ -84,6 +84,15 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
tensorflow::HOST_MEMORY); tensorflow::HOST_MEMORY);
} }
gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
for (int i = 0; i < inputs->size(); i++) {
DeviceContext* device_context = nullptr;
if (device_->tensorflow_gpu_device_info() != nullptr) {
device_context = device_->tensorflow_gpu_device_info()->default_context;
}
input_device_contexts.push_back(device_context);
}
OpKernelContext::Params params; OpKernelContext::Params params;
params.device = device_; params.device = device_;
params.frame_iter = FrameAndIter(0, 0); params.frame_iter = FrameAndIter(0, 0);
@ -110,6 +119,9 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
} }
params.step_container = step_container; params.step_container = step_container;
params.collective_executor =
collective_executor_ ? collective_executor_->get() : nullptr;
params.input_device_contexts = &input_device_contexts;
OpKernelContext context(&params); OpKernelContext context(&params);

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
@ -55,10 +56,16 @@ class KernelAndDevice {
KernelAndDevice* out); KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
: KernelAndDevice(rendez, log_memory, nullptr) {}
KernelAndDevice(
tensorflow::Rendezvous* rendez, bool log_memory,
std::unique_ptr<CollectiveExecutor::Handle> collective_executor)
: device_(nullptr), : device_(nullptr),
flr_(nullptr), flr_(nullptr),
rendez_(rendez), rendez_(rendez),
log_memory_(log_memory) {} log_memory_(log_memory),
collective_executor_(std::move(collective_executor)) {}
// TODO(ashankar): Handle list-valued inputs. // TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@ -92,6 +99,7 @@ class KernelAndDevice {
std::function<void(std::function<void()>)>* runner_; std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_; std::function<void(std::function<void()>)> default_runner_;
const bool log_memory_; const bool log_memory_;
const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -1527,6 +1527,7 @@ T* OpKernelContext::op_device_context() {
template <typename T> template <typename T>
T* OpKernelContext::input_device_context(int index) { T* OpKernelContext::input_device_context(int index) {
DCHECK_NE(params_->input_device_contexts, nullptr);
DCHECK_GE(index, 0); DCHECK_GE(index, 0);
DCHECK_LT(index, params_->input_device_contexts->size()); DCHECK_LT(index, params_->input_device_contexts->size());
static_assert(std::is_base_of<DeviceContext, T>::value, static_assert(std::is_base_of<DeviceContext, T>::value,
@ -1535,6 +1536,7 @@ T* OpKernelContext::input_device_context(int index) {
} }
inline DeviceContext* OpKernelContext::input_device_context(int index) { inline DeviceContext* OpKernelContext::input_device_context(int index) {
DCHECK_NE(params_->input_device_contexts, nullptr);
DCHECK_GE(index, 0); DCHECK_GE(index, 0);
DCHECK_LT(index, params_->input_device_contexts->size()); DCHECK_LT(index, params_->input_device_contexts->size());
return (*params_->input_device_contexts)[index]; return (*params_->input_device_contexts)[index];

View File

@ -478,10 +478,6 @@ class Context(object):
Raises: Raises:
ValueError: If name is not a string or is an invalid device name. ValueError: If name is not a string or is an invalid device name.
""" """
devices = self._context_devices
if devices is None:
self._initialize_handle_and_devices()
devices = self._context_devices
eager_context = self._eager_context eager_context = self._eager_context
old_device_name = eager_context.device_name old_device_name = eager_context.device_name
old_device_spec = eager_context.device_spec old_device_spec = eager_context.device_spec
@ -502,7 +498,9 @@ class Context(object):
if old_device_name: if old_device_name:
new_device_spec = copy.copy(old_device_spec) new_device_spec = copy.copy(old_device_spec)
else: else:
new_device_spec = pydev.DeviceSpec.from_string(devices[0]) self._initialize_handle_and_devices()
new_device_spec = pydev.DeviceSpec.from_string(
self._context_devices[0])
new_device_spec.merge_from(device_spec) new_device_spec.merge_from(device_spec)
else: else:
new_device_spec = pydev.DeviceSpec.from_string("") new_device_spec = pydev.DeviceSpec.from_string("")