Add mirroring for remote tensor handles
When executing on a remote worker, we may have to copy the TensorHandle for each executed op. To avoid duplicated work, we expand the TensorHandle to keep track of mirrors which are tied to the lifetime of the TensorHandle. If a mirror already exists on a remote worker, no additional copy is needed. The change consists of the following: - Add map of remote mirrors in TensorHandle. - Add `mirror` boolean argument to EagerCopyToDevice which indicates to try configuring a mirror if possible. - Add Device argument to RemoteAddress to handle mirrors. - Expose a ContextMirroringPolicy for the EagerContext. We plan to add additional policies in the future, such as local tensor mirroring. - Rename ContextDevicePlacementPolicy variables to be consistent with ContextMirroringPolicy. PiperOrigin-RevId: 253945140
This commit is contained in:
parent
fd65ff0e37
commit
e75d8dc058
@ -995,3 +995,23 @@ TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(
|
||||
<< handle->DebugString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
TFE_ContextMirroringPolicy policy) {
|
||||
options->mirroring_policy = policy;
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
ctx->context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
|
||||
// Note: this function looks up a thread local policy. So it should be called in
|
||||
// the appropriate client thread. In particular, in async mode, it may not be
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
return static_cast<TFE_ContextMirroringPolicy>(
|
||||
ctx->context->GetMirroringPolicy());
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ tf_cuda_library(
|
||||
srcs = [
|
||||
"c_api.cc",
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -81,6 +82,7 @@ tf_cuda_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
@ -274,7 +276,6 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_experimental.h",
|
||||
"*test*",
|
||||
],
|
||||
),
|
||||
|
@ -369,7 +369,7 @@ void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
|
||||
|
||||
void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
|
||||
options->policy = policy;
|
||||
options->device_placement_policy = policy;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||
@ -392,7 +392,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
@ -406,7 +407,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
}
|
||||
@ -576,7 +578,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h->handle->IsRemote()) {
|
||||
status->status = EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(),
|
||||
h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
|
||||
h->handle->Context()->HostCPU()->name().c_str(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -924,9 +926,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
device_name, &handle);
|
||||
device_name, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
}
|
||||
@ -997,9 +999,9 @@ TFE_TensorHandle* TFE_TensorHandleMaybeCopyToHostCPU(TFE_TensorHandle* h,
|
||||
// TensorHandles created by PyFuncOp lack context and therefore could
|
||||
// not be copied.
|
||||
if (!h->handle->OnHostCPU() && h->handle->Context() != nullptr) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
h->handle, h->handle->Context(), "CPU:0", &handle);
|
||||
h->handle, h->handle->Context(), "CPU:0", false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
} else {
|
||||
|
@ -60,6 +60,8 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
|
||||
|
||||
// Controls how to act when we try to run an operation on a given device but
|
||||
// some input tensors are not on that device.
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with internal copy of enum in eager/context.h.
|
||||
typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
@ -72,6 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
} TFE_ContextDevicePlacementPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
// Sets the default execution mode (sync/async). Note that this can be
|
||||
// overridden per thread using TFE_ContextSetAsyncForThread.
|
||||
|
@ -320,6 +320,29 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with internal copy of enum in eager/context.h.
|
||||
typedef enum TFE_ContextMirroringPolicy {
|
||||
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
|
||||
// copies with their own lifetime.
|
||||
TFE_MIRRORING_NONE = 0,
|
||||
// Mirroring any remote tensor handles, associating them with the lifetime of
|
||||
// the local TensorHandle.
|
||||
TFE_MIRRORING_ALL = 1,
|
||||
} TFE_ContextMirroringPolicy;
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
// Sets a thread-local mirroring policy. After this call, other calls to
|
||||
// TFE_Execute in the same thread will use the mirroring policy specified here
|
||||
// instead of the mirroring policy used to construct the context. This has no
|
||||
// effect on the mirroring policy used by other program threads.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context*, TFE_ContextMirroringPolicy);
|
||||
|
||||
// Returns the mirroring policy to be used by this context in the current
|
||||
// thread.
|
||||
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context*);
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -54,19 +55,24 @@ struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
// true if async execution is enabled.
|
||||
bool async = false;
|
||||
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
||||
TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_policy, bool async,
|
||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
: context(new tensorflow::EagerContext(
|
||||
opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_policy),
|
||||
default_device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
default_mirroring_policy),
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
|
||||
|
@ -59,13 +59,16 @@ auto* eager_context_created =
|
||||
} // namespace
|
||||
|
||||
EagerContext::EagerContext(
|
||||
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
|
||||
bool async, const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator,
|
||||
const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator,
|
||||
const DeviceMgr* remote_device_mgr)
|
||||
: policy_(default_policy),
|
||||
: default_device_placement_policy_(default_device_placement_policy),
|
||||
default_mirroring_policy_(default_mirroring_policy),
|
||||
remote_unowned_device_manager_(remote_device_mgr),
|
||||
devices_(device_mgr->ListDevices()),
|
||||
rendezvous_(rendezvous),
|
||||
@ -171,16 +174,36 @@ void EagerContext::ClearCaches() {
|
||||
void EagerContext::SetThreadLocalDevicePlacementPolicy(
|
||||
ContextDevicePlacementPolicy policy) {
|
||||
mutex_lock ml(policy_map_mu_);
|
||||
thread_local_policies_[std::this_thread::get_id()] = policy;
|
||||
device_placement_policy_[std::this_thread::get_id()] = policy;
|
||||
}
|
||||
|
||||
ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
|
||||
mutex_lock ml(policy_map_mu_);
|
||||
auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
|
||||
if (policy_map_it != thread_local_policies_.end()) {
|
||||
ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
|
||||
tf_shared_lock l(policy_map_mu_);
|
||||
auto policy_map_it =
|
||||
device_placement_policy_.find(std::this_thread::get_id());
|
||||
if (policy_map_it != device_placement_policy_.end()) {
|
||||
return policy_map_it->second;
|
||||
}
|
||||
return policy_;
|
||||
return default_device_placement_policy_;
|
||||
}
|
||||
|
||||
void EagerContext::SetThreadLocalMirroringPolicy(
|
||||
ContextMirroringPolicy policy) {
|
||||
mutex_lock ml(policy_map_mu_);
|
||||
mirroring_policy_[std::this_thread::get_id()] = policy;
|
||||
}
|
||||
|
||||
ContextMirroringPolicy EagerContext::GetMirroringPolicy() const {
|
||||
tf_shared_lock l(policy_map_mu_);
|
||||
auto policy_map_it = mirroring_policy_.find(std::this_thread::get_id());
|
||||
if (policy_map_it != mirroring_policy_.end()) {
|
||||
return policy_map_it->second;
|
||||
}
|
||||
return default_mirroring_policy_;
|
||||
}
|
||||
|
||||
bool EagerContext::MirrorTensors() const {
|
||||
return GetMirroringPolicy() == MIRRORING_ALL;
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
|
@ -62,7 +62,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note: there's a copy enum in eager/c_api.h. It should be kept in sync.
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
enum ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
@ -74,6 +75,19 @@ enum ContextDevicePlacementPolicy {
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api_experimental.h.
|
||||
enum ContextMirroringPolicy {
|
||||
// Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
|
||||
// copies with their own lifetime.
|
||||
MIRRORING_NONE = 0,
|
||||
// Mirroring any remote tensor handles, associating them with the lifetime of
|
||||
// the local TensorHandle.
|
||||
MIRRORING_ALL = 1,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api_experimental.h)
|
||||
|
||||
class RunMetadataListener {
|
||||
public:
|
||||
@ -84,8 +98,10 @@ class RunMetadataListener {
|
||||
class EagerContext : public core::RefCounted {
|
||||
public:
|
||||
EagerContext(
|
||||
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
|
||||
bool async, const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator = nullptr,
|
||||
@ -128,7 +144,15 @@ class EagerContext : public core::RefCounted {
|
||||
void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
|
||||
|
||||
// Returns the device placement policy for the current thread.
|
||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy();
|
||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
||||
|
||||
// Sets the implicit copy policy for the current thread.
|
||||
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
|
||||
|
||||
// Returns the implicit copy policy for the current thread.
|
||||
ContextMirroringPolicy GetMirroringPolicy() const;
|
||||
|
||||
bool MirrorTensors() const;
|
||||
|
||||
Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
|
||||
|
||||
@ -284,13 +308,16 @@ class EagerContext : public core::RefCounted {
|
||||
void InitDeviceMapAndAsync();
|
||||
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
||||
|
||||
const ContextDevicePlacementPolicy policy_;
|
||||
const ContextDevicePlacementPolicy default_device_placement_policy_;
|
||||
const ContextMirroringPolicy default_mirroring_policy_;
|
||||
|
||||
// Note: we cannot use C++11 thread_local here as there is no concept of a
|
||||
// thread-local-object-local variable in C++11.
|
||||
mutex policy_map_mu_;
|
||||
mutable mutex policy_map_mu_;
|
||||
std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
|
||||
thread_local_policies_ GUARDED_BY(policy_map_mu_);
|
||||
device_placement_policy_ GUARDED_BY(policy_map_mu_);
|
||||
std::unordered_map<std::thread::id, ContextMirroringPolicy> mirroring_policy_
|
||||
GUARDED_BY(policy_map_mu_);
|
||||
|
||||
// Only one of the below is set.
|
||||
std::unique_ptr<const DeviceMgr> local_device_manager_;
|
||||
|
@ -150,8 +150,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op,
|
||||
// trigger a copy.
|
||||
auto pre_time_nanos = Env::Default()->NowNanos();
|
||||
TensorHandle* result_handle = nullptr;
|
||||
Status status = EagerCopyToDevice(
|
||||
handle, ctx, expected_input_device->name().c_str(), &result_handle);
|
||||
Status status =
|
||||
EagerCopyToDevice(handle, ctx, expected_input_device->name().c_str(),
|
||||
ctx->MirrorTensors(), &result_handle);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
@ -510,7 +511,7 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, ctx, device == nullptr ? "" : device->name().c_str(),
|
||||
&handle));
|
||||
ctx->MirrorTensors(), &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
@ -677,11 +678,20 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
// this function enables sending tensors using the EagerService.SendTensor RPC
|
||||
// *on the receiver*.
|
||||
Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
|
||||
Device* recv_device, TensorHandle** result) {
|
||||
Device* recv_device, bool mirror,
|
||||
TensorHandle** result) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
return errors::Unimplemented(
|
||||
"Eager's remote execution is not available on mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
if (mirror) {
|
||||
if (h->HasRemoteMirror(recv_device)) {
|
||||
h->Ref();
|
||||
*result = h;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
eager::EagerClient* eager_client;
|
||||
uint64 context_id;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -726,9 +736,17 @@ Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
|
||||
n.WaitForNotification();
|
||||
if (!status.ok()) return status;
|
||||
|
||||
status = TensorHandle::CreateRemoteHandle(
|
||||
id, 0, tensor->shape(), eager_client, context_id, tensor->dtype(),
|
||||
recv_device, nullptr, ctx, result);
|
||||
auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>(
|
||||
id, 0, tensor->shape(), eager_client, context_id, ctx);
|
||||
if (mirror) {
|
||||
status = h->AddRemoteMirror(std::move(tensor_handle_data), recv_device);
|
||||
h->Ref();
|
||||
*result = h;
|
||||
} else {
|
||||
status = TensorHandle::CreateRemoteHandle(std::move(tensor_handle_data),
|
||||
tensor->dtype(), recv_device,
|
||||
nullptr, ctx, result);
|
||||
}
|
||||
|
||||
actual_handle->Unref();
|
||||
|
||||
@ -757,6 +775,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
auto* remote_op = request->add_queue()->mutable_operation();
|
||||
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
tensorflow::Device* input_device = op->Inputs()[i]->device();
|
||||
if (op->Device() != input_device &&
|
||||
// If the expected and actual devices are on the same task, don't
|
||||
@ -776,15 +795,14 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
op, op->Device()->name(), i, remote_cpu_device,
|
||||
/* run_metadata= */ nullptr, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
input = handle;
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
|
||||
tensorflow::int64 op_id;
|
||||
int32 output_num;
|
||||
TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
|
||||
TF_RETURN_IF_ERROR(input->RemoteAddress(op->Device(), &op_id, &output_num));
|
||||
|
||||
auto* remote_op_input = remote_op->add_inputs();
|
||||
remote_op_input->set_op_id(op_id);
|
||||
@ -849,6 +867,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
handle->Ref();
|
||||
}
|
||||
|
||||
// TODO(gjn): If the retval TensorHandle is simply going to be used as a
|
||||
// mirror then there should be no need to call SetRemoteShape
|
||||
// Unable to capture via std::move, so bind instead.
|
||||
auto* node = new eager::RemoteExecuteNode(
|
||||
remote_node_id, std::move(request), eager_client,
|
||||
@ -858,8 +878,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
const eager::EnqueueResponse& response) {
|
||||
for (int i = 0; i < retvals.size(); i++) {
|
||||
if (status.ok()) {
|
||||
retvals[i]->SetRemoteShape(
|
||||
Status s = retvals[i]->SetRemoteShape(
|
||||
response.queue_response(0).shape(i));
|
||||
if (!s.ok()) {
|
||||
retvals[i]->Poison(s);
|
||||
}
|
||||
} else {
|
||||
retvals[i]->Poison(status);
|
||||
}
|
||||
@ -1250,7 +1273,8 @@ string GetUniqueWireID() {
|
||||
} // namespace
|
||||
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
const char* device_name, TensorHandle** result) {
|
||||
const char* device_name, bool mirror,
|
||||
TensorHandle** result) {
|
||||
tensorflow::Device* send_device = h->device();
|
||||
|
||||
if (send_device == nullptr) {
|
||||
@ -1267,7 +1291,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
if (sender_is_local && recver_is_local) {
|
||||
return LocalEagerCopyToDevice(h, ctx, recv_device, result);
|
||||
} else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
|
||||
return EagerRemoteSendTensor(ctx, h, recv_device, result);
|
||||
return EagerRemoteSendTensor(ctx, h, recv_device, mirror, result);
|
||||
} else {
|
||||
string wire_id = GetUniqueWireID();
|
||||
|
||||
|
@ -50,9 +50,14 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
GraphCollector* graph_collector,
|
||||
TensorHandle** retvals, int num_retvals);
|
||||
|
||||
// Low-level utility to copy a tensor handle from one device to another.
|
||||
// Low-level utility to copy a tensor handle from one device to another. If
|
||||
// successful, result TensorHandle will be populated. If the caller requests for
|
||||
// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
|
||||
// original handle and update *result to point to h. Since this is not
|
||||
// guaranteed, callers should always use the value in *result.
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
const char* device_name, TensorHandle** result);
|
||||
const char* device_name, bool mirror,
|
||||
TensorHandle** result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -140,6 +140,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
VLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
@ -160,6 +161,7 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
resource_handle_container_(resource_handle.container()),
|
||||
resource_handle_name_(resource_handle.name()),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
VLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
@ -189,9 +191,19 @@ TensorHandle::TensorHandle(std::unique_ptr<AsyncLocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
VLOG(3) << "Creating Async Local TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status TensorHandle::CreateRemoteHandle(
|
||||
std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
|
||||
Device* resource_device, EagerContext* ctx, TensorHandle** h) {
|
||||
*h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
|
||||
const TensorShape& shape,
|
||||
eager::EagerClient* eager_client,
|
||||
@ -217,6 +229,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
VLOG(3) << "Creating Remote TensorHandle: " << this << " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
@ -247,7 +260,10 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||
remote_context_id_(t->context_id()),
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
tensor_handle_data_(std::move(t)) {}
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
VLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
}
|
||||
#endif
|
||||
|
||||
TensorHandle::TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype)
|
||||
@ -262,6 +278,7 @@ TensorHandle::TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype)
|
||||
ctx_(nullptr),
|
||||
is_remote_(false),
|
||||
symbolic_tensor_(new OutputGraphNode(symbolic_tensor)) {
|
||||
VLOG(3) << "Creating Symbolic TensorHandle: " << this;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
@ -305,10 +322,19 @@ Status TensorHandle::NumElements(int64* num_elements) {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) const {
|
||||
if (!is_remote_) {
|
||||
Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
|
||||
int32* output_num) const {
|
||||
if (d != device_) {
|
||||
mutex_lock l(remote_mirrors_mutex_);
|
||||
auto mirror = remote_mirrors_.find(d);
|
||||
if (mirror != remote_mirrors_.end()) {
|
||||
*op_id = mirror->second->op_id();
|
||||
*output_num = mirror->second->output_num();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return errors::FailedPrecondition(
|
||||
"This TensorHandle refers to a local tensor handle");
|
||||
"Could not find remote mirror for specified device");
|
||||
}
|
||||
|
||||
*op_id = remote_op_id_;
|
||||
@ -316,7 +342,28 @@ Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TensorHandle::SetRemoteShape(const TensorShape& shape) {
|
||||
bool TensorHandle::HasRemoteMirror(Device* d) {
|
||||
mutex_lock l(remote_mirrors_mutex_);
|
||||
auto mirror = remote_mirrors_.find(d);
|
||||
if (mirror != remote_mirrors_.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
Device* d) {
|
||||
mutex_lock l(remote_mirrors_mutex_);
|
||||
auto ret = remote_mirrors_.insert(std::make_pair(d, std::move(t)));
|
||||
if (!ret.second) {
|
||||
return errors::Internal("Attempted to duplicate a remote mirror.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::SetRemoteShape(const TensorShape& shape) {
|
||||
DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
|
||||
DCHECK(!is_ready_notification_.HasBeenNotified())
|
||||
<< "SetRemoteShape is only called on non-ready handles.";
|
||||
@ -330,6 +377,8 @@ void TensorHandle::SetRemoteShape(const TensorShape& shape) {
|
||||
remote_context_id_, ctx_);
|
||||
is_poisoned_ = Status::OK();
|
||||
is_ready_notification_.Notify();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -105,6 +105,10 @@ class TensorHandle : public core::RefCounted {
|
||||
uint64 context_id, DataType dtype, Device* d,
|
||||
Device* resource_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
DataType dtype, Device* d,
|
||||
Device* resource_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
eager::EagerClient* eager_client,
|
||||
uint64 context_id, DataType dtype,
|
||||
@ -115,9 +119,7 @@ class TensorHandle : public core::RefCounted {
|
||||
// Symbolic tensor constructor.
|
||||
TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype);
|
||||
|
||||
~TensorHandle() override {
|
||||
VLOG(3) << "Deleting internal TensorHandle " << this;
|
||||
}
|
||||
~TensorHandle() override { VLOG(3) << "Deleting TensorHandle " << this; }
|
||||
|
||||
Status Tensor(const tensorflow::Tensor** t);
|
||||
|
||||
@ -134,8 +136,13 @@ class TensorHandle : public core::RefCounted {
|
||||
Status NumElements(int64* num_elements);
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool HasRemoteMirror(Device* d);
|
||||
// TODO(gjn): Add Unshaped remote mirrors once EagerRemoteSendTensor supports
|
||||
// async execution and EagerRemoteExecute is mirror-aware.
|
||||
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Device* d);
|
||||
|
||||
// Return the op_id and output num if the handle refers to a remote tensor.
|
||||
Status RemoteAddress(int64* op_id, int32* output_num) const;
|
||||
Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const;
|
||||
|
||||
// Called on an async remote tensor once it's shape has been determined. This
|
||||
// transitions the tensor handle from a non-ready to a ready state by
|
||||
@ -143,7 +150,7 @@ class TensorHandle : public core::RefCounted {
|
||||
// queried.
|
||||
// This method or Poison must be called exactly once for remote tensors that
|
||||
// were created without a known shape.
|
||||
void SetRemoteShape(const TensorShape& shape);
|
||||
Status SetRemoteShape(const TensorShape& shape);
|
||||
#endif
|
||||
|
||||
// Sets the `tensor` for this async non-ready handle making it ready.
|
||||
@ -212,6 +219,10 @@ class TensorHandle : public core::RefCounted {
|
||||
tensorflow::Device* const resource_device_;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
mutable mutex remote_mirrors_mutex_;
|
||||
std::map<tensorflow::Device*, std::unique_ptr<RemoteTensorHandleData>>
|
||||
remote_mirrors_ GUARDED_BY(remote_mirrors_mutex_);
|
||||
|
||||
// IDs required when this class is representing a remote tensor handle.
|
||||
const int64 remote_op_id_;
|
||||
const int32 remote_output_num_;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
@ -121,9 +122,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
request->async(), device_mgr, false, r, nullptr,
|
||||
worker_session->cluster_flr.get(), std::move(rendezvous_creator),
|
||||
worker_session->remote_device_mgr());
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
|
||||
device_mgr, false, r, nullptr, worker_session->cluster_flr.get(),
|
||||
std::move(rendezvous_creator), worker_session->remote_device_mgr());
|
||||
|
||||
std::vector<DeviceAttributes> device_attributes;
|
||||
device_mgr->ListDeviceAttributes(&device_attributes);
|
||||
@ -303,7 +304,7 @@ Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle));
|
||||
TensorHandle* copied_handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
|
||||
request->device_name().c_str(),
|
||||
request->device_name().c_str(), false,
|
||||
&copied_handle));
|
||||
tensors.push_back(copied_handle);
|
||||
tensor_handle->Unref();
|
||||
|
@ -77,6 +77,7 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData {
|
||||
int32 output_num() const { return output_num_; }
|
||||
eager::EagerClient* eager_client() const { return eager_client_; }
|
||||
uint64 context_id() const { return context_id_; }
|
||||
EagerContext* ctx() const { return ctx_; }
|
||||
|
||||
// When constructed, UnshapedRemoteTensorHandleData owns the remote
|
||||
// TensorHandle and should delete it by issuing an RPC. Once the remote
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -45,6 +46,7 @@ tensorflow::Status DelegateData::Prepare(
|
||||
eager_context_ = new tensorflow::EagerContext(
|
||||
session_options,
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
|
||||
rendezvous, nullptr);
|
||||
return tensorflow::Status();
|
||||
|
@ -480,6 +480,7 @@ cuda_py_test(
|
||||
":function",
|
||||
":test",
|
||||
":profiler",
|
||||
":remote",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
|
@ -25,6 +25,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
@ -40,6 +41,7 @@ from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import profiler
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -52,6 +54,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
CPU = "/device:CPU:0"
|
||||
GPU = "/device:GPU:0"
|
||||
@ -132,6 +135,23 @@ def make_sequential_keras_model(initializer="ones"):
|
||||
return model
|
||||
|
||||
|
||||
def run_benchmark(func, num_iters, execution_mode=None):
|
||||
ctx = context.context()
|
||||
with context.execution_mode(execution_mode):
|
||||
# call func to maybe warm up the GPU
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.async_wait()
|
||||
start = time.time()
|
||||
for _ in xrange(num_iters):
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.async_wait()
|
||||
end = time.time()
|
||||
|
||||
return end - start
|
||||
|
||||
|
||||
class MicroBenchmarks(test.Benchmark):
|
||||
|
||||
def __init__(self):
|
||||
@ -145,23 +165,12 @@ class MicroBenchmarks(test.Benchmark):
|
||||
self._num_iters_100_by_784 = 30000
|
||||
|
||||
def _run(self, func, num_iters, execution_mode=None):
|
||||
# call func to maybe warm up the GPU
|
||||
ctx = context.context()
|
||||
with context.execution_mode(execution_mode):
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.async_wait()
|
||||
start = time.time()
|
||||
for _ in xrange(num_iters):
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.async_wait()
|
||||
end = time.time()
|
||||
mean_us = (end - start) * 1e6 / num_iters
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=mean_us,
|
||||
extras={"examples_per_sec": num_iters / (end - start)})
|
||||
total_time = run_benchmark(func, num_iters, execution_mode)
|
||||
mean_us = total_time * 1e6 / num_iters
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=mean_us,
|
||||
extras={"examples_per_sec": num_iters / total_time})
|
||||
|
||||
def benchmark_create_np_array(self):
|
||||
func = lambda: np.array([3.0])
|
||||
@ -943,5 +952,54 @@ class MicroBenchmarks(test.Benchmark):
|
||||
self._benchmarkFunctionWithResourceInputs(500, 100)
|
||||
|
||||
|
||||
class RemoteWorkerMicroBenchmarks(test.Benchmark):
|
||||
|
||||
def __init__(self):
|
||||
# used for remote benchmarks
|
||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||
self._cached_server = server_lib.Server.create_local_server()
|
||||
self._cached_server_target = self._cached_server.target[len("grpc://"):]
|
||||
|
||||
def _run(self, func, num_iters=10000, execution_mode=None):
|
||||
total_time = run_benchmark(func, num_iters, execution_mode)
|
||||
mean_us = total_time * 1e6 / num_iters
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=mean_us,
|
||||
extras={"examples_per_sec": num_iters / total_time})
|
||||
|
||||
def benchmark_mirroring_off(self):
|
||||
remote.connect_to_remote_host(self._cached_server_target)
|
||||
|
||||
x = random_ops.random_uniform((2, 2)).cpu()
|
||||
|
||||
@def_function.function
|
||||
def remote_func(m):
|
||||
return math_ops.matmul(m, m)
|
||||
|
||||
def func(m):
|
||||
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
|
||||
return remote_func(m)
|
||||
|
||||
context.context().mirroring_policy = context.MIRRORING_NONE
|
||||
self._run(lambda: func(x))
|
||||
|
||||
def benchmark_mirroring_on(self):
|
||||
remote.connect_to_remote_host(self._cached_server_target)
|
||||
|
||||
x = random_ops.random_uniform((2, 2)).cpu()
|
||||
|
||||
@def_function.function
|
||||
def remote_func(m):
|
||||
return math_ops.matmul(m, m)
|
||||
|
||||
def func(m):
|
||||
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
|
||||
return remote_func(m)
|
||||
|
||||
context.context().mirroring_policy = context.MIRRORING_ALL
|
||||
self._run(lambda: func(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -57,9 +57,13 @@ DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
|
||||
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
|
||||
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
|
||||
|
||||
SYNC = 0
|
||||
ASYNC = 1
|
||||
|
||||
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
||||
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
||||
|
||||
_tf2_gauge = monitoring.BoolGauge("/tensorflow/api/tf2_enable",
|
||||
"Whether tf2.enable() is called.")
|
||||
|
||||
@ -334,6 +338,7 @@ class Context(object):
|
||||
if device_policy is None:
|
||||
device_policy = DEVICE_PLACEMENT_SILENT
|
||||
self._device_policy = device_policy
|
||||
self._mirroring_policy = None
|
||||
if execution_mode not in (None, SYNC, ASYNC):
|
||||
raise ValueError(
|
||||
"execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
|
||||
@ -425,6 +430,9 @@ class Context(object):
|
||||
if self._device_policy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
opts, self._device_policy)
|
||||
if self._mirroring_policy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy(
|
||||
opts, self._mirroring_policy)
|
||||
if self._execution_mode == ASYNC:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
|
||||
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
|
||||
@ -1333,6 +1341,27 @@ class Context(object):
|
||||
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
self._handle, self._device_policy)
|
||||
|
||||
@property
|
||||
def mirroring_policy(self):
|
||||
# Only get the policy from the context if it has already been initialized
|
||||
if self._context_handle is not None:
|
||||
return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle)
|
||||
|
||||
return self._device_policy
|
||||
|
||||
@mirroring_policy.setter
|
||||
def mirroring_policy(self, policy):
|
||||
if policy is None:
|
||||
policy = MIRRORING_NONE
|
||||
|
||||
if self._mirroring_policy != policy:
|
||||
self._mirroring_policy = policy
|
||||
|
||||
# Only set the policy if the context has already been initialized
|
||||
if self._context_handle is not None:
|
||||
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
self._handle, self._mirroring_policy)
|
||||
|
||||
def enable_run_metadata(self):
|
||||
"""Enables tracing of op execution via RunMetadata.
|
||||
|
||||
@ -1623,6 +1652,18 @@ def device_policy(policy):
|
||||
ctx.device_policy = old_policy
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def mirroring_policy(policy):
|
||||
"""Context manager for setting mirroring policy for current thread."""
|
||||
ctx = context()
|
||||
old_policy = ctx.mirroring_policy
|
||||
try:
|
||||
ctx.mirroring_policy = policy
|
||||
yield
|
||||
finally:
|
||||
ctx.mirroring_policy = old_policy
|
||||
|
||||
|
||||
def set_execution_mode(mode):
|
||||
"""Sets execution mode for the current thread."""
|
||||
context().execution_mode = mode
|
||||
|
@ -38,7 +38,9 @@ limitations under the License.
|
||||
%rename("%s") TFE_ContextExportRunMetadata;
|
||||
%rename("%s") TFE_ContextClearCaches;
|
||||
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextGetMirroringPolicy;
|
||||
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
|
||||
%rename("%s") TFE_ContextSetAsyncForThread;
|
||||
%rename("%s") TFE_ContextSetServerDef;
|
||||
%rename("%s") TFE_ContextAsyncWait;
|
||||
@ -86,6 +88,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_NewContextOptions;
|
||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetAsync;
|
||||
%rename("%s") TFE_DeleteContextOptions;
|
||||
%rename("%s") TFE_Py_TensorShapeSlice;
|
||||
@ -321,6 +324,10 @@ static PyObject* TF_ListPhysicalDevices(TF_Status* status);
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
|
||||
|
||||
%rename("%s") TFE_ContextMirroringPolicy;
|
||||
%rename("%s") TFE_MIRRORING_NONE;
|
||||
%rename("%s") TFE_MIRRORING_ALL;
|
||||
|
||||
%include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||
|
Loading…
Reference in New Issue
Block a user