Remove the internal flag of lazy_remote_inputs_copy.
PiperOrigin-RevId: 350591934 Change-Id: Ia1023a6dc6d20309248d2eb2e2300a6a55a7c2ac
This commit is contained in:
parent
ce9122eb7b
commit
13d37279f1
@ -142,7 +142,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||||||
opts->session_options.options,
|
opts->session_options.options,
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||||
opts->device_placement_policy),
|
opts->device_placement_policy),
|
||||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
opts->async, device_mgr.release(),
|
||||||
/*device_mgr_owned*/ true, r);
|
/*device_mgr_owned*/ true, r);
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
eager_context->SetDistributedManager(
|
eager_context->SetDistributedManager(
|
||||||
|
@ -482,11 +482,6 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
|||||||
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
|
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
|
|
||||||
bool lazy_copy) {
|
|
||||||
options->lazy_remote_inputs_copy = lazy_copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
|
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
|
||||||
options->use_tfrt = use_tfrt;
|
options->use_tfrt = use_tfrt;
|
||||||
}
|
}
|
||||||
|
@ -265,10 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
|
|||||||
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
|
||||||
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
|
||||||
|
|
||||||
// Sets whether to copy the remote inputs of a function lazily.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
|
||||||
TFE_ContextOptions*, bool lazy_copy);
|
|
||||||
|
|
||||||
// Sets whether to use TFRT
|
// Sets whether to use TFRT
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
|
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
|
||||||
bool use_tfrt);
|
bool use_tfrt);
|
||||||
|
@ -32,8 +32,6 @@ struct TFE_ContextOptions {
|
|||||||
bool async = false;
|
bool async = false;
|
||||||
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
||||||
TFE_DEVICE_PLACEMENT_SILENT};
|
TFE_DEVICE_PLACEMENT_SILENT};
|
||||||
// If true, lazily copy the remote inputs of a function to the target devices.
|
|
||||||
bool lazy_remote_inputs_copy = true;
|
|
||||||
// If true, use TFRT backend
|
// If true, use TFRT backend
|
||||||
bool use_tfrt = false;
|
bool use_tfrt = false;
|
||||||
};
|
};
|
||||||
|
@ -45,8 +45,7 @@ EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
|
|||||||
return EagerContextPtr(new EagerContext(
|
return EagerContextPtr(new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/* async= */ false,
|
/* async= */ false, device_mgr,
|
||||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
|
|
||||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
/* cluster_flr= */ nullptr));
|
/* cluster_flr= */ nullptr));
|
||||||
}
|
}
|
||||||
|
@ -76,8 +76,7 @@ auto* eager_context_created =
|
|||||||
EagerContext::EagerContext(
|
EagerContext::EagerContext(
|
||||||
const SessionOptions& opts,
|
const SessionOptions& opts,
|
||||||
ContextDevicePlacementPolicy default_device_placement_policy, bool async,
|
ContextDevicePlacementPolicy default_device_placement_policy, bool async,
|
||||||
const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
|
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
|
||||||
bool device_mgr_owned, Rendezvous* rendezvous,
|
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr)
|
DistributedFunctionLibraryRuntime* cluster_flr)
|
||||||
: ImmediateExecutionContext(kEager),
|
: ImmediateExecutionContext(kEager),
|
||||||
opts_(opts),
|
opts_(opts),
|
||||||
@ -95,7 +94,6 @@ EagerContext::EagerContext(
|
|||||||
default_executor_(async),
|
default_executor_(async),
|
||||||
log_memory_(LogMemory::IsEnabled()),
|
log_memory_(LogMemory::IsEnabled()),
|
||||||
env_(opts.env),
|
env_(opts.env),
|
||||||
lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
|
|
||||||
use_send_tensor_rpc_(false),
|
use_send_tensor_rpc_(false),
|
||||||
pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
|
pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
|
||||||
"TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
|
"TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
|
||||||
@ -326,7 +324,7 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
|
|||||||
|
|
||||||
void EagerContext::ResetClusterFLR(
|
void EagerContext::ResetClusterFLR(
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||||
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
|
cluster_flr_.Reset(cluster_flr, /*owned=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
EagerExecutor& EagerContext::Executor() {
|
EagerExecutor& EagerContext::Executor() {
|
||||||
@ -410,10 +408,6 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
|
|||||||
return default_device_placement_policy_;
|
return default_device_placement_policy_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EagerContext::LazyCopyFunctionRemoteInputs() const {
|
|
||||||
return lazy_copy_function_remote_inputs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
std::vector<string> EagerContext::GetRemoteContexts() {
|
std::vector<string> EagerContext::GetRemoteContexts() {
|
||||||
tf_shared_lock l(remote_state_mu_);
|
tf_shared_lock l(remote_state_mu_);
|
||||||
|
@ -96,8 +96,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
|
|
||||||
EagerContext(const SessionOptions& opts,
|
EagerContext(const SessionOptions& opts,
|
||||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||||
bool async, const bool lazy_copy_function_remote_inputs,
|
bool async, const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||||
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
|
||||||
Rendezvous* rendezvous,
|
Rendezvous* rendezvous,
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
|
DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
|
||||||
|
|
||||||
@ -190,8 +189,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
Status SelectDevice(DeviceNameUtils::ParsedName preferred,
|
Status SelectDevice(DeviceNameUtils::ParsedName preferred,
|
||||||
const NodeDef& ndef, Device** out) const;
|
const NodeDef& ndef, Device** out) const;
|
||||||
|
|
||||||
bool LazyCopyFunctionRemoteInputs() const;
|
|
||||||
|
|
||||||
bool FindFunctionByName(const string& name) const;
|
bool FindFunctionByName(const string& name) const;
|
||||||
|
|
||||||
Status FindFunctionOpData(const string& name,
|
Status FindFunctionOpData(const string& name,
|
||||||
|
@ -217,7 +217,6 @@ tensorflow::Status CreateRemoteContexts(
|
|||||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||||
const bool lazy_copy_remote_function_inputs,
|
|
||||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||||
int num_remote_workers = remote_workers.size();
|
int num_remote_workers = remote_workers.size();
|
||||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||||
@ -269,8 +268,9 @@ tensorflow::Status CreateRemoteContexts(
|
|||||||
}
|
}
|
||||||
request.set_async(async);
|
request.set_async(async);
|
||||||
request.set_keep_alive_secs(keep_alive_secs);
|
request.set_keep_alive_secs(keep_alive_secs);
|
||||||
request.set_lazy_copy_remote_function_inputs(
|
// TODO(b/134094971): deprecate lazy_copy_remote_function_inputs when server
|
||||||
lazy_copy_remote_function_inputs);
|
// doesn't try to get the value of lazy_copy_remote_function_inputs.
|
||||||
|
request.set_lazy_copy_remote_function_inputs(true);
|
||||||
|
|
||||||
eager_client->CreateContextAsync(
|
eager_client->CreateContextAsync(
|
||||||
&request, response,
|
&request, response,
|
||||||
@ -557,7 +557,7 @@ tensorflow::Status UpdateContextWithServerDef(
|
|||||||
const tensorflow::Status s = CreateRemoteContexts(
|
const tensorflow::Status s = CreateRemoteContexts(
|
||||||
context, remote_workers, context_id, context_view_id, keep_alive_secs,
|
context, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
base_request);
|
||||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||||
// directly returning the error, since returning here will cause the server
|
// directly returning the error, since returning here will cause the server
|
||||||
@ -582,8 +582,7 @@ tensorflow::Status UpdateContextWithServerDef(
|
|||||||
sg.Update(CreateRemoteContexts(
|
sg.Update(CreateRemoteContexts(
|
||||||
context, added_workers, context_id, context_view_id + 1,
|
context, added_workers, context_id, context_view_id + 1,
|
||||||
keep_alive_secs, server_def, remote_eager_workers.get(),
|
keep_alive_secs, server_def, remote_eager_workers.get(),
|
||||||
context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(),
|
context->Executor().Async(), base_request));
|
||||||
base_request));
|
|
||||||
}
|
}
|
||||||
if (!existing_workers.empty()) {
|
if (!existing_workers.empty()) {
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
|
@ -58,12 +58,11 @@ class EagerContextTest : public ::testing::Test {
|
|||||||
ContextDevicePlacementPolicy policy) {
|
ContextDevicePlacementPolicy policy) {
|
||||||
ASSERT_EQ(context_, nullptr);
|
ASSERT_EQ(context_, nullptr);
|
||||||
InitDeviceManager();
|
InitDeviceManager();
|
||||||
context_ = new EagerContext(
|
context_ =
|
||||||
opts, policy,
|
new EagerContext(opts, policy,
|
||||||
/* async */ false,
|
/* async */ false, device_manager_,
|
||||||
/* lazy_copy_function_remote_inputs */ false, device_manager_,
|
/* device_mgr_owned */ false, /* rendezvous */ nullptr,
|
||||||
/* device_mgr_owned */ false, /* rendezvous */ nullptr,
|
/* cluster_flr */ nullptr);
|
||||||
/* cluster_flr */ nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -77,7 +77,7 @@ TEST(CustomDevice, TestTensorHandle) {
|
|||||||
core::RefCountPtr<EagerContext> ctx(new EagerContext(
|
core::RefCountPtr<EagerContext> ctx(new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &device_mgr, false, nullptr, nullptr));
|
&device_mgr, false, nullptr, nullptr));
|
||||||
std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
|
std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
|
||||||
TestCustomDevice device(device_name);
|
TestCustomDevice device(device_name);
|
||||||
core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
|
core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
|
||||||
|
@ -48,7 +48,7 @@ TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) {
|
|||||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &device_mgr, false, nullptr, nullptr);
|
&device_mgr, false, nullptr, nullptr);
|
||||||
EagerOperation orig_op(ctx);
|
EagerOperation orig_op(ctx);
|
||||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||||
EXPECT_EQ(Status::OK(),
|
EXPECT_EQ(Status::OK(),
|
||||||
|
@ -28,7 +28,7 @@ TEST(EagerOperationTest, DeviceName) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &device_mgr, false, nullptr, nullptr);
|
&device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
auto op = new EagerOperation(ctx);
|
auto op = new EagerOperation(ctx);
|
||||||
|
|
||||||
|
@ -197,8 +197,7 @@ Status ValidateInputTypeAndPlacement(
|
|||||||
return errors::InvalidArgument("expected ", kernel->num_inputs(),
|
return errors::InvalidArgument("expected ", kernel->num_inputs(),
|
||||||
" inputs, got ", n_inputs);
|
" inputs, got ", n_inputs);
|
||||||
}
|
}
|
||||||
const bool skip_remote_copy =
|
const bool is_function = kernel->IsFunction();
|
||||||
ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction();
|
|
||||||
if (n_inputs > 0) {
|
if (n_inputs > 0) {
|
||||||
const DataType* input_types = &kernel->input_dtypes()[0];
|
const DataType* input_types = &kernel->input_dtypes()[0];
|
||||||
TensorHandle* const* handles = &op->Inputs()[0];
|
TensorHandle* const* handles = &op->Inputs()[0];
|
||||||
@ -229,7 +228,7 @@ Status ValidateInputTypeAndPlacement(
|
|||||||
}
|
}
|
||||||
Device* handle_device = absl::get<Device*>(handle_device_variant);
|
Device* handle_device = absl::get<Device*>(handle_device_variant);
|
||||||
const bool maybe_copy =
|
const bool maybe_copy =
|
||||||
!skip_remote_copy || handle->Type() != TensorHandle::REMOTE;
|
!is_function || handle->Type() != TensorHandle::REMOTE;
|
||||||
// If the input is already on the right device, then nothing to do.
|
// If the input is already on the right device, then nothing to do.
|
||||||
if (expected_device != handle_device && maybe_copy) {
|
if (expected_device != handle_device && maybe_copy) {
|
||||||
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
|
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
|
||||||
@ -432,23 +431,8 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
|
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
input_dev_ptrs.reserve(op->Inputs().size());
|
input_dev_ptrs.reserve(op->Inputs().size());
|
||||||
// When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on
|
|
||||||
// local devices, since we execute a remote function through worker service,
|
|
||||||
// which doesn't accept remote inputs.
|
|
||||||
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
||||||
TensorHandle* input = op->Inputs()[i];
|
TensorHandle* input = op->Inputs()[i];
|
||||||
if (!ctx.LazyCopyFunctionRemoteInputs() &&
|
|
||||||
input->Type() == TensorHandle::REMOTE) {
|
|
||||||
TensorHandle* handle = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
EagerCopyToDevice(input, &ctx, &op->Executor(),
|
|
||||||
device == nullptr ? ctx.HostCPU() : device,
|
|
||||||
/*mirror=*/true, &handle));
|
|
||||||
op->UpdateInput(i, handle);
|
|
||||||
// Unref handle since it has a ref as an input now
|
|
||||||
handle->Unref();
|
|
||||||
input = handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get device for this input, and add it to 'cache_key'.
|
// Get device for this input, and add it to 'cache_key'.
|
||||||
Device* input_device;
|
Device* input_device;
|
||||||
@ -549,9 +533,7 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
<< "Full node_def=" << ndef.DebugString();
|
<< "Full node_def=" << ndef.DebugString();
|
||||||
std::function<int64()> get_op_id = nullptr;
|
std::function<int64()> get_op_id = nullptr;
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
if (ctx.LazyCopyFunctionRemoteInputs()) {
|
get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
|
||||||
get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
|
|
||||||
}
|
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
kernel.reset(new KernelAndDeviceFunc(
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
flr, ctx.pflr(), std::move(input_dev_ptrs),
|
flr, ctx.pflr(), std::move(input_dev_ptrs),
|
||||||
@ -569,9 +551,8 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(kernel->Init(
|
TF_RETURN_IF_ERROR(
|
||||||
{ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef,
|
kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
|
||||||
graph_collector));
|
|
||||||
|
|
||||||
if (op->is_function()) {
|
if (op->is_function()) {
|
||||||
ctx.AddKernelToCache(cache_key, kernel.get());
|
ctx.AddKernelToCache(cache_key, kernel.get());
|
||||||
@ -873,8 +854,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
{
|
{
|
||||||
profiler::TraceMe activity("CopyInputToExpectedDevice",
|
profiler::TraceMe activity("CopyInputToExpectedDevice",
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
const bool eagerly_copy_function_remote_inputs =
|
const bool is_function = op->is_function();
|
||||||
!ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
|
|
||||||
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
||||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||||
tensorflow::Device* input_device = absl::get<Device*>(input->device());
|
tensorflow::Device* input_device = absl::get<Device*>(input->device());
|
||||||
@ -887,8 +867,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// explicitly copy, and instead depend on the copy to happen locally
|
// explicitly copy, and instead depend on the copy to happen locally
|
||||||
// when the op is executed on the device.
|
// when the op is executed on the device.
|
||||||
!ctx.OnSameTask(op_device, input_device)) {
|
!ctx.OnSameTask(op_device, input_device)) {
|
||||||
if (eagerly_copy_function_remote_inputs ||
|
if (!is_function || input_device_or_cpu->IsLocal()) {
|
||||||
input_device_or_cpu->IsLocal()) {
|
|
||||||
tensorflow::Device* remote_cpu_device;
|
tensorflow::Device* remote_cpu_device;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
|
ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
|
||||||
@ -967,19 +946,14 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
|
id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx.LazyCopyFunctionRemoteInputs()) {
|
// Store the data type and shape of a remote resource variable on the
|
||||||
// Store the data type and shape of a remote resource variable on the
|
// corresponding remote TensorHandle (output of 'VarHandleOp').
|
||||||
// corresponding remote TensorHandle (output of 'VarHandleOp').
|
// If the variable is an input of a remote function, the function may need
|
||||||
// If the variable is an input of a remote function, the function may need
|
// the type and shape during function instantiation. Store the type and
|
||||||
// the type and shape during function instantiation. When
|
// shape on eager master and sent them to the default function device along
|
||||||
// LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
|
// with the EnqueueRequest.
|
||||||
// handle (contains the type and shape) of the variable to the default
|
TF_RETURN_IF_ERROR(
|
||||||
// function device. Instead, we store the type and shape on eager master
|
StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
|
||||||
// and sent them to the default function device along with the
|
|
||||||
// EnqueueRequest.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& executor = op->Executor();
|
auto& executor = op->Executor();
|
||||||
DVLOG(4) << "Execute remote eager op: " << op->Name()
|
DVLOG(4) << "Execute remote eager op: " << op->Name()
|
||||||
|
@ -68,7 +68,7 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &device_mgr, false, nullptr, nullptr);
|
&device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
// Set a RemoteMgr to the EagerContext.
|
// Set a RemoteMgr to the EagerContext.
|
||||||
auto remote_mgr = absl::make_unique<eager::RemoteMgr>(
|
auto remote_mgr = absl::make_unique<eager::RemoteMgr>(
|
||||||
|
@ -97,7 +97,8 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status KernelAndDeviceOp::Init(const Context& ctx, const NodeDef& ndef,
|
Status KernelAndDeviceOp::Init(const bool log_device_placement,
|
||||||
|
const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) {
|
GraphCollector* graph_collector) {
|
||||||
OpKernel* k = nullptr;
|
OpKernel* k = nullptr;
|
||||||
if (flr_ == nullptr) {
|
if (flr_ == nullptr) {
|
||||||
@ -129,7 +130,7 @@ Status KernelAndDeviceOp::Init(const Context& ctx, const NodeDef& ndef,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement,
|
||||||
const NodeDef& ndef,
|
const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) {
|
GraphCollector* graph_collector) {
|
||||||
const OpDef* op_def = nullptr;
|
const OpDef* op_def = nullptr;
|
||||||
@ -212,18 +213,19 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
->mutable_optimizer_options()
|
->mutable_optimizer_options()
|
||||||
->set_do_function_inlining(true);
|
->set_do_function_inlining(true);
|
||||||
|
|
||||||
options.config_proto.set_log_device_placement(ctx.log_device_placement);
|
options.config_proto.set_log_device_placement(log_device_placement);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
|
pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
|
||||||
return pflr_->IsCrossProcess(handle_, &is_cross_process_);
|
return pflr_->IsCrossProcess(handle_, &is_cross_process_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
|
Status KernelAndDeviceFunc::Init(const bool log_device_placement,
|
||||||
|
const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) {
|
GraphCollector* graph_collector) {
|
||||||
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
|
TF_RETURN_IF_ERROR(
|
||||||
return pflr_->GetOutputDevices(handle_, &output_devices_,
|
InstantiateFunc(log_device_placement, ndef, graph_collector));
|
||||||
ctx.eager_lazy_copy);
|
return pflr_->GetOutputDevices(handle_, &output_devices_);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -97,16 +97,11 @@ typedef absl::variant<Tensor, TensorShape> EagerKernelRet;
|
|||||||
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
|
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
|
||||||
class KernelAndDevice : public core::RefCounted {
|
class KernelAndDevice : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
struct Context {
|
|
||||||
bool log_device_placement = false;
|
|
||||||
bool eager_lazy_copy = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Populates this with a kernel appropriate for 'ndef'.
|
// Populates this with a kernel appropriate for 'ndef'.
|
||||||
//
|
//
|
||||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||||
// Run() on the returned KernelAndDevice.
|
// Run() on the returned KernelAndDevice.
|
||||||
virtual Status Init(const Context& ctx, const NodeDef& ndef,
|
virtual Status Init(const bool log_device_placement, const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) = 0;
|
GraphCollector* graph_collector) = 0;
|
||||||
|
|
||||||
// Non-multi-device functions are run using regular CallOp and look like
|
// Non-multi-device functions are run using regular CallOp and look like
|
||||||
@ -205,7 +200,7 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
|||||||
|
|
||||||
~KernelAndDeviceOp() override {}
|
~KernelAndDeviceOp() override {}
|
||||||
|
|
||||||
Status Init(const Context& ctx, const NodeDef& ndef,
|
Status Init(const bool log_device_placement, const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) override;
|
GraphCollector* graph_collector) override;
|
||||||
|
|
||||||
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
@ -290,10 +285,10 @@ class KernelAndDeviceFunc : public KernelAndDevice {
|
|||||||
|
|
||||||
bool IsCrossProcess() override { return is_cross_process_; }
|
bool IsCrossProcess() override { return is_cross_process_; }
|
||||||
|
|
||||||
Status InstantiateFunc(const Context& ctx, const NodeDef& ndef,
|
Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector);
|
GraphCollector* graph_collector);
|
||||||
|
|
||||||
Status Init(const Context& ctx, const NodeDef& ndef,
|
Status Init(const bool log_device_placement, const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) override;
|
GraphCollector* graph_collector) override;
|
||||||
|
|
||||||
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
|
||||||
|
@ -38,13 +38,12 @@ class EagerOpRewriteTest : public ::testing::Test {
|
|||||||
absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
bool async = false;
|
bool async = false;
|
||||||
bool lazy_remote_tensor_copy = false;
|
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
eager_ctx_ = new tensorflow::EagerContext(
|
eager_ctx_ = new tensorflow::EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
async, lazy_remote_tensor_copy, device_mgr.get(), false, rendezvous);
|
async, device_mgr.get(), false, rendezvous);
|
||||||
|
|
||||||
EagerExecutor executor_(false);
|
EagerExecutor executor_(false);
|
||||||
std::unique_ptr<tensorflow::EagerOperation> op(
|
std::unique_ptr<tensorflow::EagerOperation> op(
|
||||||
|
@ -83,12 +83,11 @@ class PlacementTest : public ::testing::Test {
|
|||||||
ContextDevicePlacementPolicy policy) {
|
ContextDevicePlacementPolicy policy) {
|
||||||
ASSERT_EQ(context_, nullptr);
|
ASSERT_EQ(context_, nullptr);
|
||||||
InitDeviceManager();
|
InitDeviceManager();
|
||||||
context_ = new EagerContext(
|
context_ =
|
||||||
opts, policy,
|
new EagerContext(opts, policy,
|
||||||
/* async */ false,
|
/* async */ false, device_manager_,
|
||||||
/* lazy_copy_function_remote_inputs */ false, device_manager_,
|
/* device_mgr_owned */ false, /* rendezvous */ nullptr,
|
||||||
/* device_mgr_owned */ false, /* rendezvous */ nullptr,
|
/* cluster_flr */ nullptr);
|
||||||
/* cluster_flr */ nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -39,7 +39,7 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &device_mgr, false, nullptr, nullptr);
|
&device_mgr, false, nullptr, nullptr);
|
||||||
TensorHandle* sync_th =
|
TensorHandle* sync_th =
|
||||||
TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx);
|
TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx);
|
||||||
TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle(
|
TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle(
|
||||||
@ -105,8 +105,7 @@ class PackedTensorHandleTest : public ::testing::Test {
|
|||||||
context_ = new EagerContext(
|
context_ = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/* async= */ false,
|
/* async= */ false, device_mgr_,
|
||||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
|
|
||||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
/* cluster_flr= */ nullptr);
|
/* cluster_flr= */ nullptr);
|
||||||
}
|
}
|
||||||
@ -256,7 +255,7 @@ TEST(TensorHandle_ResourceDeviceTest, OnLocalDevice) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &local_device_mgr, false, nullptr, nullptr);
|
&local_device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
tensorflow::DataType dtype = DT_RESOURCE;
|
tensorflow::DataType dtype = DT_RESOURCE;
|
||||||
TensorShape shape = {2};
|
TensorShape shape = {2};
|
||||||
@ -288,7 +287,7 @@ TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &local_device_mgr, false, nullptr, nullptr);
|
&local_device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
std::unique_ptr<Device> d0(
|
std::unique_ptr<Device> d0(
|
||||||
CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false));
|
CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false));
|
||||||
@ -342,8 +341,7 @@ class RemoteTensorHandleTest : public ::testing::Test {
|
|||||||
context_ = new EagerContext(
|
context_ = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/* async= */ false,
|
/* async= */ false, device_mgr_,
|
||||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
|
|
||||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
/* cluster_flr= */ nullptr);
|
/* cluster_flr= */ nullptr);
|
||||||
}
|
}
|
||||||
@ -382,8 +380,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
|
|||||||
EagerContext* context = new EagerContext(
|
EagerContext* context = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/* async= */ false,
|
/* async= */ false, &device_mgr,
|
||||||
/* lazy_copy_function_remote_inputs= */ false, &device_mgr,
|
|
||||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
/* cluster_flr= */ nullptr);
|
/* cluster_flr= */ nullptr);
|
||||||
|
|
||||||
@ -418,7 +415,7 @@ TEST(TensorHandle_DeviceNameTest, OnLocalDevice) {
|
|||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
false, &local_device_mgr, false, nullptr, nullptr);
|
&local_device_mgr, false, nullptr, nullptr);
|
||||||
|
|
||||||
Device* dcpu = local_device_mgr.ListDevices()[0];
|
Device* dcpu = local_device_mgr.ListDevices()[0];
|
||||||
Device* dgpu = local_device_mgr.ListDevices()[1];
|
Device* dgpu = local_device_mgr.ListDevices()[1];
|
||||||
|
@ -995,8 +995,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||||
FunctionLibraryRuntime::Handle handle, std::vector<Device*>* output_devices,
|
FunctionLibraryRuntime::Handle handle,
|
||||||
const bool eager_lazy_copy) const {
|
std::vector<Device*>* output_devices) const {
|
||||||
MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||||
if (data == nullptr) {
|
if (data == nullptr) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -1015,16 +1015,6 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
Device* target_device = nullptr;
|
Device* target_device = nullptr;
|
||||||
Device* host = nullptr;
|
Device* host = nullptr;
|
||||||
if (target_flr == nullptr) {
|
if (target_flr == nullptr) {
|
||||||
if (!eager_lazy_copy) {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"Currently, outputting tensors on remote devices is not supported."
|
|
||||||
"The ",
|
|
||||||
comp_data.ret_indices[0],
|
|
||||||
"-th return value of the function outputs to target_device: ",
|
|
||||||
target,
|
|
||||||
" Please copy the tensor to local device explicitly using "
|
|
||||||
"tf.identity and return the new Tensor instead.");
|
|
||||||
}
|
|
||||||
if (!data->has_remote_outputs) {
|
if (!data->has_remote_outputs) {
|
||||||
data->has_remote_outputs = true;
|
data->has_remote_outputs = true;
|
||||||
}
|
}
|
||||||
|
@ -150,8 +150,7 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// is set to the device backing the resource.
|
// is set to the device backing the resource.
|
||||||
// REQUIRES: `handle` identifies a multi-device function.
|
// REQUIRES: `handle` identifies a multi-device function.
|
||||||
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
|
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
|
||||||
std::vector<Device*>* output_devices,
|
std::vector<Device*>* output_devices) const;
|
||||||
const bool eager_lazy_copy) const;
|
|
||||||
|
|
||||||
// Returns true if function with handle `handle` was instantiated on device
|
// Returns true if function with handle `handle` was instantiated on device
|
||||||
// `device_name`. Returns false for multi-device functions.
|
// `device_name`. Returns false for multi-device functions.
|
||||||
|
@ -284,12 +284,8 @@ void EagerClusterFunctionLibraryRuntime::CleanUp(
|
|||||||
|
|
||||||
DistributedFunctionLibraryRuntime* CreateClusterFLR(
|
DistributedFunctionLibraryRuntime* CreateClusterFLR(
|
||||||
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
|
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
|
||||||
if (ctx->LazyCopyFunctionRemoteInputs()) {
|
return new EagerClusterFunctionLibraryRuntime(
|
||||||
return new EagerClusterFunctionLibraryRuntime(
|
context_id, ctx, worker_session->remote_device_mgr());
|
||||||
context_id, ctx, worker_session->remote_device_mgr());
|
|
||||||
} else {
|
|
||||||
return worker_session->cluster_flr();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
|
@ -274,8 +274,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
opts.config = request->server_def().default_session_config();
|
opts.config = request->server_def().default_session_config();
|
||||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||||
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
request->async(), request->lazy_copy_remote_function_inputs(), device_mgr,
|
request->async(), device_mgr, false, r, worker_session->cluster_flr());
|
||||||
false, r, worker_session->cluster_flr());
|
|
||||||
// Ownership will be transferred to the ServerContext, or else in an error
|
// Ownership will be transferred to the ServerContext, or else in an error
|
||||||
// case ctx will be deleted by this unref.
|
// case ctx will be deleted by this unref.
|
||||||
core::ScopedUnref unref_ctx(ctx);
|
core::ScopedUnref unref_ctx(ctx);
|
||||||
|
@ -1220,9 +1220,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
|||||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/*async=*/false,
|
/*async=*/false, device_mgr_.get(), false, rendezvous);
|
||||||
/*lazy_copy_function_remote_inputs=*/false, device_mgr_.get(), false,
|
|
||||||
rendezvous);
|
|
||||||
const uint64 context_id = random::New64();
|
const uint64 context_id = random::New64();
|
||||||
|
|
||||||
// Set RemoteMgr to ctx.
|
// Set RemoteMgr to ctx.
|
||||||
|
@ -58,7 +58,7 @@ Status CreateUncachedKernelAndDeviceOp(
|
|||||||
ctx.HostCPU()));
|
ctx.HostCPU()));
|
||||||
|
|
||||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||||
return kernel->get()->Init({ctx.LogDevicePlacement()}, ndef,
|
return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
|
||||||
/*graph_collector=*/nullptr);
|
/*graph_collector=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,9 +54,7 @@ class RemoteMgrTest : public ::testing::Test {
|
|||||||
ctx_ = new tensorflow::EagerContext(
|
ctx_ = new tensorflow::EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/*async=*/false,
|
/*async=*/false, device_mgr.release(), true, rendezvous, nullptr);
|
||||||
/*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), true,
|
|
||||||
rendezvous, nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~RemoteMgrTest() override { ctx_->Unref(); }
|
~RemoteMgrTest() override { ctx_->Unref(); }
|
||||||
|
@ -46,8 +46,8 @@ tensorflow::Status DelegateData::Prepare(
|
|||||||
eager_context_ = new tensorflow::EagerContext(
|
eager_context_ = new tensorflow::EagerContext(
|
||||||
session_options,
|
session_options,
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/*async=*/false, /*lazy_copy_function_remote_inputs=*/false,
|
/*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
|
||||||
device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr);
|
rendezvous, nullptr);
|
||||||
return tensorflow::Status();
|
return tensorflow::Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,7 +419,6 @@ class Context(object):
|
|||||||
if execution_mode is None:
|
if execution_mode is None:
|
||||||
execution_mode = SYNC
|
execution_mode = SYNC
|
||||||
self._default_is_async = execution_mode == ASYNC
|
self._default_is_async = execution_mode == ASYNC
|
||||||
self._lazy_remote_inputs_copy = None
|
|
||||||
self._use_tfrt = is_tfrt_enabled()
|
self._use_tfrt = is_tfrt_enabled()
|
||||||
self._server_def = server_def
|
self._server_def = server_def
|
||||||
self._collective_ops_server_def = None
|
self._collective_ops_server_def = None
|
||||||
@ -521,9 +520,6 @@ class Context(object):
|
|||||||
opts, self._mirroring_policy)
|
opts, self._mirroring_policy)
|
||||||
if self._default_is_async == ASYNC:
|
if self._default_is_async == ASYNC:
|
||||||
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
|
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
|
||||||
if self._lazy_remote_inputs_copy is not None:
|
|
||||||
pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
|
||||||
opts, self._lazy_remote_inputs_copy)
|
|
||||||
if self._use_tfrt is not None:
|
if self._use_tfrt is not None:
|
||||||
pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
|
pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
|
||||||
context_handle = pywrap_tfe.TFE_NewContext(opts)
|
context_handle = pywrap_tfe.TFE_NewContext(opts)
|
||||||
@ -1177,10 +1173,6 @@ class Context(object):
|
|||||||
A packed EagerTensor.
|
A packed EagerTensor.
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
if self._lazy_remote_inputs_copy is not None and (
|
|
||||||
not self._lazy_remote_inputs_copy):
|
|
||||||
raise ValueError("Packing eager tensors is not supported when "
|
|
||||||
"lazy_remote_inputs_copy is disabled.")
|
|
||||||
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
||||||
|
|
||||||
def remove_function(self, name):
|
def remove_function(self, name):
|
||||||
@ -1669,22 +1661,6 @@ class Context(object):
|
|||||||
pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
self._handle, self._device_policy)
|
self._handle, self._device_policy)
|
||||||
|
|
||||||
@property
|
|
||||||
def lazy_remote_inputs_copy(self):
|
|
||||||
return self._lazy_remote_inputs_copy
|
|
||||||
|
|
||||||
@lazy_remote_inputs_copy.setter
|
|
||||||
def lazy_remote_inputs_copy(self, lazy_copy):
|
|
||||||
"""Sets whether to copy remote inputs lazily for functions."""
|
|
||||||
if not isinstance(lazy_copy, bool):
|
|
||||||
raise ValueError("Expecting a boolean but got %s" % type(lazy_copy))
|
|
||||||
|
|
||||||
if self._lazy_remote_inputs_copy != lazy_copy:
|
|
||||||
if self._initialized:
|
|
||||||
raise ValueError(
|
|
||||||
"lazy_remote_inputs_copy should be set before being initialized.")
|
|
||||||
self._lazy_remote_inputs_copy = lazy_copy
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_tfrt(self):
|
def use_tfrt(self):
|
||||||
return self._use_tfrt
|
return self._use_tfrt
|
||||||
|
@ -233,21 +233,6 @@ class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
|
|||||||
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
|
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
|
||||||
|
|
||||||
|
|
||||||
class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).setUpClass()
|
|
||||||
context._reset_context()
|
|
||||||
context.context().lazy_remote_inputs_copy = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).tearDownClass()
|
|
||||||
context._reset_context()
|
|
||||||
context.context().lazy_remote_inputs_copy = True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -66,7 +66,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Reset the context to avoid polluting other test cases.
|
# Reset the context to avoid polluting other test cases.
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionBasic(self):
|
def testMultiDeviceFunctionBasic(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -81,7 +80,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
|
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
|
||||||
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
|
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionVariable(self):
|
def testMultiDeviceFunctionVariable(self):
|
||||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||||
variable_b = variables.Variable(1)
|
variable_b = variables.Variable(1)
|
||||||
@ -148,7 +146,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertIn('Dimensions must be equal', cm.exception.message)
|
self.assertIn('Dimensions must be equal', cm.exception.message)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testShapeError_Function(self):
|
def testShapeError_Function(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -179,7 +176,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:worker/task:0'):
|
with ops.device('/job:worker/task:0'):
|
||||||
self.assertAllEqual(func(), 1)
|
self.assertAllEqual(func(), 1)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testRemoteCall(self):
|
def testRemoteCall(self):
|
||||||
|
|
||||||
@def_function.function(
|
@def_function.function(
|
||||||
@ -306,7 +302,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Reset the context to avoid polluting other test cases.
|
# Reset the context to avoid polluting other test cases.
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testReturnRemoteArgument(self):
|
def testReturnRemoteArgument(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -376,7 +371,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
os.environ[remote_async_env_var] = default_streaming
|
os.environ[remote_async_env_var] = default_streaming
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionOnLocalDevice(self):
|
def testMultiDeviceFunctionOnLocalDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
variable_b = variables.Variable(1.0)
|
variable_b = variables.Variable(1.0)
|
||||||
@ -444,7 +438,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Run the function on a local worker
|
# Run the function on a local worker
|
||||||
self.assertAllEqual(add_variables().numpy(), 3.0)
|
self.assertAllEqual(add_variables().numpy(), 3.0)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
|
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
variable_b = variables.Variable([1.0])
|
variable_b = variables.Variable([1.0])
|
||||||
@ -480,7 +473,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:worker/replica:0/task:2'):
|
with ops.device('/job:worker/replica:0/task:2'):
|
||||||
self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
|
self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionOnRemoteDevice(self):
|
def testMultiDeviceFunctionOnRemoteDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
variable_b = variables.Variable(1.0)
|
variable_b = variables.Variable(1.0)
|
||||||
@ -518,7 +510,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(rets[0].numpy(), [2])
|
self.assertAllEqual(rets[0].numpy(), [2])
|
||||||
self.assertAllEqual(rets[1].numpy(), 2)
|
self.assertAllEqual(rets[1].numpy(), 2)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
variable_b = variables.Variable(1.0)
|
variable_b = variables.Variable(1.0)
|
||||||
@ -540,7 +531,6 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
||||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testSimpleParameterServer(self):
|
def testSimpleParameterServer(self):
|
||||||
|
|
||||||
with ops.device('/job:worker/task:2/device:CPU:0'):
|
with ops.device('/job:worker/task:2/device:CPU:0'):
|
||||||
@ -585,7 +575,6 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Reset the context to avoid polluting other test cases.
|
# Reset the context to avoid polluting other test cases.
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testSimpleParameterServer(self):
|
def testSimpleParameterServer(self):
|
||||||
remote.connect_to_cluster(self._cluster)
|
remote.connect_to_cluster(self._cluster)
|
||||||
|
|
||||||
@ -606,7 +595,6 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(worker_fn(), 8)
|
self.assertAllEqual(worker_fn(), 8)
|
||||||
|
|
||||||
# TODO(b/152224115): Re-enable this test.
|
# TODO(b/152224115): Re-enable this test.
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
|
def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
|
||||||
cluster_device_filters = server_lib.ClusterDeviceFilters()
|
cluster_device_filters = server_lib.ClusterDeviceFilters()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -653,7 +641,6 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
# subsequent tests.
|
# subsequent tests.
|
||||||
del v1, v2
|
del v1, v2
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testConnectWithClusterResolver(self):
|
def testConnectWithClusterResolver(self):
|
||||||
remote.connect_to_cluster(self._cluster_resolver)
|
remote.connect_to_cluster(self._cluster_resolver)
|
||||||
|
|
||||||
@ -672,12 +659,10 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||||
self.assertAllEqual(worker_fn(), 8)
|
self.assertAllEqual(worker_fn(), 8)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testConnectToClusterTwiceOk(self):
|
def testConnectToClusterTwiceOk(self):
|
||||||
remote.connect_to_cluster(self._cluster_resolver)
|
remote.connect_to_cluster(self._cluster_resolver)
|
||||||
remote.connect_to_cluster(self._cluster_resolver)
|
remote.connect_to_cluster(self._cluster_resolver)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testConnectToClusterOnMismatchedDevice(self):
|
def testConnectToClusterOnMismatchedDevice(self):
|
||||||
remote.connect_to_cluster(self._cluster_resolver)
|
remote.connect_to_cluster(self._cluster_resolver)
|
||||||
|
|
||||||
@ -687,12 +672,10 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
remote.connect_to_cluster(self._cluster_resolver)
|
remote.connect_to_cluster(self._cluster_resolver)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testConnectToClusterWithLocalMaster(self):
|
def testConnectToClusterWithLocalMaster(self):
|
||||||
local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
|
local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
|
||||||
remote.connect_to_cluster(local_resolver)
|
remote.connect_to_cluster(local_resolver)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testConnectToClusterInGraphModeWillFail(self):
|
def testConnectToClusterInGraphModeWillFail(self):
|
||||||
ops.disable_eager_execution()
|
ops.disable_eager_execution()
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
@ -1107,21 +1107,6 @@ def run_in_async_and_sync_mode(f):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def eager_lazy_remote_copy_on_and_off(f):
|
|
||||||
"""Execute the test method w/o lazy tensor copy for function remote inputs."""
|
|
||||||
|
|
||||||
@parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)])
|
|
||||||
@functools.wraps(f)
|
|
||||||
def decorator(self, lazily_remote_copy, *args, **kwargs):
|
|
||||||
if lazily_remote_copy:
|
|
||||||
context.context().lazy_remote_inputs_copy = True
|
|
||||||
else:
|
|
||||||
context.context().lazy_remote_inputs_copy = False
|
|
||||||
f(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def run_in_graph_and_eager_modes(func=None,
|
def run_in_graph_and_eager_modes(func=None,
|
||||||
config=None,
|
config=None,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
|
@ -1006,8 +1006,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
});
|
});
|
||||||
m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
|
m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
|
||||||
&TFE_ContextOptionsSetDevicePlacementPolicy);
|
&TFE_ContextOptionsSetDevicePlacementPolicy);
|
||||||
m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy",
|
|
||||||
&TFE_ContextOptionsSetLazyRemoteInputsCopy);
|
|
||||||
m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
|
m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
|
||||||
m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
|
m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
|
||||||
m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
|
m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
|
||||||
|
Loading…
Reference in New Issue
Block a user