Add lazy_remote_inputs_copy to TFE_ContextOptions to control lazy remote tensor copy. Disable it by default.
PiperOrigin-RevId: 279212487 Change-Id: Ie46de71fd2902b79281e6257ff28c06d9aaa73d4
This commit is contained in:
parent
40164d7a26
commit
206d6af149
@ -264,6 +264,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
tensorflow::uint64 context_view_id, int keep_alive_secs,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
@ -300,6 +301,8 @@ tensorflow::Status CreateRemoteContexts(
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
@ -319,7 +322,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
@ -527,7 +530,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(), base_request));
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
// The master's context_view_id will be incremented by one
|
||||
// the UpdateRemoteMaster call later. We want all new workers and
|
||||
@ -537,7 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(), base_request));
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
@ -546,8 +551,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
existing_workers, context_id, context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), ctx->context->Executor().Async(),
|
||||
base_request));
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
@ -713,7 +717,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr.release(),
|
||||
opts->async, opts->lazy_remote_inputs_copy,
|
||||
device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
}
|
||||
@ -728,7 +733,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
}
|
||||
|
||||
|
@ -557,6 +557,11 @@ extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
ctx->context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
|
||||
bool lazy_copy) {
|
||||
options->lazy_remote_inputs_copy = lazy_copy;
|
||||
}
|
||||
|
||||
TFE_CancellationManager* TFE_NewCancellationManager() {
|
||||
return new TFE_CancellationManager;
|
||||
}
|
||||
|
@ -336,6 +336,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context*);
|
||||
|
||||
// Sets whether to copy the remote inputs of a function lazily.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
TFE_ContextOptions*, bool lazy_copy);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Cancellation APIs.
|
||||
|
||||
|
@ -57,12 +57,15 @@ struct TFE_ContextOptions {
|
||||
TFE_ContextDevicePlacementPolicy device_placement_policy{
|
||||
TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
// If true, lazily copy the remote inputs of a function to the target devices.
|
||||
bool lazy_remote_inputs_copy = false;
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const bool lazy_remote_inputs_copy,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
@ -72,8 +75,8 @@ struct TFE_Context {
|
||||
default_device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
default_mirroring_policy),
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
|
||||
rendezvous, custom_kernel_creator)) {}
|
||||
|
||||
~TFE_Context() {
|
||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
||||
|
@ -68,13 +68,12 @@ auto* eager_context_created =
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(b/134094971): Make lazily_copy_function_remote_inputs_ configurable once
|
||||
// it's ready to enable.
|
||||
EagerContext::EagerContext(
|
||||
const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
|
||||
const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
|
||||
bool device_mgr_owned, Rendezvous* rendezvous,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr)
|
||||
: default_device_placement_policy_(default_device_placement_policy),
|
||||
@ -91,7 +90,7 @@ EagerContext::EagerContext(
|
||||
default_executor_(async),
|
||||
log_memory_(LogMemory::IsEnabled()),
|
||||
env_(opts.env),
|
||||
lazily_copy_function_remote_inputs_(false),
|
||||
lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
|
||||
use_send_tensor_rpc_(false),
|
||||
pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
|
||||
"TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
|
||||
@ -130,7 +129,7 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
||||
thread::ThreadPool* thread_pool,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
const CustomKernelCreator* custom_kernel_creator) {
|
||||
if (lazily_copy_function_remote_inputs_) {
|
||||
if (lazy_copy_function_remote_inputs_) {
|
||||
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
|
||||
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
|
||||
thread_pool, cluster_flr, custom_kernel_creator));
|
||||
@ -164,7 +163,7 @@ void EagerContext::InitDeviceMapAndAsync() {
|
||||
|
||||
void EagerContext::ResetClusterFLR(
|
||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||
cluster_flr_.Reset(cluster_flr, lazily_copy_function_remote_inputs_);
|
||||
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
|
||||
}
|
||||
|
||||
EagerExecutor& EagerContext::Executor() {
|
||||
@ -239,8 +238,8 @@ bool EagerContext::MirrorTensors() const {
|
||||
return GetMirroringPolicy() == MIRRORING_ALL;
|
||||
}
|
||||
|
||||
bool EagerContext::LazilyCopyFunctionRemoteInputs() const {
|
||||
return lazily_copy_function_remote_inputs_;
|
||||
bool EagerContext::LazyCopyFunctionRemoteInputs() const {
|
||||
return lazy_copy_function_remote_inputs_;
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
|
@ -121,6 +121,7 @@ class EagerContext : public core::RefCounted {
|
||||
EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const bool lazy_copy_function_remote_inputs,
|
||||
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
Rendezvous* rendezvous,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
@ -168,7 +169,7 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
bool MirrorTensors() const;
|
||||
|
||||
bool LazilyCopyFunctionRemoteInputs() const;
|
||||
bool LazyCopyFunctionRemoteInputs() const;
|
||||
|
||||
bool FindFunctionByName(const string& name);
|
||||
|
||||
@ -461,7 +462,7 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
// EagerContext owns the DistributedFunctionLibraryRuntime(
|
||||
// EagerClusterFunctionLibraryRuntime) if using EagerService for remote
|
||||
// function execution (lazily_copy_function_remote_inputs_=true).
|
||||
// function execution (lazy_copy_function_remote_inputs_=true).
|
||||
OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
|
||||
// One FunctionLibraryRuntime per device.
|
||||
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
||||
@ -553,7 +554,12 @@ class EagerContext : public core::RefCounted {
|
||||
bool is_master_ GUARDED_BY(remote_state_mu_);
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
bool lazily_copy_function_remote_inputs_;
|
||||
// For a multi device function, the target device of each input is unknown
|
||||
// until the function is instantiated on the default function device.
|
||||
// If false, eagerly copy all remote inputs to the default function device;
|
||||
// if true, lazily copy remote inputs to their target devices to avoid
|
||||
// redundant copies.
|
||||
bool lazy_copy_function_remote_inputs_ = false;
|
||||
bool use_send_tensor_rpc_;
|
||||
const bool pin_small_ops_to_cpu_;
|
||||
|
||||
|
@ -212,7 +212,7 @@ Status ValidateInputTypeAndPlacement(
|
||||
" inputs, got ", n_inputs);
|
||||
}
|
||||
const bool skip_remote_copy =
|
||||
ctx->LazilyCopyFunctionRemoteInputs() && kernel->IsFunction();
|
||||
ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction();
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
Device* expected_device = kernel->InputDevice(i);
|
||||
@ -499,14 +499,12 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
input_dev_ptrs.reserve(op->Inputs().size());
|
||||
// When LazilyCopyFunctionRemoteInputs is disabled, all inputs need to be on
|
||||
// When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on
|
||||
// local devices, since we execute a remote function through worker service,
|
||||
// which doesn't accept remote inputs.
|
||||
// TODO(b/134094971): Make resource_dtypes_and_shapes avaliable without
|
||||
// remote tensor copy.
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
TensorHandle* input = op->Inputs()[i];
|
||||
if (!ctx->LazilyCopyFunctionRemoteInputs() && input->IsRemote()) {
|
||||
if (!ctx->LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, ctx, &executor, device == nullptr ? ctx->HostCPU() : device,
|
||||
@ -603,7 +601,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
<< ". Full node_def=" << ndef.DebugString();
|
||||
std::function<int64()> get_op_id = nullptr;
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (ctx->LazilyCopyFunctionRemoteInputs()) {
|
||||
if (ctx->LazyCopyFunctionRemoteInputs()) {
|
||||
get_op_id = [ctx]() { return ctx->RemoteMgr()->NextOpId(); };
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
@ -750,7 +748,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
profiler::TraceMe activity("CopyInputToExpectedDevice",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
const bool eagerly_copy_function_remote_inputs =
|
||||
!ctx->LazilyCopyFunctionRemoteInputs() || !op->is_function();
|
||||
!ctx->LazyCopyFunctionRemoteInputs() || !op->is_function();
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
tensorflow::Device* input_device = input->device();
|
||||
@ -834,12 +832,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->LazilyCopyFunctionRemoteInputs()) {
|
||||
if (ctx->LazyCopyFunctionRemoteInputs()) {
|
||||
// Store the data type and shape of a remote resource variable on the
|
||||
// corresponding remote TensorHandle (output of 'VarHandleOp').
|
||||
// If the variable is an input of a remote function, the function may need
|
||||
// the type and shape during function instantiation. When
|
||||
// LazilyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
|
||||
// LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
|
||||
// handle (contains the type and shape) of the variable to the default
|
||||
// function device. Instead, we store the type and shape on eager master
|
||||
// and sent them to the default function device along with the
|
||||
|
@ -216,7 +216,7 @@ void EagerClusterFunctionLibraryRuntime::CleanUp(
|
||||
|
||||
DistributedFunctionLibraryRuntime* CreateClusterFLR(
|
||||
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
|
||||
if (ctx->LazilyCopyFunctionRemoteInputs()) {
|
||||
if (ctx->LazyCopyFunctionRemoteInputs()) {
|
||||
return new EagerClusterFunctionLibraryRuntime(
|
||||
context_id, ctx, worker_session->remote_device_mgr());
|
||||
} else {
|
||||
|
@ -162,8 +162,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
|
||||
device_mgr, false, r, GetDefaultCustomKernelCreator(),
|
||||
worker_session->cluster_flr());
|
||||
request->lazy_copy_remote_function_inputs(), device_mgr, false, r,
|
||||
GetDefaultCustomKernelCreator(), worker_session->cluster_flr());
|
||||
// Ownership will be transferred to the ServerContext, or else in an error
|
||||
// case ctx will be deleted by this unref.
|
||||
core::ScopedUnref unref_ctx(ctx);
|
||||
|
@ -710,8 +710,9 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
|
||||
device_mgr_.get(), false, rendezvous, GetDefaultCustomKernelCreator());
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
|
||||
/*lazy_copy_function_remote_inputs=*/false, device_mgr_.get(), false,
|
||||
rendezvous, GetDefaultCustomKernelCreator());
|
||||
const uint64 context_id = random::New64();
|
||||
|
||||
// Set RemoteMgr to ctx.
|
||||
|
@ -55,9 +55,9 @@ class RemoteMgrTest : public ::testing::Test {
|
||||
ctx_ = new tensorflow::EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
|
||||
device_mgr.release(), true, rendezvous, GetDefaultCustomKernelCreator(),
|
||||
nullptr);
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
|
||||
/*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), true,
|
||||
rendezvous, GetDefaultCustomKernelCreator(), nullptr);
|
||||
}
|
||||
|
||||
~RemoteMgrTest() override { ctx_->Unref(); }
|
||||
|
@ -90,6 +90,11 @@ message CreateContextRequest {
|
||||
// The view ID of the context.
|
||||
fixed64 context_view_id = 8;
|
||||
|
||||
// For a multi device function, if false, eagerly copy all remote inputs to
|
||||
// the default function device; if true, lazily copy remote inputs to their
|
||||
// target devices after function instantiation to avoid redundant copies.
|
||||
bool lazy_copy_remote_function_inputs = 9;
|
||||
|
||||
reserved 5;
|
||||
}
|
||||
|
||||
|
@ -47,8 +47,8 @@ tensorflow::Status DelegateData::Prepare(
|
||||
session_options,
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
|
||||
rendezvous, nullptr);
|
||||
/*async=*/false, /*lazy_copy_function_remote_inputs=*/false,
|
||||
device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr);
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
||||
|
@ -775,9 +775,11 @@ cuda_py_test(
|
||||
":def_function",
|
||||
":test",
|
||||
":remote",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"no_oss", # This test launches local server.
|
||||
"optonly", # times out
|
||||
|
@ -408,6 +408,7 @@ class Context(object):
|
||||
if execution_mode is None:
|
||||
execution_mode = SYNC
|
||||
self._default_is_async = execution_mode == ASYNC
|
||||
self._lazy_remote_inputs_copy = False
|
||||
self._server_def = server_def
|
||||
self._collective_ops_server_def = None
|
||||
self._collective_leader = None
|
||||
@ -502,6 +503,9 @@ class Context(object):
|
||||
opts, self._mirroring_policy)
|
||||
if self._default_is_async == ASYNC:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
|
||||
if self._lazy_remote_inputs_copy:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
opts, True)
|
||||
context_handle = pywrap_tensorflow.TFE_NewContext(opts)
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
|
||||
@ -1445,6 +1449,22 @@ class Context(object):
|
||||
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
self._handle, self._mirroring_policy)
|
||||
|
||||
@property
|
||||
def lazy_remote_inputs_copy(self):
|
||||
return self._lazy_remote_inputs_copy
|
||||
|
||||
@lazy_remote_inputs_copy.setter
|
||||
def lazy_remote_inputs_copy(self, lazy_copy):
|
||||
"""Sets whether to copy remote inputs lazily for functions."""
|
||||
if not isinstance(lazy_copy, bool):
|
||||
raise ValueError("Expecting a boolean but got %s" % type(lazy_copy))
|
||||
|
||||
if self._lazy_remote_inputs_copy != lazy_copy:
|
||||
if self._initialized:
|
||||
raise ValueError(
|
||||
"lazy_remote_inputs_copy should be set before being initialized.")
|
||||
self._lazy_remote_inputs_copy = lazy_copy
|
||||
|
||||
def enable_run_metadata(self):
|
||||
"""Enables tracing of op execution via RunMetadata.
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
@ -39,7 +40,7 @@ from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
class SingleWorkerTest(test.TestCase):
|
||||
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(SingleWorkerTest, self).setUp()
|
||||
@ -55,6 +56,7 @@ class SingleWorkerTest(test.TestCase):
|
||||
# Reset the context to avoid polluting other test cases.
|
||||
context._reset_context()
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionBasic(self):
|
||||
|
||||
@def_function.function
|
||||
@ -69,6 +71,7 @@ class SingleWorkerTest(test.TestCase):
|
||||
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
|
||||
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionVariable(self):
|
||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||
variable_b = variables.Variable(1)
|
||||
@ -79,6 +82,7 @@ class SingleWorkerTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionRemoteOutput(self):
|
||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||
variable_b = variables.Variable(1)
|
||||
@ -134,6 +138,7 @@ class SingleWorkerTest(test.TestCase):
|
||||
|
||||
self.assertIn('Dimensions must be equal', cm.exception.message)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testShapeError_Function(self):
|
||||
|
||||
@def_function.function
|
||||
@ -150,7 +155,7 @@ class SingleWorkerTest(test.TestCase):
|
||||
self.assertIn('Dimensions must be equal', cm.exception.message)
|
||||
|
||||
|
||||
class MultiWorkersTest(test.TestCase):
|
||||
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MultiWorkersTest, self).setUp()
|
||||
@ -167,6 +172,7 @@ class MultiWorkersTest(test.TestCase):
|
||||
# Reset the context to avoid polluting other test cases.
|
||||
context._reset_context()
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionOnLocalDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
variable_b = variables.Variable(1.0)
|
||||
@ -180,6 +186,7 @@ class MultiWorkersTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionOnRemoteDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
variable_b = variables.Variable(1.0)
|
||||
@ -209,6 +216,7 @@ class MultiWorkersTest(test.TestCase):
|
||||
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
variable_b = variables.Variable(1.0)
|
||||
@ -241,6 +249,7 @@ class MultiWorkersTest(test.TestCase):
|
||||
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testSimpleParameterServer(self):
|
||||
|
||||
with ops.device('/job:worker/task:2/device:CPU:0'):
|
||||
@ -263,7 +272,7 @@ class MultiWorkersTest(test.TestCase):
|
||||
_GRPC_PREFIX = 'grpc://'
|
||||
|
||||
|
||||
class MultiJobsTest(test.TestCase):
|
||||
class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MultiJobsTest, self).setUp()
|
||||
@ -288,6 +297,7 @@ class MultiJobsTest(test.TestCase):
|
||||
# Reset the context to avoid polluting other test cases.
|
||||
context._reset_context()
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testSimpleParameterServer(self):
|
||||
remote.connect_to_cluster(self._cluster)
|
||||
|
||||
@ -307,6 +317,7 @@ class MultiJobsTest(test.TestCase):
|
||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 8)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testConnectWithClusterResolver(self):
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
|
||||
@ -325,10 +336,12 @@ class MultiJobsTest(test.TestCase):
|
||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 8)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testConnectToClusterTwiceOk(self):
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testConnectToClusterOnMismatchedDevice(self):
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
|
||||
@ -338,6 +351,7 @@ class MultiJobsTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testConnectToClusterWithLocalMaster(self):
|
||||
local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
|
||||
remote.connect_to_cluster(local_resolver)
|
||||
|
@ -1019,6 +1019,21 @@ def build_as_function_and_v1_graph(func=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def eager_lazy_remote_copy_on_and_off(f):
|
||||
"""Execute the test method w/o lazy tensor copy for function remote inputs."""
|
||||
|
||||
@parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)])
|
||||
@functools.wraps(f)
|
||||
def decorator(self, lazily_remote_copy, *args, **kwargs):
|
||||
if lazily_remote_copy:
|
||||
context.context().lazy_remote_inputs_copy = True
|
||||
else:
|
||||
context.context().lazy_remote_inputs_copy = False
|
||||
f(self, *args, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def run_in_graph_and_eager_modes(func=None,
|
||||
config=None,
|
||||
use_gpu=True,
|
||||
|
@ -110,6 +110,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetAsync;
|
||||
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
|
||||
%rename("%s") TFE_DeleteContextOptions;
|
||||
%rename("%s") TFE_Py_TensorShapeSlice;
|
||||
%rename("%s") TFE_Py_TensorShapeOnDevice;
|
||||
|
Loading…
Reference in New Issue
Block a user