Use device private thread pool to launch ops when gpu_private
thread mode is turned on in eager runtime.
When setting the environment variables `TF_GPU_THREAD_MODE=gpu_private TF_GPU_THREAD_COUNT=2`, TensorFlow allocates separate thread pools for GPU devices, to reduce contention on GPU kernel launching. Though it's not TF official API, it's been important for TF to achieve good performance. This CL exposes the FunctionLibraryRuntime's internal runner (which is initialized using the correct device private thread pool if available) to kernel and device ops. When using the instance to run, it will launch ops using the specified runner instead of the default one in EagerContext. PiperOrigin-RevId: 237132982
This commit is contained in:
parent
11599b4a8b
commit
3850ab8a0d
@ -401,6 +401,7 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
"Unable to find a FunctionLibraryRuntime corresponding to device ",
|
||||
device->name());
|
||||
}
|
||||
auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx->runner();
|
||||
GraphCollector* graph_collector = nullptr;
|
||||
if (ctx->ShouldStoreGraphs()) {
|
||||
graph_collector = ctx->GetGraphCollector();
|
||||
@ -418,14 +419,14 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
<< "compile_with_xla=" << compile_with_xla
|
||||
<< ". Full node_def=" << ndef.DebugString();
|
||||
kernel = new KernelAndDeviceFunc(
|
||||
flr, ctx->pflr(), std::move(input_dev_ptrs), ctx->runner(),
|
||||
flr, ctx->pflr(), std::move(input_dev_ptrs), runner,
|
||||
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
|
||||
} else {
|
||||
VLOG(2) << "Running " << ndef.op() << " using op kernel. "
|
||||
<< "compile_with_xla=" << compile_with_xla
|
||||
<< ". Full node_def=" << ndef.DebugString();
|
||||
kernel = new KernelAndDeviceOp(
|
||||
ctx->GetRendezvous(), ctx->LogMemory(), flr, ctx->runner(),
|
||||
ctx->GetRendezvous(), ctx->LogMemory(), flr, runner,
|
||||
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU());
|
||||
}
|
||||
|
||||
|
@ -192,6 +192,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
|
||||
|
||||
Env* env() override;
|
||||
Device* device() override;
|
||||
std::function<void(std::function<void()>)>* runner() override;
|
||||
const DeviceMgr* device_mgr() const override;
|
||||
|
||||
string DebugString(Handle handle) override;
|
||||
@ -266,6 +267,11 @@ Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
|
||||
|
||||
Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
|
||||
|
||||
std::function<void(std::function<void()>)>*
|
||||
FunctionLibraryRuntimeOverlay::runner() {
|
||||
return base_flr_->runner();
|
||||
}
|
||||
|
||||
const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
|
||||
return base_flr_->device_mgr();
|
||||
}
|
||||
@ -333,6 +339,11 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
}
|
||||
|
||||
Device* device() override { return device_; }
|
||||
|
||||
std::function<void(std::function<void()>)>* runner() override {
|
||||
return &default_runner_;
|
||||
}
|
||||
|
||||
const DeviceMgr* device_mgr() const override { return device_mgr_; }
|
||||
Env* env() override { return env_; }
|
||||
int graph_def_version() override { return graph_def_version_; }
|
||||
|
@ -661,6 +661,11 @@ class FunctionLibraryRuntime {
|
||||
// Returns the device on which the function executes.
|
||||
virtual Device* device() = 0;
|
||||
|
||||
// Returns the default runner in which the ops should be launched. If the
|
||||
// device on which the function executes has a private thread pool, return
|
||||
// runner on the device local thread pool.
|
||||
virtual std::function<void(std::function<void()>)>* runner() = 0;
|
||||
|
||||
// Get the DeviceMgr from which the device was obtained.
|
||||
virtual const DeviceMgr* device_mgr() const = 0;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user