diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index e6dfb20166b..bffceab1bc7 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -264,6 +264,7 @@ tensorflow::Status CreateRemoteContexts( tensorflow::uint64 context_view_id, int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + const bool lazy_copy_remote_function_inputs, const tensorflow::eager::CreateContextRequest& base_request) { int num_remote_workers = remote_workers.size(); tensorflow::BlockingCounter counter(num_remote_workers); @@ -300,6 +301,8 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); request.set_keep_alive_secs(keep_alive_secs); + request.set_lazy_copy_remote_function_inputs( + lazy_copy_remote_function_inputs); eager_client->CreateContextAsync( &request, response, @@ -319,7 +322,7 @@ tensorflow::Status CreateRemoteContexts( tensorflow::Status UpdateRemoteContexts( const std::vector& remote_workers, tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::eager::EagerClientCache* remote_eager_workers, const tensorflow::eager::CreateContextRequest& base_request) { int num_remote_workers = remote_workers.size(); tensorflow::BlockingCounter counter(num_remote_workers); @@ -527,7 +530,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), - ctx->context->Executor().Async(), base_request)); + ctx->context->Executor().Async(), + ctx->context->LazyCopyFunctionRemoteInputs(), base_request)); } else { // The master's context_view_id will be incremented by one // the UpdateRemoteMaster call later. We want all new workers and @@ -537,7 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), - ctx->context->Executor().Async(), base_request)); + ctx->context->Executor().Async(), + ctx->context->LazyCopyFunctionRemoteInputs(), base_request)); if (!existing_workers.empty()) { if (VLOG_IS_ON(1)) { for (const string& w : existing_workers) { @@ -546,8 +551,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( existing_workers, context_id, context_view_id + 1, server_def, - remote_eager_workers.get(), ctx->context->Executor().Async(), - base_request)); + remote_eager_workers.get(), base_request)); } } @@ -713,7 +717,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { return new TFE_Context(opts->session_options.options, opts->device_placement_policy, opts->mirroring_policy, - opts->async, device_mgr.release(), + opts->async, opts->lazy_remote_inputs_copy, + device_mgr.release(), /*device_mgr_owned*/ true, r, tensorflow::GetDefaultCustomKernelCreator()); } @@ -728,7 +733,8 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, return new TFE_Context(opts->session_options.options, opts->device_placement_policy, opts->mirroring_policy, - opts->async, device_mgr, /*device_mgr_owned*/ false, r, + opts->async, opts->lazy_remote_inputs_copy, device_mgr, + /*device_mgr_owned*/ false, r, tensorflow::GetDefaultCustomKernelCreator()); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index a40a435065f..b513fcedc59 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -557,6 +557,11 @@ extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( ctx->context->GetMirroringPolicy()); } +void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, + bool lazy_copy) { + options->lazy_remote_inputs_copy = lazy_copy; +} + TFE_CancellationManager* TFE_NewCancellationManager() { return new TFE_CancellationManager; } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 4da08641907..055f9f9d602 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -336,6 +336,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy( TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context*); +// Sets whether to copy the remote inputs of a function lazily. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( + TFE_ContextOptions*, bool lazy_copy); + // ----------------------------------------------------------------------------- // Cancellation APIs. diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 1841d4846ab..56ee1e01cc9 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -57,12 +57,15 @@ struct TFE_ContextOptions { TFE_ContextDevicePlacementPolicy device_placement_policy{ TFE_DEVICE_PLACEMENT_SILENT}; TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; + // If true, lazily copy the remote inputs of a function to the target devices. + bool lazy_remote_inputs_copy = false; }; struct TFE_Context { TFE_Context(const tensorflow::SessionOptions& opts, TFE_ContextDevicePlacementPolicy default_device_placement_policy, TFE_ContextMirroringPolicy default_mirroring_policy, bool async, + const bool lazy_remote_inputs_copy, const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, tensorflow::Rendezvous* rendezvous, const tensorflow::CustomKernelCreator* custom_kernel_creator) @@ -72,8 +75,8 @@ struct TFE_Context { default_device_placement_policy), static_cast( default_mirroring_policy), - async, device_mgr, device_mgr_owned, rendezvous, - custom_kernel_creator)) {} + async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned, + rendezvous, custom_kernel_creator)) {} ~TFE_Context() { // TODO(iga): Add a separate API method to shutdown TFE_Context so that we diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 3a0b24f0b22..083bcf8d85b 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -68,13 +68,12 @@ auto* eager_context_created = } // namespace -// TODO(b/134094971): Make lazily_copy_function_remote_inputs_ configurable once -// it's ready to enable. EagerContext::EagerContext( const SessionOptions& opts, ContextDevicePlacementPolicy default_device_placement_policy, ContextMirroringPolicy default_mirroring_policy, bool async, - const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous, + const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr, + bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, DistributedFunctionLibraryRuntime* cluster_flr) : default_device_placement_policy_(default_device_placement_policy), @@ -91,7 +90,7 @@ EagerContext::EagerContext( default_executor_(async), log_memory_(LogMemory::IsEnabled()), env_(opts.env), - lazily_copy_function_remote_inputs_(false), + lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs), use_send_tensor_rpc_(false), pin_small_ops_to_cpu_(ReadBoolFromEnvVar( "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) { @@ -130,7 +129,7 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env, thread::ThreadPool* thread_pool, DistributedFunctionLibraryRuntime* cluster_flr, const CustomKernelCreator* custom_kernel_creator) { - if (lazily_copy_function_remote_inputs_) { + if (lazy_copy_function_remote_inputs_) { pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime( device_mgr, env, config, graph_def_version, lib_def, optimizer_options, thread_pool, cluster_flr, custom_kernel_creator)); @@ -164,7 +163,7 @@ void EagerContext::InitDeviceMapAndAsync() { void EagerContext::ResetClusterFLR( DistributedFunctionLibraryRuntime* cluster_flr) { - cluster_flr_.Reset(cluster_flr, lazily_copy_function_remote_inputs_); + cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_); } EagerExecutor& EagerContext::Executor() { @@ -239,8 +238,8 @@ bool EagerContext::MirrorTensors() const { return GetMirroringPolicy() == MIRRORING_ALL; } -bool EagerContext::LazilyCopyFunctionRemoteInputs() const { - return lazily_copy_function_remote_inputs_; +bool EagerContext::LazyCopyFunctionRemoteInputs() const { + return lazy_copy_function_remote_inputs_; } #if !defined(IS_MOBILE_PLATFORM) diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 1749188ac17..116c6685c27 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -121,6 +121,7 @@ class EagerContext : public core::RefCounted { EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_device_placement_policy, ContextMirroringPolicy default_mirroring_policy, bool async, + const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, @@ -168,7 +169,7 @@ class EagerContext : public core::RefCounted { bool MirrorTensors() const; - bool LazilyCopyFunctionRemoteInputs() const; + bool LazyCopyFunctionRemoteInputs() const; bool FindFunctionByName(const string& name); @@ -461,7 +462,7 @@ class EagerContext : public core::RefCounted { // EagerContext owns the DistributedFunctionLibraryRuntime( // EagerClusterFunctionLibraryRuntime) if using EagerService for remote - // function execution (lazily_copy_function_remote_inputs_=true). + // function execution (lazy_copy_function_remote_inputs_=true). OwnedOrUnownedHelper cluster_flr_; // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to @@ -553,7 +554,12 @@ class EagerContext : public core::RefCounted { bool is_master_ GUARDED_BY(remote_state_mu_); #endif // IS_MOBILE_PLATFORM - bool lazily_copy_function_remote_inputs_; + // For a multi device function, the target device of each input is unknown + // until the function is instantiated on the default function device. + // If false, eagerly copy all remote inputs to the default function device; + // if true, lazily copy remote inputs to their target devices to avoid + // redundant copies. + bool lazy_copy_function_remote_inputs_ = false; bool use_send_tensor_rpc_; const bool pin_small_ops_to_cpu_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index f11a37f204d..e783aaefdc6 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -212,7 +212,7 @@ Status ValidateInputTypeAndPlacement( " inputs, got ", n_inputs); } const bool skip_remote_copy = - ctx->LazilyCopyFunctionRemoteInputs() && kernel->IsFunction(); + ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction(); for (int i = 0; i < n_inputs; ++i) { TensorHandle* handle = op->Inputs()[i]; Device* expected_device = kernel->InputDevice(i); @@ -499,14 +499,12 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey", profiler::TraceMeLevel::kInfo); input_dev_ptrs.reserve(op->Inputs().size()); - // When LazilyCopyFunctionRemoteInputs is disabled, all inputs need to be on + // When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on // local devices, since we execute a remote function through worker service, // which doesn't accept remote inputs. - // TODO(b/134094971): Make resource_dtypes_and_shapes avaliable without - // remote tensor copy. for (int i = 0; i < op->Inputs().size(); i++) { TensorHandle* input = op->Inputs()[i]; - if (!ctx->LazilyCopyFunctionRemoteInputs() && input->IsRemote()) { + if (!ctx->LazyCopyFunctionRemoteInputs() && input->IsRemote()) { TensorHandle* handle = nullptr; TF_RETURN_IF_ERROR(EagerCopyToDevice( input, ctx, &executor, device == nullptr ? ctx->HostCPU() : device, @@ -603,7 +601,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, << ". Full node_def=" << ndef.DebugString(); std::function get_op_id = nullptr; #if !defined(IS_MOBILE_PLATFORM) - if (ctx->LazilyCopyFunctionRemoteInputs()) { + if (ctx->LazyCopyFunctionRemoteInputs()) { get_op_id = [ctx]() { return ctx->RemoteMgr()->NextOpId(); }; } #endif // IS_MOBILE_PLATFORM @@ -750,7 +748,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, profiler::TraceMe activity("CopyInputToExpectedDevice", profiler::TraceMeLevel::kInfo); const bool eagerly_copy_function_remote_inputs = - !ctx->LazilyCopyFunctionRemoteInputs() || !op->is_function(); + !ctx->LazyCopyFunctionRemoteInputs() || !op->is_function(); for (int i = 0; i < op->Inputs().size(); i++) { tensorflow::TensorHandle* input = op->Inputs()[i]; tensorflow::Device* input_device = input->device(); @@ -834,12 +832,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } } - if (ctx->LazilyCopyFunctionRemoteInputs()) { + if (ctx->LazyCopyFunctionRemoteInputs()) { // Store the data type and shape of a remote resource variable on the // corresponding remote TensorHandle (output of 'VarHandleOp'). // If the variable is an input of a remote function, the function may need // the type and shape during function instantiation. When - // LazilyCopyFunctionRemoteInputs is enabled, we no longer copy the resource + // LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource // handle (contains the type and shape) of the variable to the default // function device. Instead, we store the type and shape on eager master // and sent them to the default function device along with the diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index c221d76aafa..a1cfe5813f1 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -216,7 +216,7 @@ void EagerClusterFunctionLibraryRuntime::CleanUp( DistributedFunctionLibraryRuntime* CreateClusterFLR( const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) { - if (ctx->LazilyCopyFunctionRemoteInputs()) { + if (ctx->LazyCopyFunctionRemoteInputs()) { return new EagerClusterFunctionLibraryRuntime( context_id, ctx, worker_session->remote_device_mgr()); } else { diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 083aeefbf7b..92e3d2fb3cf 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -162,8 +162,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, tensorflow::EagerContext* ctx = new tensorflow::EagerContext( opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), - device_mgr, false, r, GetDefaultCustomKernelCreator(), - worker_session->cluster_flr()); + request->lazy_copy_remote_function_inputs(), device_mgr, false, r, + GetDefaultCustomKernelCreator(), worker_session->cluster_flr()); // Ownership will be transferred to the ServerContext, or else in an error // case ctx will be deleted by this unref. core::ScopedUnref unref_ctx(ctx); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 8b8fe42502a..dbf3c6370bc 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -710,8 +710,9 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) { tensorflow::EagerContext* ctx = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, - device_mgr_.get(), false, rendezvous, GetDefaultCustomKernelCreator()); + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false, + /*lazy_copy_function_remote_inputs=*/false, device_mgr_.get(), false, + rendezvous, GetDefaultCustomKernelCreator()); const uint64 context_id = random::New64(); // Set RemoteMgr to ctx. diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 7b68100b543..6bb4943ffee 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -55,9 +55,9 @@ class RemoteMgrTest : public ::testing::Test { ctx_ = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, - device_mgr.release(), true, rendezvous, GetDefaultCustomKernelCreator(), - nullptr); + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false, + /*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), true, + rendezvous, GetDefaultCustomKernelCreator(), nullptr); } ~RemoteMgrTest() override { ctx_->Unref(); } diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index d90ba548e0e..4335d87309a 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -90,6 +90,11 @@ message CreateContextRequest { // The view ID of the context. fixed64 context_view_id = 8; + // For a multi device function, if false, eagerly copy all remote inputs to + // the default function device; if true, lazily copy remote inputs to their + // target devices after function instantiation to avoid redundant copies. + bool lazy_copy_remote_function_inputs = 9; + reserved 5; } diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index bed38bdffdd..2be928073ff 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -47,8 +47,8 @@ tensorflow::Status DelegateData::Prepare( session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, - /*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true, - rendezvous, nullptr); + /*async=*/false, /*lazy_copy_function_remote_inputs=*/false, + device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr); return tensorflow::Status(); } diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 4a76bd79513..5bc654c2184 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -775,9 +775,11 @@ cuda_py_test( ":def_function", ":test", ":remote", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:resource_variable_ops", ], grpc_enabled = True, + shard_count = 2, tags = [ "no_oss", # This test launches local server. "optonly", # times out diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 1f13f163e0e..8de73bc35d1 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -408,6 +408,7 @@ class Context(object): if execution_mode is None: execution_mode = SYNC self._default_is_async = execution_mode == ASYNC + self._lazy_remote_inputs_copy = False self._server_def = server_def self._collective_ops_server_def = None self._collective_leader = None @@ -502,6 +503,9 @@ class Context(object): opts, self._mirroring_policy) if self._default_is_async == ASYNC: pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + if self._lazy_remote_inputs_copy: + pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy( + opts, True) context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) @@ -1445,6 +1449,22 @@ class Context(object): pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy( self._handle, self._mirroring_policy) + @property + def lazy_remote_inputs_copy(self): + return self._lazy_remote_inputs_copy + + @lazy_remote_inputs_copy.setter + def lazy_remote_inputs_copy(self, lazy_copy): + """Sets whether to copy remote inputs lazily for functions.""" + if not isinstance(lazy_copy, bool): + raise ValueError("Expecting a boolean but got %s" % type(lazy_copy)) + + if self._lazy_remote_inputs_copy != lazy_copy: + if self._initialized: + raise ValueError( + "lazy_remote_inputs_copy should be set before being initialized.") + self._lazy_remote_inputs_copy = lazy_copy + def enable_run_metadata(self): """Enables tracing of op execution via RunMetadata. diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index 7008b0a124f..d0030774fde 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import random +from absl.testing import parameterized import numpy as np from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver @@ -39,7 +40,7 @@ from tensorflow.python.training import server_lib from tensorflow.python.training.server_lib import ClusterSpec -class SingleWorkerTest(test.TestCase): +class SingleWorkerTest(test.TestCase, parameterized.TestCase): def setUp(self): super(SingleWorkerTest, self).setUp() @@ -55,6 +56,7 @@ class SingleWorkerTest(test.TestCase): # Reset the context to avoid polluting other test cases. context._reset_context() + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionBasic(self): @def_function.function @@ -69,6 +71,7 @@ class SingleWorkerTest(test.TestCase): self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5]) self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionVariable(self): with ops.device('/job:worker/replica:0/task:0/cpu:0'): variable_b = variables.Variable(1) @@ -79,6 +82,7 @@ class SingleWorkerTest(test.TestCase): self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionRemoteOutput(self): with ops.device('/job:worker/replica:0/task:0/cpu:0'): variable_b = variables.Variable(1) @@ -134,6 +138,7 @@ class SingleWorkerTest(test.TestCase): self.assertIn('Dimensions must be equal', cm.exception.message) + @test_util.eager_lazy_remote_copy_on_and_off def testShapeError_Function(self): @def_function.function @@ -150,7 +155,7 @@ class SingleWorkerTest(test.TestCase): self.assertIn('Dimensions must be equal', cm.exception.message) -class MultiWorkersTest(test.TestCase): +class MultiWorkersTest(test.TestCase, parameterized.TestCase): def setUp(self): super(MultiWorkersTest, self).setUp() @@ -167,6 +172,7 @@ class MultiWorkersTest(test.TestCase): # Reset the context to avoid polluting other test cases. context._reset_context() + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionOnLocalDevice(self): with ops.device('/job:worker/replica:0/task:1'): variable_b = variables.Variable(1.0) @@ -180,6 +186,7 @@ class MultiWorkersTest(test.TestCase): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionOnRemoteDevice(self): with ops.device('/job:worker/replica:0/task:1'): variable_b = variables.Variable(1.0) @@ -209,6 +216,7 @@ class MultiWorkersTest(test.TestCase): with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceWhileLoopOnRemoteDevice(self): with ops.device('/job:worker/replica:0/task:1'): variable_b = variables.Variable(1.0) @@ -241,6 +249,7 @@ class MultiWorkersTest(test.TestCase): with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + @test_util.eager_lazy_remote_copy_on_and_off def testSimpleParameterServer(self): with ops.device('/job:worker/task:2/device:CPU:0'): @@ -263,7 +272,7 @@ class MultiWorkersTest(test.TestCase): _GRPC_PREFIX = 'grpc://' -class MultiJobsTest(test.TestCase): +class MultiJobsTest(test.TestCase, parameterized.TestCase): def setUp(self): super(MultiJobsTest, self).setUp() @@ -288,6 +297,7 @@ class MultiJobsTest(test.TestCase): # Reset the context to avoid polluting other test cases. context._reset_context() + @test_util.eager_lazy_remote_copy_on_and_off def testSimpleParameterServer(self): remote.connect_to_cluster(self._cluster) @@ -307,6 +317,7 @@ class MultiJobsTest(test.TestCase): with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) + @test_util.eager_lazy_remote_copy_on_and_off def testConnectWithClusterResolver(self): remote.connect_to_cluster(self._cluster_resolver) @@ -325,10 +336,12 @@ class MultiJobsTest(test.TestCase): with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) + @test_util.eager_lazy_remote_copy_on_and_off def testConnectToClusterTwiceOk(self): remote.connect_to_cluster(self._cluster_resolver) remote.connect_to_cluster(self._cluster_resolver) + @test_util.eager_lazy_remote_copy_on_and_off def testConnectToClusterOnMismatchedDevice(self): remote.connect_to_cluster(self._cluster_resolver) @@ -338,6 +351,7 @@ class MultiJobsTest(test.TestCase): with self.assertRaises(ValueError): remote.connect_to_cluster(self._cluster_resolver) + @test_util.eager_lazy_remote_copy_on_and_off def testConnectToClusterWithLocalMaster(self): local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local') remote.connect_to_cluster(local_resolver) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index f96a4a58822..cda81a57d29 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1019,6 +1019,21 @@ def build_as_function_and_v1_graph(func=None): return decorator +def eager_lazy_remote_copy_on_and_off(f): + """Execute the test method w/o lazy tensor copy for function remote inputs.""" + + @parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)]) + @functools.wraps(f) + def decorator(self, lazily_remote_copy, *args, **kwargs): + if lazily_remote_copy: + context.context().lazy_remote_inputs_copy = True + else: + context.context().lazy_remote_inputs_copy = False + f(self, *args, **kwargs) + + return decorator + + def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 25106769c15..e3984c37657 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -110,6 +110,7 @@ limitations under the License. %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; %rename("%s") TFE_ContextOptionsSetMirroringPolicy; %rename("%s") TFE_ContextOptionsSetAsync; +%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; %rename("%s") TFE_Py_TensorShapeOnDevice;