[NFC] Moved GlobalDeviceId into its own file as it is also used by the CPU runtime.

Moved everything else in `gpu_executable_run_option.h` into the `gpu` namespace.

PiperOrigin-RevId: 344029444
Change-Id: Id911cb046a176313f916354ff63ee97d2b5e5828
This commit is contained in:
Chris Jones 2020-11-24 04:13:12 -08:00 committed by TensorFlower Gardener
parent 6f2db1eea2
commit d39f8a2490
13 changed files with 109 additions and 42 deletions

View File

@ -103,12 +103,12 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
}
ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options) {
const gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
gpu_executable_run_options_ = gpu_executable_run_options;
return *this;
}
const GpuExecutableRunOptions*
const gpu::GpuExecutableRunOptions*
ExecutableRunOptions::gpu_executable_run_options() const {
return gpu_executable_run_options_;
}

View File

@ -38,7 +38,9 @@ namespace xla {
class DeviceAssignment;
class ExecutionProfile;
namespace gpu {
class GpuExecutableRunOptions;
} // namespace gpu
// A unique identifier for a particular "logical execution" of an XLA model.
//
@ -150,8 +152,8 @@ class ExecutableRunOptions {
// GPU-backend specific options. These are kept out-of-line to avoid bloating
// the size of this dependency for CPU-only AOT builds.
ExecutableRunOptions& set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options);
const GpuExecutableRunOptions* gpu_executable_run_options() const;
const gpu::GpuExecutableRunOptions* gpu_executable_run_options);
const gpu::GpuExecutableRunOptions* gpu_executable_run_options() const;
private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
@ -165,7 +167,7 @@ class ExecutableRunOptions {
stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_;
const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
const gpu::GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
};
} // namespace xla

View File

@ -179,7 +179,7 @@ class NcclIdStore {
client_(std::move(client)),
device_to_node_(std::move(device_to_node)) {}
StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
StatusOr<std::string> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
private:
const int node_id_;
@ -187,10 +187,12 @@ class NcclIdStore {
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
absl::Mutex mu_;
absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<gpu::NcclCliqueKey, std::string> cache_
ABSL_GUARDED_BY(mu_);
};
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
const gpu::NcclCliqueKey& key) {
// The caller must ensure that threads calling this method concurrently have
// unique keys, otherwise the global key-value store may hold the wrong value.
{
@ -241,7 +243,7 @@ Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<PjRtDevice>>* devices,
GpuExecutableRunOptions* gpu_executable_run_options) {
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology;
local_topology.set_node_id(node_id);
for (const auto& local_device : local_device_states) {
@ -292,7 +294,7 @@ Status BuildDistributedDevices(
auto nccl_id_store = std::make_shared<NcclIdStore>(
node_id, distributed_client, device_to_node);
gpu_executable_run_options->set_nccl_unique_id_callback(
[nccl_id_store](const NcclCliqueKey& key) {
[nccl_id_store](const gpu::NcclCliqueKey& key) {
return nccl_id_store->GetNcclUniqueId(key);
});
return Status::OK();
@ -320,7 +322,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<PjRtDevice>> devices;
auto gpu_run_options = absl::make_unique<GpuExecutableRunOptions>();
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices(
std::move(local_device_states), std::move(distributed_client), node_id,

View File

@ -188,7 +188,7 @@ PjRtClient::PjRtClient(
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
: platform_id_(tensorflow::Fingerprint64(platform_name)),
platform_name_(std::move(platform_name)),
client_(client),

View File

@ -185,7 +185,7 @@ class PjRtClient {
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default;
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
@ -222,7 +222,7 @@ class PjRtClient {
return should_stage_host_to_device_transfers_;
}
GpuExecutableRunOptions* gpu_run_options() const {
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
@ -358,7 +358,7 @@ class PjRtClient {
// transfer via pinned memory.
bool should_stage_host_to_device_transfers_;
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
tensorflow::thread::ThreadPool h2d_transfer_pool_;
};

View File

@ -5250,3 +5250,15 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "global_device_id",
srcs = ["global_device_id.cc"],
hdrs = ["global_device_id.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -165,7 +165,7 @@ struct AllReduceParticipantData : ParticipantData {
PrimitiveType primitive_type;
};
std::vector<Buffer> buffers;
const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
const gpu::NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
ReductionKind reduction_kind;

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "absl/strings/str_join.h"
namespace xla {
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
std::vector<int64> values;
values.reserve(ids.size());
for (GlobalDeviceId id : ids) {
values.push_back(id.value());
}
return absl::StrJoin(values, ",");
}
} // namespace xla

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_
#include <string>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/int_type.h"
namespace xla {
// Strongly-typed integer type for naming a device globally within a distributed
// system. XLA doesn't have a strong opinion about what global numbering scheme
// is applied to GPUs; the user must provide a local -> global mapping via
// GpuExecutableRunOptions for the local GPUs.
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
// Returns a comma-separated string of global device IDs.
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GLOBAL_DEVICE_ID_H_

View File

@ -74,11 +74,9 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib_internal",
"//tensorflow/compiler/xla/service:global_device_id",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -16,18 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
namespace xla {
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
std::vector<int64> values;
values.reserve(ids.size());
for (GlobalDeviceId id : ids) {
values.push_back(id.value());
}
return absl::StrJoin(values, ",");
}
namespace gpu {
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
: devices_(std::move(devices)) {
@ -63,4 +54,5 @@ const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback()
return nccl_unique_id_callback_;
}
} // namespace gpu
} // namespace xla

View File

@ -21,21 +21,12 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/int_type.h"
namespace xla {
// Strongly-typed integer type for naming a device globally within a distributed
// system. XLA doesn't have a strong opinion about what global numbering scheme
// is applied to GPUs; the user must provide a local -> global mapping via
// GpuExecutableRunOptions for the local GPUs.
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
// Returns a comma-separated string of global device IDs.
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
namespace gpu {
// Key for naming up a particular NCCL clique. This is just a set of unique
// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster.
@ -87,6 +78,7 @@ class GpuExecutableRunOptions {
NcclUniqueIdCallback nccl_unique_id_callback_;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_

View File

@ -322,7 +322,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
run_options.set_device_assignment(
&executable->executable()->module_config().static_device_assignment());
}
xla::GpuExecutableRunOptions gpu_options;
xla::gpu::GpuExecutableRunOptions gpu_options;
std::vector<xla::GlobalDeviceId> gpu_global_ids;
if (config.local_replica_mapping_size() > 0) {
gpu_global_ids.reserve(config.local_replica_mapping_size());
@ -334,7 +334,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
if (nccl_factory != nullptr) {
auto uid_callback =
[&](const xla::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
[&](const xla::gpu::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
std::vector<xla::int64> replicas;
for (auto& device : key.devices()) {
replicas.push_back(device.value());