From e4df992d2e10a02f3aea0522699e641d4de359b1 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Mon, 18 Mar 2019 18:19:37 -0700 Subject: [PATCH] Fill the `runner` field in FLR options with device private runner. When running multi-device eager functions, pass the device-private runner (if available) to launch internal ops. It makes sure that GPU kernel launching ops are executed in GPU private thread pool. PiperOrigin-RevId: 239103306 --- .../common_runtime/process_function_library_runtime.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 7cc6bca4afd..3e14305725d 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -790,11 +790,16 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_; opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_; opts_copy.remote_execution = false; + + FunctionLibraryRuntime* flr = GetFLR(target); + // When target device has private thread pool, use the target device runner + thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool(); + opts_copy.runner = (pool == nullptr) ? opts_copy.runner : flr->runner(); std::vector comp_args = GetArgsForIndices(comp_data.arg_indices_, args); std::vector* comp_rets = new std::vector; rets->resize(data->num_outputs_); - GetFLR(target)->Run( + flr->Run( opts_copy, handle, comp_args, comp_rets, [comp_rets, rets, comp_data, refcounted_done](const Status& status) { if (!status.ok()) {