Parallel device: sync executors after each parallel op
PiperOrigin-RevId: 314396769 Change-Id: I4697f3e488d351d8610cfb80da1701e2fa24848e
This commit is contained in:
parent
123f576851
commit
fb6e0c1cdd
@ -319,11 +319,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
std::vector<MaybeParallelTensorOwned> outputs;
|
||||
outputs.reserve(t->num_tensors());
|
||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TensorHandlePtr this_output(
|
||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||
outputs.emplace_back(std::move(this_output));
|
||||
@ -438,6 +433,15 @@ ParallelDevice::ExecuteParallelOperation(
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
|
Loading…
Reference in New Issue
Block a user