diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 7740131b7c0..244a84e948f 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -28,6 +28,7 @@ cc_library( ]), copts = tf_copts(), deps = if_cuda([ + "@com_google_absl//absl/container:flat_hash_map", "@local_config_nccl//:nccl", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index d968fac833b..ebb2aab44e0 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -17,7 +17,6 @@ limitations under the License. #ifdef GOOGLE_CUDA -#include #include // TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when @@ -27,6 +26,7 @@ limitations under the License. #define gpu_assert(x) #endif +#include "absl/container/flat_hash_map.h" #include "third_party/nccl/nccl.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -214,12 +214,12 @@ class NcclManager { mutex mu_; // Maps key to collectives currently being assembled or run. - std::unordered_map collectives_ GUARDED_BY(mu_); + absl::flat_hash_map collectives_ GUARDED_BY(mu_); // Maps a device to the communication streams that make up its collective. // This is used to share the stream across different communicators that // include the same device. - std::map> + absl::flat_hash_map> device_to_comm_streams_ GUARDED_BY(mu_); std::vector> communicators_;