Populate some additional parameters when running eager operations.
PiperOrigin-RevId: 224391662
This commit is contained in:
parent
6427950eae
commit
844dbd8730
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
@ -71,6 +74,13 @@ EagerContext::EagerContext(const SessionOptions& opts,
|
||||
runner_ = [this](std::function<void()> closure) {
|
||||
this->thread_pool_->Schedule(std::move(closure));
|
||||
};
|
||||
|
||||
std::unique_ptr<DeviceResolverInterface> drl(
|
||||
new DeviceResolverLocal(local_device_mgr()));
|
||||
std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
|
||||
local_device_mgr(), drl.get(), "/job:localhost/replica:0/task:0"));
|
||||
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
|
||||
opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
|
||||
}
|
||||
|
||||
void EagerContext::InitDeviceMapAndAsync() {
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
@ -147,6 +148,11 @@ class EagerContext {
|
||||
bool LogMemory() { return log_memory_; }
|
||||
|
||||
Rendezvous* GetRendezvous() { return rendezvous_; }
|
||||
std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
|
||||
return std::unique_ptr<CollectiveExecutor::Handle>(
|
||||
new CollectiveExecutor::Handle(
|
||||
collective_executor_mgr_->FindOrCreate(0), true /*inherit_ref*/));
|
||||
}
|
||||
|
||||
const tensorflow::DeviceMgr* local_device_mgr() const {
|
||||
return (local_device_manager_ != nullptr) ? local_device_manager_.get()
|
||||
@ -273,6 +279,8 @@ class EagerContext {
|
||||
|
||||
Env* const env_;
|
||||
|
||||
std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
|
||||
|
||||
#ifndef __ANDROID__
|
||||
void CloseRemoteContexts();
|
||||
|
||||
|
@ -284,7 +284,8 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
"Unable to find a FunctionLibraryRuntime corresponding to device ",
|
||||
device->name());
|
||||
}
|
||||
kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
|
||||
kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory(),
|
||||
ctx->GetCollectiveExecutorHandle());
|
||||
status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel);
|
||||
if (!status.ok()) {
|
||||
delete kernel;
|
||||
|
@ -84,6 +84,15 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
tensorflow::HOST_MEMORY);
|
||||
}
|
||||
|
||||
gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
|
||||
for (int i = 0; i < inputs->size(); i++) {
|
||||
DeviceContext* device_context = nullptr;
|
||||
if (device_->tensorflow_gpu_device_info() != nullptr) {
|
||||
device_context = device_->tensorflow_gpu_device_info()->default_context;
|
||||
}
|
||||
input_device_contexts.push_back(device_context);
|
||||
}
|
||||
|
||||
OpKernelContext::Params params;
|
||||
params.device = device_;
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
@ -110,6 +119,9 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
}
|
||||
|
||||
params.step_container = step_container;
|
||||
params.collective_executor =
|
||||
collective_executor_ ? collective_executor_->get() : nullptr;
|
||||
params.input_device_contexts = &input_device_contexts;
|
||||
|
||||
OpKernelContext context(¶ms);
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -55,10 +56,16 @@ class KernelAndDevice {
|
||||
KernelAndDevice* out);
|
||||
|
||||
KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
|
||||
: KernelAndDevice(rendez, log_memory, nullptr) {}
|
||||
|
||||
KernelAndDevice(
|
||||
tensorflow::Rendezvous* rendez, bool log_memory,
|
||||
std::unique_ptr<CollectiveExecutor::Handle> collective_executor)
|
||||
: device_(nullptr),
|
||||
flr_(nullptr),
|
||||
rendez_(rendez),
|
||||
log_memory_(log_memory) {}
|
||||
log_memory_(log_memory),
|
||||
collective_executor_(std::move(collective_executor)) {}
|
||||
|
||||
// TODO(ashankar): Handle list-valued inputs.
|
||||
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
|
||||
@ -92,6 +99,7 @@ class KernelAndDevice {
|
||||
std::function<void(std::function<void()>)>* runner_;
|
||||
std::function<void(std::function<void()>)> default_runner_;
|
||||
const bool log_memory_;
|
||||
const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1527,6 +1527,7 @@ T* OpKernelContext::op_device_context() {
|
||||
|
||||
template <typename T>
|
||||
T* OpKernelContext::input_device_context(int index) {
|
||||
DCHECK_NE(params_->input_device_contexts, nullptr);
|
||||
DCHECK_GE(index, 0);
|
||||
DCHECK_LT(index, params_->input_device_contexts->size());
|
||||
static_assert(std::is_base_of<DeviceContext, T>::value,
|
||||
@ -1535,6 +1536,7 @@ T* OpKernelContext::input_device_context(int index) {
|
||||
}
|
||||
|
||||
inline DeviceContext* OpKernelContext::input_device_context(int index) {
|
||||
DCHECK_NE(params_->input_device_contexts, nullptr);
|
||||
DCHECK_GE(index, 0);
|
||||
DCHECK_LT(index, params_->input_device_contexts->size());
|
||||
return (*params_->input_device_contexts)[index];
|
||||
|
@ -478,10 +478,6 @@ class Context(object):
|
||||
Raises:
|
||||
ValueError: If name is not a string or is an invalid device name.
|
||||
"""
|
||||
devices = self._context_devices
|
||||
if devices is None:
|
||||
self._initialize_handle_and_devices()
|
||||
devices = self._context_devices
|
||||
eager_context = self._eager_context
|
||||
old_device_name = eager_context.device_name
|
||||
old_device_spec = eager_context.device_spec
|
||||
@ -502,7 +498,9 @@ class Context(object):
|
||||
if old_device_name:
|
||||
new_device_spec = copy.copy(old_device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string(devices[0])
|
||||
self._initialize_handle_and_devices()
|
||||
new_device_spec = pydev.DeviceSpec.from_string(
|
||||
self._context_devices[0])
|
||||
new_device_spec.merge_from(device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string("")
|
||||
|
Loading…
Reference in New Issue
Block a user