Add lazy_remote_inputs_copy to TFE_ContextOptions to control lazy remote tensor copy. Disable it by default.

PiperOrigin-RevId: 279212487
Change-Id: Ie46de71fd2902b79281e6257ff28c06d9aaa73d4
This commit is contained in:
Yujing Zhang 2019-11-07 18:27:11 -08:00 committed by TensorFlower Gardener
parent 40164d7a26
commit 206d6af149
18 changed files with 121 additions and 42 deletions

View File

@ -264,6 +264,7 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::uint64 context_view_id, int keep_alive_secs,
const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
@ -300,6 +301,8 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
lazy_copy_remote_function_inputs);
eager_client->CreateContextAsync(
&request, response,
@ -319,7 +322,7 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
@ -527,7 +530,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(), base_request));
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
} else {
// The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and
@ -537,7 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(), base_request));
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
@ -546,8 +551,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, context_view_id + 1, server_def,
remote_eager_workers.get(), ctx->context->Executor().Async(),
base_request));
remote_eager_workers.get(), base_request));
}
}
@ -713,7 +717,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr.release(),
opts->async, opts->lazy_remote_inputs_copy,
device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
}
@ -728,7 +733,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
}

View File

@ -557,6 +557,11 @@ extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
ctx->context->GetMirroringPolicy());
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
bool lazy_copy) {
options->lazy_remote_inputs_copy = lazy_copy;
}
TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
}

View File

