Change ReplicaIdForDeviceOrdinal to explicitly take a GlobalDeviceId.
The method name and parameter name were both misleading in a multi-host environment. PiperOrigin-RevId: 344095231 Change-Id: Idf5093126bfcb445a1596ca97d1ab9fc444deca5
This commit is contained in:
parent
c6aaecf54e
commit
256556c132
@ -2882,6 +2882,7 @@ cc_library(
|
||||
srcs = ["computation_placer.cc"],
|
||||
hdrs = ["computation_placer.h"],
|
||||
deps = [
|
||||
":global_device_id",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
||||
@ -63,7 +63,7 @@ StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
|
||||
// Use the DeviceAssignment to figure out our replica-id.
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
device_assn.ReplicaIdForDeviceOrdinal(device_id.value()));
|
||||
device_assn.ReplicaIdForDevice(device_id));
|
||||
|
||||
// Figure out the other replicas that go together with this one.
|
||||
absl::optional<ReplicaGroup> replica_group;
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -39,25 +40,24 @@ using absl::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<int> DeviceAssignment::ReplicaIdForDeviceOrdinal(
|
||||
int device_ordinal) const {
|
||||
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
|
||||
GlobalDeviceId device_id) const {
|
||||
absl::optional<int> replica_id;
|
||||
for (int64 r = 0; r < replica_count(); ++r) {
|
||||
for (int64 c = 0; c < computation_count(); ++c) {
|
||||
if ((*this)(r, c) == device_ordinal) {
|
||||
if ((*this)(r, c) == device_id.value()) {
|
||||
if (replica_id.has_value()) {
|
||||
return InternalError(
|
||||
"Device ordinal %d appears twice in DeviceAssignment? %s",
|
||||
device_ordinal, ToString());
|
||||
"Device %d appears twice in DeviceAssignment: %s",
|
||||
device_id.value(), ToString());
|
||||
}
|
||||
replica_id = r;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!replica_id.has_value()) {
|
||||
return InternalError(
|
||||
"Device ordinal %d doesn't appear in DeviceAssignment %s",
|
||||
device_ordinal, ToString());
|
||||
return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
|
||||
device_id.value(), ToString());
|
||||
}
|
||||
return *replica_id;
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -48,7 +49,7 @@ class DeviceAssignment : public Array2D<int> {
|
||||
int computation_count() const { return width(); }
|
||||
|
||||
// Finds the replica ID for the given device.
|
||||
StatusOr<int> ReplicaIdForDeviceOrdinal(int device_ordinal) const;
|
||||
StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;
|
||||
|
||||
// Protocol buffer serialization and deserialization.
|
||||
Status Serialize(DeviceAssignmentProto* proto) const;
|
||||
|
||||
@ -631,9 +631,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
|
||||
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
|
||||
xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
xla::int32 replica_id =
|
||||
run_options->device_assignment()
|
||||
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
|
||||
.ValueOrDie();
|
||||
absl::string_view replica_groups_serialized(
|
||||
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||
std::vector<xla::ReplicaGroup> group =
|
||||
@ -720,9 +721,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
|
||||
const xla::ExecutableRunOptions* run_options, void* output_buffer) {
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
xla::int32 replica_id =
|
||||
run_options->device_assignment()
|
||||
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
|
||||
.ValueOrDie();
|
||||
std::memcpy(output_buffer, &replica_id, 4);
|
||||
}
|
||||
|
||||
@ -735,9 +737,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
|
||||
absl::string_view source_target_pairs_serialized(
|
||||
static_cast<const char*>(source_target_pairs), source_target_pairs_size);
|
||||
auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
xla::int32 replica_id =
|
||||
run_options->device_assignment()
|
||||
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
|
||||
.ValueOrDie();
|
||||
std::vector<int> copy_to;
|
||||
for (auto& p : pairs) {
|
||||
std::vector<std::string> mapping = absl::StrSplit(p, '=');
|
||||
|
||||
@ -249,9 +249,8 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||
params.GetGlobalDeviceId());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
int64 replica_id,
|
||||
params.device_assn->ReplicaIdForDeviceOrdinal(global_device_id.value()));
|
||||
TF_ASSIGN_OR_RETURN(int64 replica_id,
|
||||
params.device_assn->ReplicaIdForDevice(global_device_id));
|
||||
|
||||
// Figure out which replicas our data is copied to.
|
||||
std::vector<int64> dest_replicas;
|
||||
|
||||
@ -30,9 +30,8 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||
params.GetGlobalDeviceId());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
int replica_id,
|
||||
params.device_assn->ReplicaIdForDeviceOrdinal(global_device_id.value()));
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
params.device_assn->ReplicaIdForDevice(global_device_id));
|
||||
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user