110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
# Description:
|
|
# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
|
|
# APIs are meant to change over time.
|
|
|
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|
|
|
# buildifier: disable=same-origin-load
|
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
|
|
|
# buildifier: disable=same-origin-load
|
|
load("//tensorflow:tensorflow.bzl", "filegroup")
|
|
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm", "tf_copts")
|
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
|
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
|
load(
|
|
"//tensorflow/core/platform:build_config_root.bzl",
|
|
"tf_cuda_tests_tags",
|
|
)
|
|
|
|
# buildifier: disable=same-origin-load
|
|
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
|
|
|
package(
|
|
default_visibility = ["//tensorflow:__subpackages__"],
|
|
licenses = ["notice"], # Apache 2.0
|
|
)
|
|
|
|
cc_library(
|
|
name = "nccl_lib",
|
|
srcs = if_cuda_or_rocm([
|
|
"nccl_manager.cc",
|
|
"nccl_rewrite.cc",
|
|
]),
|
|
hdrs = if_cuda_or_rocm([
|
|
"nccl_manager.h",
|
|
]),
|
|
copts = tf_copts(),
|
|
deps = if_cuda([
|
|
"@local_config_nccl//:nccl",
|
|
]) + if_rocm([
|
|
"@local_config_rocm//rocm:rccl",
|
|
"//tensorflow/core:gpu_runtime",
|
|
]) + if_cuda_or_rocm([
|
|
"@com_google_absl//absl/base",
|
|
"@com_google_absl//absl/container:flat_hash_map",
|
|
"//tensorflow/core:core_cpu",
|
|
"//tensorflow/core:framework",
|
|
"//tensorflow/core:gpu_headers_lib",
|
|
"//tensorflow/core:lib",
|
|
"//tensorflow/core/platform:stream_executor",
|
|
"//tensorflow/core/profiler/lib:traceme",
|
|
"//tensorflow/core/profiler/lib:connected_traceme",
|
|
"//tensorflow/core/profiler/lib:annotated_traceme",
|
|
]),
|
|
alwayslink = 1,
|
|
)
|
|
|
|
tf_cuda_cc_test(
|
|
name = "nccl_manager_test",
|
|
size = "medium",
|
|
srcs = ["nccl_manager_test.cc"],
|
|
tags = tf_cuda_tests_tags() + [
|
|
"noguitar", # TODO(b/176867216): Flaky.
|
|
"manual",
|
|
"multi_gpu",
|
|
"no_oss",
|
|
# TODO(b/147451637): Replace 'no_rocm' with 'rocm_multi_gpu'.
|
|
"no_rocm",
|
|
"notap",
|
|
],
|
|
deps = [
|
|
"//tensorflow/core:test",
|
|
"//tensorflow/core:test_main",
|
|
"//tensorflow/core:testlib",
|
|
] + if_cuda_or_rocm([
|
|
":nccl_lib",
|
|
]) + if_cuda([
|
|
"@local_config_nccl//:nccl",
|
|
"//tensorflow/core:cuda",
|
|
]) + if_rocm([
|
|
"@local_config_rocm//rocm:rccl",
|
|
"//tensorflow/core/common_runtime/gpu:rocm",
|
|
]),
|
|
)
|
|
|
|
cc_library(
|
|
name = "collective_communicator",
|
|
srcs = ["collective_communicator.cc"],
|
|
hdrs = ["collective_communicator.h"],
|
|
copts = tf_copts() + if_nccl(["-DTENSORFLOW_USE_NCCL=1"]),
|
|
visibility = [
|
|
"//learning/brain/runtime:__subpackages__",
|
|
"//tensorflow:__subpackages__",
|
|
],
|
|
deps =
|
|
["//tensorflow/core:framework"] + if_nccl([
|
|
":nccl_lib",
|
|
"@com_google_absl//absl/memory",
|
|
"//tensorflow/core/profiler/lib:traceme",
|
|
]),
|
|
)
|
|
|
|
filegroup(
|
|
name = "mobile_srcs",
|
|
srcs = [
|
|
"collective_communicator.cc",
|
|
"collective_communicator.h",
|
|
],
|
|
)
|