Populate some additional parameters when running eager operations.

PiperOrigin-RevId: 224391662
This commit is contained in:
Akshay Modi 2018-12-06 13:01:00 -08:00 committed by TensorFlower Gardener
parent 6427950eae
commit 844dbd8730
7 changed files with 46 additions and 7 deletions

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -71,6 +74,13 @@ EagerContext::EagerContext(const SessionOptions& opts,
runner_ = [this](std::function<void()> closure) {
this->thread_pool_->Schedule(std::move(closure));
};
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(local_device_mgr()));
std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
local_device_mgr(), drl.get(), "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
}
void EagerContext::InitDeviceMapAndAsync() {

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#endif
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -147,6 +148,11 @@ class EagerContext {
bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; }
std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
return std::unique_ptr<CollectiveExecutor::Handle>(
new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(0), true /*inherit_ref*/));
}
const tensorflow::DeviceMgr* local_device_mgr() const {
return (local_device_manager_ != nullptr) ? local_device_manager_.get()
@ -273,6 +279,8 @@ class EagerContext {
Env* const env_;
std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
#ifndef __ANDROID__
void CloseRemoteContexts();

View File

@ -284,7 +284,8 @@ Status EagerLocalExecute(EagerOperation* op,
"Unable to find a FunctionLibraryRuntime corresponding to device ",
device->name());
}
kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory(),
ctx->GetCollectiveExecutorHandle());
status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel);
if (!status.ok()) {
delete kernel;

View File

@ -84,6 +84,15 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
tensorflow::HOST_MEMORY);
}
gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
for (int i = 0; i < inputs->size(); i++) {
DeviceContext* device_context = nullptr;
if (device_->tensorflow_gpu_device_info() != nullptr) {
device_context = device_->tensorflow_gpu_device_info()->default_context;
}
input_device_contexts.push_back(device_context);
}
OpKernelContext::Params params;
params.device = device_;
params.frame_iter = FrameAndIter(0, 0);
@ -110,6 +119,9 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
}
params.step_container = step_container;
params.collective_executor =
collective_executor_ ? collective_executor_->get() : nullptr;
params.input_device_contexts = &input_device_contexts;
OpKernelContext context(&params);

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@ -55,10 +56,16 @@ class KernelAndDevice {
KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
: KernelAndDevice(rendez, log_memory, nullptr) {}
KernelAndDevice(
tensorflow::Rendezvous* rendez, bool log_memory,
std::unique_ptr<CollectiveExecutor::Handle> collective_executor)
: device_(nullptr),
flr_(nullptr),
rendez_(rendez),
log_memory_(log_memory) {}
log_memory_(log_memory),
collective_executor_(std::move(collective_executor)) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@ -92,6 +99,7 @@ class KernelAndDevice {
std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_;
const bool log_memory_;
const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
};
} // namespace tensorflow

View File

@ -1527,6 +1527,7 @@ T* OpKernelContext::op_device_context() {
template <typename T>
T* OpKernelContext::input_device_context(int index) {
DCHECK_NE(params_->input_device_contexts, nullptr);
DCHECK_GE(index, 0);
DCHECK_LT(index, params_->input_device_contexts->size());
static_assert(std::is_base_of<DeviceContext, T>::value,
@ -1535,6 +1536,7 @@ T* OpKernelContext::input_device_context(int index) {
}
inline DeviceContext* OpKernelContext::input_device_context(int index) {
DCHECK_NE(params_->input_device_contexts, nullptr);
DCHECK_GE(index, 0);
DCHECK_LT(index, params_->input_device_contexts->size());
return (*params_->input_device_contexts)[index];

View File

@ -478,10 +478,6 @@ class Context(object):
Raises:
ValueError: If name is not a string or is an invalid device name.
"""
devices = self._context_devices
if devices is None:
self._initialize_handle_and_devices()
devices = self._context_devices
eager_context = self._eager_context
old_device_name = eager_context.device_name
old_device_spec = eager_context.device_spec
@ -502,7 +498,9 @@ class Context(object):
if old_device_name:
new_device_spec = copy.copy(old_device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string(devices[0])
self._initialize_handle_and_devices()
new_device_spec = pydev.DeviceSpec.from_string(
self._context_devices[0])
new_device_spec.merge_from(device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string("")