@ -336,6 +336,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context*);
// Sets whether to copy the remote inputs of a function lazily.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -57,12 +57,15 @@ struct TFE_ContextOptions {
TFE_ContextDevicePlacementPolicy device_placement_policy{
TFE_DEVICE_PLACEMENT_SILENT};
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = false;
};
struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
@ -72,8 +75,8 @@ struct TFE_Context {
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we

View File

@ -68,13 +68,12 @@ auto* eager_context_created =
} // namespace
// TODO(b/134094971): Make lazily_copy_function_remote_inputs_ configurable once
// it's ready to enable.
EagerContext::EagerContext(
const SessionOptions& opts,
ContextDevicePlacementPolicy default_device_placement_policy,
ContextMirroringPolicy default_mirroring_policy, bool async,
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
bool device_mgr_owned, Rendezvous* rendezvous,
const CustomKernelCreator* custom_kernel_creator,
DistributedFunctionLibraryRuntime* cluster_flr)
: default_device_placement_policy_(default_device_placement_policy),
@ -91,7 +90,7 @@ EagerContext::EagerContext(
default_executor_(async),
log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
lazily_copy_function_remote_inputs_(false),
lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
use_send_tensor_rpc_(false),
pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
"TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
@ -130,7 +129,7 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
thread::ThreadPool* thread_pool,
DistributedFunctionLibraryRuntime* cluster_flr,
const CustomKernelCreator* custom_kernel_creator) {
if (lazily_copy_function_remote_inputs_) {
if (lazy_copy_function_remote_inputs_) {
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator));
@ -164,7 +163,7 @@ void EagerContext::InitDeviceMapAndAsync() {
void EagerContext::ResetClusterFLR(
DistributedFunctionLibraryRuntime* cluster_flr) {
cluster_flr_.Reset(cluster_flr, lazily_copy_function_remote_inputs_);
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
}
EagerExecutor& EagerContext::Executor() {
@ -239,8 +238,8 @@ bool EagerContext::MirrorTensors() const {
return GetMirroringPolicy() == MIRRORING_ALL;
}
bool EagerContext::LazilyCopyFunctionRemoteInputs() const {
return lazily_copy_function_remote_inputs_;
bool EagerContext::LazyCopyFunctionRemoteInputs() const {
return lazy_copy_function_remote_inputs_;
}
#if !defined(IS_MOBILE_PLATFORM)

View File

@ -121,6 +121,7 @@ class EagerContext : public core::RefCounted {
EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_device_placement_policy,
ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_copy_function_remote_inputs,
const DeviceMgr* device_mgr, bool device_mgr_owned,
Rendezvous* rendezvous,
const CustomKernelCreator* custom_kernel_creator,
@ -168,7 +169,7 @@ class EagerContext : public core::RefCounted {
bool MirrorTensors() const;
bool LazilyCopyFunctionRemoteInputs() const;
bool LazyCopyFunctionRemoteInputs() const;
bool FindFunctionByName(const string& name);
@ -461,7 +462,7 @@ class EagerContext : public core::RefCounted {
// EagerContext owns the DistributedFunctionLibraryRuntime(
// EagerClusterFunctionLibraryRuntime) if using EagerService for remote
// function execution (lazily_copy_function_remote_inputs_=true).
// function execution (lazy_copy_function_remote_inputs_=true).
OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
@ -553,7 +554,12 @@ class EagerContext : public core::RefCounted {
bool is_master_ GUARDED_BY(remote_state_mu_);
#endif // IS_MOBILE_PLATFORM
bool lazily_copy_function_remote_inputs_;
// For a multi device function, the target device of each input is unknown
// until the function is instantiated on the default function device.
// If false, eagerly copy all remote inputs to the default function device;
// if true, lazily copy remote inputs to their target devices to avoid
// redundant copies.
bool lazy_copy_function_remote_inputs_ = false;
bool use_send_tensor_rpc_;
const bool pin_small_ops_to_cpu_;

View File

@ -212,7 +212,7 @@ Status ValidateInputTypeAndPlacement(
" inputs, got ", n_inputs);
}
const bool skip_remote_copy =
ctx->LazilyCopyFunctionRemoteInputs() && kernel->IsFunction();
ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction();
for (int i = 0; i < n_inputs; ++i) {
TensorHandle* handle = op->Inputs()[i];
Device* expected_device = kernel->InputDevice(i);
@ -499,14 +499,12 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
profiler::TraceMeLevel::kInfo);
input_dev_ptrs.reserve(op->Inputs().size());
// When LazilyCopyFunctionRemoteInputs is disabled, all inputs need to be on
// When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on
// local devices, since we execute a remote function through worker service,
// which doesn't accept remote inputs.
// TODO(b/134094971): Make resource_dtypes_and_shapes avaliable without
// remote tensor copy.
for (int i = 0; i < op->Inputs().size(); i++) {
TensorHandle* input = op->Inputs()[i];
if (!ctx->LazilyCopyFunctionRemoteInputs() && input->IsRemote()) {
if (!ctx->LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice(
input, ctx, &executor, device == nullptr ? ctx->HostCPU() : device,
@ -603,7 +601,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
<< ". Full node_def=" << ndef.DebugString();
std::function<int64()> get_op_id = nullptr;
#if !defined(IS_MOBILE_PLATFORM)
if (ctx->LazilyCopyFunctionRemoteInputs()) {
if (ctx->LazyCopyFunctionRemoteInputs()) {
get_op_id = [ctx]() { return ctx->RemoteMgr()->NextOpId(); };
}
#endif // IS_MOBILE_PLATFORM
@ -750,7 +748,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
profiler::TraceMe activity("CopyInputToExpectedDevice",
profiler::TraceMeLevel::kInfo);
const bool eagerly_copy_function_remote_inputs =
!ctx->LazilyCopyFunctionRemoteInputs() || !op->is_function();
!ctx->LazyCopyFunctionRemoteInputs() || !op->is_function();
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::TensorHandle* input = op->Inputs()[i];
tensorflow::Device* input_device = input->device();
@ -834,12 +832,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
}
}
if (ctx->LazilyCopyFunctionRemoteInputs()) {
if (ctx->LazyCopyFunctionRemoteInputs()) {
// Store the data type and shape of a remote resource variable on the
// corresponding remote TensorHandle (output of 'VarHandleOp').
// If the variable is an input of a remote function, the function may need
// the type and shape during function instantiation. When
// LazilyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
// LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource
// handle (contains the type and shape) of the variable to the default
// function device. Instead, we store the type and shape on eager master
// and sent them to the default function device along with the

View File

@ -216,7 +216,7 @@ void EagerClusterFunctionLibraryRuntime::CleanUp(
DistributedFunctionLibraryRuntime* CreateClusterFLR(
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
if (ctx->LazilyCopyFunctionRemoteInputs()) {
if (ctx->LazyCopyFunctionRemoteInputs()) {
return new EagerClusterFunctionLibraryRuntime(
context_id, ctx, worker_session->remote_device_mgr());
} else {

View File

@ -162,8 +162,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, GetDefaultCustomKernelCreator(),
worker_session->cluster_flr());
request->lazy_copy_remote_function_inputs(), device_mgr, false, r,
GetDefaultCustomKernelCreator(), worker_session->cluster_flr());
// Ownership will be transferred to the ServerContext, or else in an error
// case ctx will be deleted by this unref.
core::ScopedUnref unref_ctx(ctx);

View File

@ -710,8 +710,9 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
device_mgr_.get(), false, rendezvous, GetDefaultCustomKernelCreator());
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
/*lazy_copy_function_remote_inputs=*/false, device_mgr_.get(), false,
rendezvous, GetDefaultCustomKernelCreator());
const uint64 context_id = random::New64();
// Set RemoteMgr to ctx.

View File

@ -55,9 +55,9 @@ class RemoteMgrTest : public ::testing::Test {
ctx_ = new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false,
device_mgr.release(), true, rendezvous, GetDefaultCustomKernelCreator(),
nullptr);
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
/*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), true,
rendezvous, GetDefaultCustomKernelCreator(), nullptr);
}
~RemoteMgrTest() override { ctx_->Unref(); }

View File

@ -90,6 +90,11 @@ message CreateContextRequest {
// The view ID of the context.
fixed64 context_view_id = 8;
// For a multi device function, if false, eagerly copy all remote inputs to
// the default function device; if true, lazily copy remote inputs to their
// target devices after function instantiation to avoid redundant copies.
bool lazy_copy_remote_function_inputs = 9;
reserved 5;
}

View File

@ -47,8 +47,8 @@ tensorflow::Status DelegateData::Prepare(
session_options,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
rendezvous, nullptr);
/*async=*/false, /*lazy_copy_function_remote_inputs=*/false,
device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr);
return tensorflow::Status();
}

View File

@ -775,9 +775,11 @@ cuda_py_test(
":def_function",
":test",
":remote",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:resource_variable_ops",
],
grpc_enabled = True,
shard_count = 2,
tags = [
"no_oss", # This test launches local server.
"optonly", # times out

View File

@ -408,6 +408,7 @@ class Context(object):
if execution_mode is None:
execution_mode = SYNC
self._default_is_async = execution_mode == ASYNC
self._lazy_remote_inputs_copy = False
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
@ -502,6 +503,9 @@ class Context(object):
opts, self._mirroring_policy)
if self._default_is_async == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
if self._lazy_remote_inputs_copy:
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy(
opts, True)
context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
@ -1445,6 +1449,22 @@ class Context(object):
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
self._handle, self._mirroring_policy)
@property
def lazy_remote_inputs_copy(self):
return self._lazy_remote_inputs_copy
@lazy_remote_inputs_copy.setter
def lazy_remote_inputs_copy(self, lazy_copy):
"""Sets whether to copy remote inputs lazily for functions."""
if not isinstance(lazy_copy, bool):
raise ValueError("Expecting a boolean but got %s" % type(lazy_copy))
if self._lazy_remote_inputs_copy != lazy_copy:
if self._initialized:
raise ValueError(
"lazy_remote_inputs_copy should be set before being initialized.")
self._lazy_remote_inputs_copy = lazy_copy
def enable_run_metadata(self):
"""Enables tracing of op execution via RunMetadata.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import random
from absl.testing import parameterized
import numpy as np
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
@ -39,7 +40,7 @@ from tensorflow.python.training import server_lib
from tensorflow.python.training.server_lib import ClusterSpec
class SingleWorkerTest(test.TestCase):
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(SingleWorkerTest, self).setUp()
@ -55,6 +56,7 @@ class SingleWorkerTest(test.TestCase):
# Reset the context to avoid polluting other test cases.
context._reset_context()
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionBasic(self):
@def_function.function
@ -69,6 +71,7 @@ class SingleWorkerTest(test.TestCase):
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionVariable(self):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
variable_b = variables.Variable(1)
@ -79,6 +82,7 @@ class SingleWorkerTest(test.TestCase):
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionRemoteOutput(self):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
variable_b = variables.Variable(1)
@ -134,6 +138,7 @@ class SingleWorkerTest(test.TestCase):
self.assertIn('Dimensions must be equal', cm.exception.message)
@test_util.eager_lazy_remote_copy_on_and_off
def testShapeError_Function(self):
@def_function.function
@ -150,7 +155,7 @@ class SingleWorkerTest(test.TestCase):
self.assertIn('Dimensions must be equal', cm.exception.message)
class MultiWorkersTest(test.TestCase):
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(MultiWorkersTest, self).setUp()
@ -167,6 +172,7 @@ class MultiWorkersTest(test.TestCase):
# Reset the context to avoid polluting other test cases.
context._reset_context()
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionOnLocalDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@ -180,6 +186,7 @@ class MultiWorkersTest(test.TestCase):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@ -209,6 +216,7 @@ class MultiWorkersTest(test.TestCase):
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceWhileLoopOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@ -241,6 +249,7 @@ class MultiWorkersTest(test.TestCase):
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
@test_util.eager_lazy_remote_copy_on_and_off
def testSimpleParameterServer(self):
with ops.device('/job:worker/task:2/device:CPU:0'):
@ -263,7 +272,7 @@ class MultiWorkersTest(test.TestCase):
_GRPC_PREFIX = 'grpc://'
class MultiJobsTest(test.TestCase):
class MultiJobsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(MultiJobsTest, self).setUp()
@ -288,6 +297,7 @@ class MultiJobsTest(test.TestCase):
# Reset the context to avoid polluting other test cases.
context._reset_context()
@test_util.eager_lazy_remote_copy_on_and_off
def testSimpleParameterServer(self):
remote.connect_to_cluster(self._cluster)
@ -307,6 +317,7 @@ class MultiJobsTest(test.TestCase):
with ops.device('/job:my_worker/task:1/device:CPU:0'):
self.assertAllEqual(worker_fn(), 8)
@test_util.eager_lazy_remote_copy_on_and_off
def testConnectWithClusterResolver(self):
remote.connect_to_cluster(self._cluster_resolver)
@ -325,10 +336,12 @@ class MultiJobsTest(test.TestCase):
with ops.device('/job:my_worker/task:1/device:CPU:0'):
self.assertAllEqual(worker_fn(), 8)
@test_util.eager_lazy_remote_copy_on_and_off
def testConnectToClusterTwiceOk(self):
remote.connect_to_cluster(self._cluster_resolver)
remote.connect_to_cluster(self._cluster_resolver)
@test_util.eager_lazy_remote_copy_on_and_off
def testConnectToClusterOnMismatchedDevice(self):
remote.connect_to_cluster(self._cluster_resolver)
@ -338,6 +351,7 @@ class MultiJobsTest(test.TestCase):
with self.assertRaises(ValueError):
remote.connect_to_cluster(self._cluster_resolver)
@test_util.eager_lazy_remote_copy_on_and_off
def testConnectToClusterWithLocalMaster(self):
local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
remote.connect_to_cluster(local_resolver)

View File

@ -1019,6 +1019,21 @@ def build_as_function_and_v1_graph(func=None):
return decorator
def eager_lazy_remote_copy_on_and_off(f):
"""Execute the test method w/o lazy tensor copy for function remote inputs."""
@parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)])
@functools.wraps(f)
def decorator(self, lazily_remote_copy, *args, **kwargs):
if lazily_remote_copy:
context.context().lazy_remote_inputs_copy = True
else:
context.context().lazy_remote_inputs_copy = False
f(self, *args, **kwargs)
return decorator
def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,

View File

@ -110,6 +110,7 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;