Fill the runner
field in FLR options with device private runner.
When running multi-device eager functions, pass the device-private runner (if available) to launch internal ops. It makes sure that GPU kernel launching ops are executed in GPU private thread pool. PiperOrigin-RevId: 239103306
This commit is contained in:
parent
6e9cb400d1
commit
e4df992d2e
@ -790,11 +790,16 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
|||||||
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
|
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
|
||||||
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
|
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
|
||||||
opts_copy.remote_execution = false;
|
opts_copy.remote_execution = false;
|
||||||
|
|
||||||
|
FunctionLibraryRuntime* flr = GetFLR(target);
|
||||||
|
// When target device has private thread pool, use the target device runner
|
||||||
|
thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
|
||||||
|
opts_copy.runner = (pool == nullptr) ? opts_copy.runner : flr->runner();
|
||||||
std::vector<Tensor> comp_args =
|
std::vector<Tensor> comp_args =
|
||||||
GetArgsForIndices(comp_data.arg_indices_, args);
|
GetArgsForIndices(comp_data.arg_indices_, args);
|
||||||
std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
|
std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
|
||||||
rets->resize(data->num_outputs_);
|
rets->resize(data->num_outputs_);
|
||||||
GetFLR(target)->Run(
|
flr->Run(
|
||||||
opts_copy, handle, comp_args, comp_rets,
|
opts_copy, handle, comp_args, comp_rets,
|
||||||
[comp_rets, rets, comp_data, refcounted_done](const Status& status) {
|
[comp_rets, rets, comp_data, refcounted_done](const Status& status) {
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user