diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 92a0fea8578..98278e3ae9d 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -43,7 +43,6 @@ cc_library( ]) + if_cuda_or_rocm([ "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 43c9b229450..157c255a316 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -632,7 +632,7 @@ void NcclManager::RunCollective(Collective* collective) { // Wait to ensure that the kernel that produces the data in the input // tensor has finished running before the nccl kernel runs on the // communication stream. - nccl_stream->stream->ThenWaitFor(p->input_event.get()); + nccl_stream->stream->ThenWaitFor(p->tensor_stream); } if (p->root) { if (collective->root_rank == -1) { diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index b91ef86042a..88b8bc85663 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -27,7 +27,6 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM @@ -77,7 +76,6 @@ class NcclManager { context(static_cast(info->default_context)), #endif input(input), - input_event(nullptr), output(output), global_rank(global_rank), done_callback(std::move(done_callback)), @@ -85,11 +83,6 @@ class NcclManager { DCHECK(executor != nullptr); DCHECK(event_mgr != nullptr); DCHECK(tensor_stream != nullptr); - if (input != nullptr) { - input_event = absl::make_unique(executor); - input_event->Init(); - tensor_stream->ThenRecordEvent(input_event.get()); - } } // StreamExecutor for the device. Expected to be live for process lifetime. @@ -118,10 +111,6 @@ class NcclManager { // called. Is NULL for participants that only receive data. const Tensor* input; - // Wait on this event rather than synchronizing on the entire stream. - // This allows greater concurrency between compute and nccl streams. - std::unique_ptr input_event; - // Owned by the caller, who must keep it live until `done_callback` is // called. Is NULL for participants that only send data. Tensor* output;