[NFC] Moved GlobalDeviceId
into its own file as it is also used by the CPU runtime.
Moved everything else in `gpu_executable_run_option.h` into the `gpu` namespace. PiperOrigin-RevId: 344029444 Change-Id: Id911cb046a176313f916354ff63ee97d2b5e5828
This commit is contained in:
parent
6f2db1eea2
commit
d39f8a2490
@ -103,12 +103,12 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options(
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
const gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
gpu_executable_run_options_ = gpu_executable_run_options;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const GpuExecutableRunOptions*
|
||||
const gpu::GpuExecutableRunOptions*
|
||||
ExecutableRunOptions::gpu_executable_run_options() const {
|
||||
return gpu_executable_run_options_;
|
||||
}
|
||||
|
@ -38,7 +38,9 @@ namespace xla {
|
||||
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
namespace gpu {
|
||||
class GpuExecutableRunOptions;
|
||||
} // namespace gpu
|
||||
|
||||
// A unique identifier for a particular "logical execution" of an XLA model.
|
||||
//
|
||||
@ -150,8 +152,8 @@ class ExecutableRunOptions {
|
||||
// GPU-backend specific options. These are kept out-of-line to avoid bloating
|
||||
// the size of this dependency for CPU-only AOT builds.
|
||||
ExecutableRunOptions& set_gpu_executable_run_options(
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options);
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options() const;
|
||||
const gpu::GpuExecutableRunOptions* gpu_executable_run_options);
|
||||
const gpu::GpuExecutableRunOptions* gpu_executable_run_options() const;
|
||||
|
||||
private:
|
||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
@ -165,7 +167,7 @@ class ExecutableRunOptions {
|
||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||
ThenExecuteFunction* then_execute_function_ = nullptr;
|
||||
RunId run_id_;
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
|
||||
const gpu::GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -179,7 +179,7 @@ class NcclIdStore {
|
||||
client_(std::move(client)),
|
||||
device_to_node_(std::move(device_to_node)) {}
|
||||
|
||||
StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
|
||||
StatusOr<std::string> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
|
||||
|
||||
private:
|
||||
const int node_id_;
|
||||
@ -187,10 +187,12 @@ class NcclIdStore {
|
||||
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<gpu::NcclCliqueKey, std::string> cache_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
|
||||
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
|
||||
const gpu::NcclCliqueKey& key) {
|
||||
// The caller must ensure that threads calling this method concurrently have
|
||||
// unique keys, otherwise the global key-value store may hold the wrong value.
|
||||
{
|
||||
@ -241,7 +243,7 @@ Status BuildDistributedDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
||||
GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
LocalTopologyProto local_topology;
|
||||
local_topology.set_node_id(node_id);
|
||||
for (const auto& local_device : local_device_states) {
|
||||
@ -292,7 +294,7 @@ Status BuildDistributedDevices(
|
||||
auto nccl_id_store = std::make_shared<NcclIdStore>(
|
||||
node_id, distributed_client, device_to_node);
|
||||
gpu_executable_run_options->set_nccl_unique_id_callback(
|
||||
[nccl_id_store](const NcclCliqueKey& key) {
|
||||
[nccl_id_store](const gpu::NcclCliqueKey& key) {
|
||||
return nccl_id_store->GetNcclUniqueId(key);
|
||||
});
|
||||
return Status::OK();
|
||||
@ -320,7 +322,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
auto gpu_run_options = absl::make_unique<GpuExecutableRunOptions>();
|
||||
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
||||
if (distributed_client) {
|
||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||
std::move(local_device_states), std::move(distributed_client), node_id,
|
||||
|
@ -188,7 +188,7 @@ PjRtClient::PjRtClient(
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
|
||||
: platform_id_(tensorflow::Fingerprint64(platform_name)),
|
||||
platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
|
@ -185,7 +185,7 @@ class PjRtClient {
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
@ -222,7 +222,7 @@ class PjRtClient {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
GpuExecutableRunOptions* gpu_run_options() const {
|
||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
@ -358,7 +358,7 @@ class PjRtClient {
|
||||
// transfer via pinned memory.
|
||||
bool should_stage_host_to_device_transfers_;
|
||||
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
|
||||
|
||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||
};
|
||||
|
@ -5250,3 +5250,15 @@ tf_cc_test(
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "global_device_id",
|
||||
srcs = ["global_device_id.cc"],
|
||||
hdrs = ["global_device_id.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -165,7 +165,7 @@ struct AllReduceParticipantData : ParticipantData {
|
||||
PrimitiveType primitive_type;
|
||||
};
|
||||
std::vector<Buffer> buffers;
|
||||
const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
|
||||
const gpu::NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
|
||||
|
||||
ReductionKind reduction_kind;
|
||||
|
||||
|
31
tensorflow/compiler/xla/service/global_device_id.cc
Normal file
31
tensorflow/compiler/xla/service/global_device_id.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
|
||||
std::vector<int64> values;
|
||||
values.reserve(ids.size());
|
||||
for (GlobalDeviceId id : ids) {
|
||||
values.push_back(id.value());
|
||||
}
|
||||
return absl::StrJoin(values, ",");
|
||||
}
|
||||
|
||||
} // namespace xla
|
38
tensorflow/compiler/xla/service/global_device_id.h
Normal file
38
tensorflow/compiler/xla/service/global_device_id.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/int_type.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Strongly-typed integer type for naming a device globally within a distributed
|
||||
// system. XLA doesn't have a strong opinion about what global numbering scheme
|
||||
// is applied to GPUs; the user must provide a local -> global mapping via
|
||||
// GpuExecutableRunOptions for the local GPUs.
|
||||
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
|
||||
|
||||
// Returns a comma-separated string of global device IDs.
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_
|
@ -74,11 +74,9 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/compiler/xla/service:global_device_id",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,18 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
|
||||
std::vector<int64> values;
|
||||
values.reserve(ids.size());
|
||||
for (GlobalDeviceId id : ids) {
|
||||
values.push_back(id.value());
|
||||
}
|
||||
return absl::StrJoin(values, ",");
|
||||
}
|
||||
namespace gpu {
|
||||
|
||||
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
|
||||
: devices_(std::move(devices)) {
|
||||
@ -63,4 +54,5 @@ const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback()
|
||||
return nccl_unique_id_callback_;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -21,21 +21,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/int_type.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Strongly-typed integer type for naming a device globally within a distributed
|
||||
// system. XLA doesn't have a strong opinion about what global numbering scheme
|
||||
// is applied to GPUs; the user must provide a local -> global mapping via
|
||||
// GpuExecutableRunOptions for the local GPUs.
|
||||
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
|
||||
|
||||
// Returns a comma-separated string of global device IDs.
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
|
||||
namespace gpu {
|
||||
|
||||
// Key for naming up a particular NCCL clique. This is just a set of unique
|
||||
// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster.
|
||||
@ -87,6 +78,7 @@ class GpuExecutableRunOptions {
|
||||
NcclUniqueIdCallback nccl_unique_id_callback_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
|
||||
|
@ -322,7 +322,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
run_options.set_device_assignment(
|
||||
&executable->executable()->module_config().static_device_assignment());
|
||||
}
|
||||
xla::GpuExecutableRunOptions gpu_options;
|
||||
xla::gpu::GpuExecutableRunOptions gpu_options;
|
||||
std::vector<xla::GlobalDeviceId> gpu_global_ids;
|
||||
if (config.local_replica_mapping_size() > 0) {
|
||||
gpu_global_ids.reserve(config.local_replica_mapping_size());
|
||||
@ -334,7 +334,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
|
||||
if (nccl_factory != nullptr) {
|
||||
auto uid_callback =
|
||||
[&](const xla::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
|
||||
[&](const xla::gpu::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
|
||||
std::vector<xla::int64> replicas;
|
||||
for (auto& device : key.devices()) {
|
||||
replicas.push_back(device.value());
|
||||
|
Loading…
x
Reference in New Issue
Block a user