Support async execution in ParallelDevice when remote eager is not in use.

PiperOrigin-RevId: 337335617
Change-Id: I4290de7b7af55cd1ac129073255cd24f5f08dda7
This commit is contained in:
Yujing Zhang 2020-10-15 10:25:48 -07:00 committed by TensorFlower Gardener
parent bfcb19d049
commit bb90cdb01c
2 changed files with 9 additions and 5 deletions
tensorflow/c/eager/parallel_device

View File

@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
explicit DeviceThread(const std::string& device, const bool is_async)
: status_(TF_NewStatus()),
device_(device),
// If the context's default exector is set to async, re-using that in
@ -67,7 +67,7 @@ class DeviceThread {
//
// TODO(allenl): We should have an async API that works with the
// parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)),
executor_(TFE_NewExecutor(is_async)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
const bool is_async)
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
new DeviceThread(devices[device_index].c_str(), is_async));
}
}

View File

@ -49,7 +49,10 @@ class DeviceThread;
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
// Eager async execution is only supported when remote eager is not in use
// (b/157523095).
explicit ParallelDevice(const std::vector<std::string>& devices,
const bool is_async = false);
~ParallelDevice();