Use unordered maps since key ordering is not needed; switched to using
absl::flat_hash_map. PiperOrigin-RevId: 252507005
This commit is contained in:
parent
a61d342d90
commit
4e7bf7f554
@ -28,6 +28,7 @@ cc_library(
|
|||||||
]),
|
]),
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = if_cuda([
|
deps = if_cuda([
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@local_config_nccl//:nccl",
|
"@local_config_nccl//:nccl",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#ifdef GOOGLE_CUDA
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when
|
// TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when
|
||||||
@ -27,6 +26,7 @@ limitations under the License.
|
|||||||
#define gpu_assert(x)
|
#define gpu_assert(x)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -214,12 +214,12 @@ class NcclManager {
|
|||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
|
||||||
// Maps key to collectives currently being assembled or run.
|
// Maps key to collectives currently being assembled or run.
|
||||||
std::unordered_map<string, Collective*> collectives_ GUARDED_BY(mu_);
|
absl::flat_hash_map<string, Collective*> collectives_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// Maps a device to the communication streams that make up its collective.
|
// Maps a device to the communication streams that make up its collective.
|
||||||
// This is used to share the stream across different communicators that
|
// This is used to share the stream across different communicators that
|
||||||
// include the same device.
|
// include the same device.
|
||||||
std::map<se::StreamExecutor*, std::vector<NcclStream*>>
|
absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>>
|
||||||
device_to_comm_streams_ GUARDED_BY(mu_);
|
device_to_comm_streams_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Communicator>> communicators_;
|
std::vector<std::unique_ptr<Communicator>> communicators_;
|
||||||
|
Loading…
Reference in New Issue
Block a user