From d5b3ec27d1d6bb157588ff3033a3d9bd2e46711f Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Tue, 12 May 2020 23:14:53 -0700 Subject: [PATCH] Allow dynamically configuring device placement Enable setting soft device placement as well as logging dynamically. This required ensuring the device placement policy was part of the cache key. Further, we fix the logging to ensure in eager mode if a kernel is retrieved from the kernel cache, then the execution is still logged. We also log closer to the actual op execution to avoid logging before all checks have been done. PiperOrigin-RevId: 311271808 Change-Id: I9765228894f84a3447cc03332a2559f6d933165b --- tensorflow/c/eager/c_api_experimental.cc | 14 ++++++ tensorflow/c/eager/c_api_experimental.h | 12 +++++ .../core/common_runtime/eager/context.h | 7 +-- .../core/common_runtime/eager/execute.cc | 45 ++++++++++++------- tensorflow/python/client/session_test.py | 11 ++++- tensorflow/python/eager/context.py | 16 +++---- tensorflow/python/eager/core_test.py | 1 - tensorflow/python/framework/config_test.py | 16 +++---- tensorflow/python/tfe_wrapper.cc | 12 +++++ 9 files changed, 95 insertions(+), 39 deletions(-) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index dd9e5e111d9..0d71b11531b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, std::move(tensor_handles), context, &handle); return tensorflow::wrap(handle); } + +void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetAllowSoftPlacement(enable); +} + +void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetLogDevicePlacement(enable); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 584f7222111..1b8efe61ee0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, TF_Status* status); +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 683425919d1..d034aaf2f9c 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -300,7 +300,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() const { return log_device_placement_; } + void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; } bool AllowSoftPlacement() const { return allow_soft_placement_; } + void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; } bool LogMemory() const { return log_memory_; } Rendezvous* GetRendezvous() const { return rendezvous_; } @@ -625,9 +627,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { mutex metadata_mu_; RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_); GraphCollector graph_collector_; - // TODO(fishx): Allow update following two bool after context creation. - const bool log_device_placement_; - const bool allow_soft_placement_; + std::atomic log_device_placement_; + std::atomic allow_soft_placement_; // Information related to step containers. std::atomic num_active_steps_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3036e6d7989..f6b4370bbdc 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -365,6 +365,9 @@ Status GetOrCreateKernelAndDevice( Device* device = absl::get(op->Device()); Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); + /// Include soft placement policy in cache key since the placement strategy + // can change and thus affect which kernel is picked. + cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement()); std::vector input_dev_ptrs; absl::flat_hash_map*> composite_devices; @@ -488,13 +491,6 @@ Status GetOrCreateKernelAndDevice( << KernelsRegisteredForOp(op->Name()); op->SetDevice(device); } - if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { - string msg = strings::StrCat("Executing op ", ndef.op(), " in device ", - DeviceNameOrUnspecified(device)); - if (!logging::LogToListeners(msg)) { - LOG(INFO) << msg; - } - } FunctionLibraryRuntime* flr = device == nullptr ? nullptr : ctx.func_lib(device); @@ -607,6 +603,14 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, int num_outputs = kernel->num_outputs(); TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); + if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat("Executing op ", op->Name(), " in device ", + kernel->device()->name()); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + GraphCollector* graph_collector = nullptr; if (ctx.ShouldStoreGraphs()) { graph_collector = ctx.GetGraphCollector(); @@ -841,6 +845,16 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, ctx.GetContextViewId(), eager_client.get(), op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), op->Inputs(), {retvals, num_outputs})); + + if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat( + "Executing op ", op->Name(), " on task ", + DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName())); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + Status s = executor.AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. @@ -1119,15 +1133,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, return EagerLocalExecute(op, retvals, num_retvals); } - if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { - string msg = strings::StrCat( - "Executing op ", op->Name(), " on task ", - DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName())); - if (!logging::LogToListeners(msg)) { - LOG(INFO) << msg; - } - } - #if defined(IS_MOBILE_PLATFORM) return errors::Unimplemented( "Eager's remote execution is not available on mobile devices."); @@ -1428,6 +1433,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, return; } + if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat("Executing op ", op->Name(), " in device ", + kernel->device()->name()); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + GraphCollector* graph_collector = nullptr; if (ctx.ShouldStoreGraphs()) { graph_collector = ctx.GetGraphCollector(); diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index dd8e64ac182..1c244c1b297 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1917,6 +1917,9 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(1) b = constant_op.constant(2) c = a + b + # Ensure if the same kernel with the same arguments is executed then its + # execution is logged. + d = a + b else: # Passing the config to the server, but not the session should still # result in logging device placement. @@ -1925,12 +1928,16 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(1) b = constant_op.constant(2) c = a + b + d = a + b with session.Session(server.target) as sess: with CaptureStderr() as log: - sess.run(c) + c, d = sess.run([c, d]) + self.assertEqual(c, 3) + self.assertEqual(d, 3) # Ensure that we did log device placement. - self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log)) + add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] + self.assertEqual(len(add_executions), 2) @test_util.run_v1_only('b/120545219') def testLocalMasterSessionTimeout(self): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 182b8478420..86b3d5cf95f 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1509,9 +1509,11 @@ class Context(object): return self.config.allow_soft_placement @soft_device_placement.setter - def soft_device_placement(self, enabled): - self._soft_device_placement = enabled + def soft_device_placement(self, enable): + if self._context_handle is not None: + pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) + self._soft_device_placement = enable self._thread_local_data.function_call_options = None @property @@ -1519,15 +1521,11 @@ class Context(object): return self.config.log_device_placement @log_device_placement.setter - def log_device_placement(self, enabled): - if self._log_device_placement == enabled: - return - + def log_device_placement(self, enable): if self._context_handle is not None: - raise RuntimeError( - "Device placement logging must be set at program startup") + pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) - self._log_device_placement = enabled + self._log_device_placement = enable self._thread_local_data.function_call_options = None @property diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 47b3966827f..c1401fc56ee 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -1112,5 +1112,4 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase): if __name__ == '__main__': - context.set_log_device_placement(True) test.main() diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index b07bb874385..3051f1d0623 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -159,7 +159,6 @@ class ConfigTest(test.TestCase, parameterized.TestCase): else: self.assertFalse(config.get_soft_device_placement()) - @def_function.function def mod(): with ops.device('/device:GPU:0'): a = constant_op.constant(1.0) @@ -172,8 +171,10 @@ class ConfigTest(test.TestCase, parameterized.TestCase): config.get_soft_device_placement(), context.context().soft_device_placement) - # Since soft placement is enabled, the mod operation should work with CPU + # Since soft placement is enabled, the mod operation should fallback to CPU + # with pure eager execution as well as functions mod() + def_function.function(mod)() config.set_soft_device_placement(False) self.assertEqual(config.get_soft_device_placement(), False) @@ -182,8 +183,11 @@ class ConfigTest(test.TestCase, parameterized.TestCase): context.context().soft_device_placement) # Since soft placement is disabled, the mod operation should fail on GPU + # with pure eager execution as well as functions with self.assertRaises(errors.InvalidArgumentError): mod() + with self.assertRaises(errors.InvalidArgumentError): + def_function.function(mod)() @reset_eager def testLogDevicePlacement(self): @@ -203,12 +207,8 @@ class ConfigTest(test.TestCase, parameterized.TestCase): context.ensure_initialized() - with self.assertRaises(RuntimeError): - context.set_log_device_placement(True) - - # If the setting the device placement is a no-op, do not throw a runtime - # exception. - context.set_log_device_placement(False) + # Changing the device placement should not throw an exception + context.set_log_device_placement(True) @reset_eager def testEnableMlirBridge(self): diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index ec54efa61cf..836cafbd494 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -488,6 +488,18 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // NOTE: different from TFE_ContextSyncExecutors that raises potential // errors, deliberately ignore executor statuses in cleanup. }); + m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, + status.get()); + }); + m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, + status.get()); + }); // TFE_Executor logic m.def(