Rolling back due to:
https://github.com/tensorflow/tensorflow/issues/41539 https://github.com/tensorflow/tensorflow/issues/41980 Resolves #41539, resolves #41980. PiperOrigin-RevId: 336736742 Change-Id: Ibcc53f3fbf9c798da95d9bb4fdb62b65ead56d4d
This commit is contained in:
parent
0930645d24
commit
da8b395cf7
@ -43,7 +43,6 @@ cc_library(
|
|||||||
]) + if_cuda_or_rocm([
|
]) + if_cuda_or_rocm([
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:gpu_headers_lib",
|
"//tensorflow/core:gpu_headers_lib",
|
||||||
|
|||||||
@ -632,7 +632,7 @@ void NcclManager::RunCollective(Collective* collective) {
|
|||||||
// Wait to ensure that the kernel that produces the data in the input
|
// Wait to ensure that the kernel that produces the data in the input
|
||||||
// tensor has finished running before the nccl kernel runs on the
|
// tensor has finished running before the nccl kernel runs on the
|
||||||
// communication stream.
|
// communication stream.
|
||||||
nccl_stream->stream->ThenWaitFor(p->input_event.get());
|
nccl_stream->stream->ThenWaitFor(p->tensor_stream);
|
||||||
}
|
}
|
||||||
if (p->root) {
|
if (p->root) {
|
||||||
if (collective->root_rank == -1) {
|
if (collective->root_rank == -1) {
|
||||||
|
|||||||
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
@ -77,7 +76,6 @@ class NcclManager {
|
|||||||
context(static_cast<GPUDeviceContext*>(info->default_context)),
|
context(static_cast<GPUDeviceContext*>(info->default_context)),
|
||||||
#endif
|
#endif
|
||||||
input(input),
|
input(input),
|
||||||
input_event(nullptr),
|
|
||||||
output(output),
|
output(output),
|
||||||
global_rank(global_rank),
|
global_rank(global_rank),
|
||||||
done_callback(std::move(done_callback)),
|
done_callback(std::move(done_callback)),
|
||||||
@ -85,11 +83,6 @@ class NcclManager {
|
|||||||
DCHECK(executor != nullptr);
|
DCHECK(executor != nullptr);
|
||||||
DCHECK(event_mgr != nullptr);
|
DCHECK(event_mgr != nullptr);
|
||||||
DCHECK(tensor_stream != nullptr);
|
DCHECK(tensor_stream != nullptr);
|
||||||
if (input != nullptr) {
|
|
||||||
input_event = absl::make_unique<se::Event>(executor);
|
|
||||||
input_event->Init();
|
|
||||||
tensor_stream->ThenRecordEvent(input_event.get());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamExecutor for the device. Expected to be live for process lifetime.
|
// StreamExecutor for the device. Expected to be live for process lifetime.
|
||||||
@ -118,10 +111,6 @@ class NcclManager {
|
|||||||
// called. Is NULL for participants that only receive data.
|
// called. Is NULL for participants that only receive data.
|
||||||
const Tensor* input;
|
const Tensor* input;
|
||||||
|
|
||||||
// Wait on this event rather than synchronizing on the entire stream.
|
|
||||||
// This allows greater concurrency between compute and nccl streams.
|
|
||||||
std::unique_ptr<se::Event> input_event;
|
|
||||||
|
|
||||||
// Owned by the caller, who must keep it live until `done_callback` is
|
// Owned by the caller, who must keep it live until `done_callback` is
|
||||||
// called. Is NULL for participants that only send data.
|
// called. Is NULL for participants that only send data.
|
||||||
Tensor* output;
|
Tensor* output;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user