diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index dd9e5e111d9..0d71b11531b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, std::move(tensor_handles), context, &handle); return tensorflow::wrap(handle); } + +void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetAllowSoftPlacement(enable); +} + +void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetLogDevicePlacement(enable); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 584f7222111..1b8efe61ee0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, TF_Status* status); +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 683425919d1..d034aaf2f9c 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -300,7 +300,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() const { return log_device_placement_; } + void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; } bool AllowSoftPlacement() const { return allow_soft_placement_; } + void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; } bool LogMemory() const { return log_memory_; } Rendezvous* GetRendezvous() const { return rendezvous_; } @@ -625,9 +627,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { mutex metadata_mu_; RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_); GraphCollector graph_collector_; - // TODO(fishx): Allow update following two bool after context creation. - const bool log_device_placement_; - const bool allow_soft_placement_; + std::atomic log_device_placement_; + std::atomic allow_soft_placement_; // Information related to step containers. std::atomic num_active_steps_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3036e6d7989..f6b4370bbdc 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -365,6 +365,9 @@ Status GetOrCreateKernelAndDevice( Device* device = absl::get(op->Device()); Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); + /// Include soft placement policy in cache key since the placement strategy + // can change and thus affect which kernel is picked. + cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement()); std::vector input_dev_ptrs; absl::flat_hash_map*> composite_devices; @@ -488,13 +491,6 @@ Status GetOrCreateKernelAndDevice( << KernelsRegisteredForOp(op->Name()); op->SetDevice(device); } - if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { - string msg = strings::StrCat("Executing op ", ndef.op(), " in device ", - DeviceNameOrUnspecified(device)); - if (!logging::LogToListeners(msg)) { - LOG(INFO) << msg; - } - } FunctionLibraryRuntime* flr = device == nullptr ? nullptr : ctx.func_lib(device); @@ -607,6 +603,14 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, int num_outputs = kernel->num_outputs(); TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); + if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat("Executing op ", op->Name(), " in device ", + kernel->device()->name()); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + GraphCollector* graph_collector = nullptr; if (ctx.ShouldStoreGraphs()) { graph_collector = ctx.GetGraphCollector(); @@ -841,6 +845,16 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, ctx.GetContextViewId(), eager_client.get(), op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), op->Inputs(), {retvals, num_outputs})); + + if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat( + "Executing op ", op->Name(), " on task ", + DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName())); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + Status s = executor.AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. @@ -1119,15 +1133,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, return EagerLocalExecute(op, retvals, num_retvals); } - if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { - string msg = strings::StrCat( - "Executing op ", op->Name(), " on task ", - DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName())); - if (!logging::LogToListeners(msg)) { - LOG(INFO) << msg; - } - } - #if defined(IS_MOBILE_PLATFORM) return errors::Unimplemented( "Eager's remote execution is not available on mobile devices."); @@ -1428,6 +1433,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, return; } + if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { + string msg = strings::StrCat("Executing op ", op->Name(), " in device ", + kernel->device()->name()); + if (!logging::LogToListeners(msg)) { + LOG(INFO) << msg; + } + } + GraphCollector* graph_collector = nullptr; if (ctx.ShouldStoreGraphs()) { graph_collector = ctx.GetGraphCollector(); diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index dd8e64ac182..1c244c1b297 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1917,6 +1917,9 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(1) b = constant_op.constant(2) c = a + b + # Ensure if the same kernel with the same arguments is executed then its + # execution is logged. + d = a + b else: # Passing the config to the server, but not the session should still # result in logging device placement. @@ -1925,12 +1928,16 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(1) b = constant_op.constant(2) c = a + b + d = a + b with session.Session(server.target) as sess: with CaptureStderr() as log: - sess.run(c) + c, d = sess.run([c, d]) + self.assertEqual(c, 3) + self.assertEqual(d, 3) # Ensure that we did log device placement. - self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log)) + add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] + self.assertEqual(len(add_executions), 2) @test_util.run_v1_only('b/120545219') def testLocalMasterSessionTimeout(self): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 182b8478420..86b3d5cf95f 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1509,9 +1509,11 @@ class Context(object): return self.config.allow_soft_placement @soft_device_placement.setter - def soft_device_placement(self, enabled): - self._soft_device_placement = enabled + def soft_device_placement(self, enable): + if self._context_handle is not None: + pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) + self._soft_device_placement = enable self._thread_local_data.function_call_options = None @property @@ -1519,15 +1521,11 @@ class Context(object): return self.config.log_device_placement @log_device_placement.setter - def log_device_placement(self, enabled): - if self._log_device_placement == enabled: - return - + def log_device_placement(self, enable): if self._context_handle is not None: - raise RuntimeError( - "Device placement logging must be set at program startup") + pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) - self._log_device_placement = enabled + self._log_device_placement = enable self._thread_local_data.function_call_options = None @property diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 47b3966827f..c1401fc56ee 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -1112,5 +1112,4 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase): if __name__ == '__main__': - context.set_log_device_placement(True) test.main() diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index b07bb874385..3051f1d0623 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -159,7 +159,6 @@ class ConfigTest(test.TestCase, parameterized.TestCase): else: self.assertFalse(config.get_soft_device_placement()) - @def_function.function def mod(): with ops.device('/device:GPU:0'): a = constant_op.constant(1.0) @@ -172,8 +171,10 @@ class ConfigTest(test.TestCase, parameterized.TestCase): config.get_soft_device_placement(), context.context().soft_device_placement) - # Since soft placement is enabled, the mod operation should work with CPU + # Since soft placement is enabled, the mod operation should fallback to CPU + # with pure eager execution as well as functions mod() + def_function.function(mod)() config.set_soft_device_placement(False) self.assertEqual(config.get_soft_device_placement(), False) @@ -182,8 +183,11 @@ class ConfigTest(test.TestCase, parameterized.TestCase): context.context().soft_device_placement) # Since soft placement is disabled, the mod operation should fail on GPU + # with pure eager execution as well as functions with self.assertRaises(errors.InvalidArgumentError): mod() + with self.assertRaises(errors.InvalidArgumentError): + def_function.function(mod)() @reset_eager def testLogDevicePlacement(self): @@ -203,12 +207,8 @@ class ConfigTest(test.TestCase, parameterized.TestCase): context.ensure_initialized() - with self.assertRaises(RuntimeError): - context.set_log_device_placement(True) - - # If the setting the device placement is a no-op, do not throw a runtime - # exception. - context.set_log_device_placement(False) + # Changing the device placement should not throw an exception + context.set_log_device_placement(True) @reset_eager def testEnableMlirBridge(self): diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index ec54efa61cf..836cafbd494 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -488,6 +488,18 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // NOTE: different from TFE_ContextSyncExecutors that raises potential // errors, deliberately ignore executor statuses in cleanup. }); + m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, + status.get()); + }); + m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, + status.get()); + }); // TFE_Executor logic m.def(