Change ReplicaIdForDeviceOrdinal to explicitly take a GlobalDeviceId.

The method name and parameter name were both misleading in a multi-host environment.

PiperOrigin-RevId: 344095231
Change-Id: Idf5093126bfcb445a1596ca97d1ab9fc444deca5
This commit is contained in:
Chris Jones 2020-11-24 11:12:14 -08:00 committed by TensorFlower Gardener
parent c6aaecf54e
commit 256556c132
7 changed files with 28 additions and 25 deletions

View File

@ -2882,6 +2882,7 @@ cc_library(
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
deps = [
":global_device_id",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -63,7 +63,7 @@ StatusOr<std::vector<int64>> GetParticipatingReplicas(
// Use the DeviceAssignment to figure out our replica-id.
TF_ASSIGN_OR_RETURN(int replica_id,
device_assn.ReplicaIdForDeviceOrdinal(device_id.value()));
device_assn.ReplicaIdForDevice(device_id));
// Figure out the other replicas that go together with this one.
absl::optional<ReplicaGroup> replica_group;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -39,25 +40,24 @@ using absl::StrCat;
namespace xla {
StatusOr<int> DeviceAssignment::ReplicaIdForDeviceOrdinal(
int device_ordinal) const {
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
GlobalDeviceId device_id) const {
absl::optional<int> replica_id;
for (int64 r = 0; r < replica_count(); ++r) {
for (int64 c = 0; c < computation_count(); ++c) {
if ((*this)(r, c) == device_ordinal) {
if ((*this)(r, c) == device_id.value()) {
if (replica_id.has_value()) {
return InternalError(
"Device ordinal %d appears twice in DeviceAssignment? %s",
device_ordinal, ToString());
"Device %d appears twice in DeviceAssignment: %s",
device_id.value(), ToString());
}
replica_id = r;
}
}
}
if (!replica_id.has_value()) {
return InternalError(
"Device ordinal %d doesn't appear in DeviceAssignment %s",
device_ordinal, ToString());
return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
device_id.value(), ToString());
}
return *replica_id;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -48,7 +49,7 @@ class DeviceAssignment : public Array2D<int> {
int computation_count() const { return width(); }
// Finds the replica ID for the given device.
StatusOr<int> ReplicaIdForDeviceOrdinal(int device_ordinal) const;
StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;
// Protocol buffer serialization and deserialization.
Status Serialize(DeviceAssignmentProto* proto) const;

View File

@ -631,9 +631,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
int device_ordinal = GetDeviceOrdinal(run_options);
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
xla::int32 replica_id =
run_options->device_assignment()
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
.ValueOrDie();
absl::string_view replica_groups_serialized(
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
std::vector<xla::ReplicaGroup> group =
@ -720,9 +721,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer) {
int device_ordinal = GetDeviceOrdinal(run_options);
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
xla::int32 replica_id =
run_options->device_assignment()
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
.ValueOrDie();
std::memcpy(output_buffer, &replica_id, 4);
}
@ -735,9 +737,10 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
absl::string_view source_target_pairs_serialized(
static_cast<const char*>(source_target_pairs), source_target_pairs_size);
auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
xla::int32 replica_id =
run_options->device_assignment()
->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
.ValueOrDie();
std::vector<int> copy_to;
for (auto& p : pairs) {
std::vector<std::string> mapping = absl::StrSplit(p, '=');

View File

@ -249,9 +249,8 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(
int64 replica_id,
params.device_assn->ReplicaIdForDeviceOrdinal(global_device_id.value()));
TF_ASSIGN_OR_RETURN(int64 replica_id,
params.device_assn->ReplicaIdForDevice(global_device_id));
// Figure out which replicas our data is copied to.
std::vector<int64> dest_replicas;

View File

@ -30,9 +30,8 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(
int replica_id,
params.device_assn->ReplicaIdForDeviceOrdinal(global_device_id.value()));
TF_ASSIGN_OR_RETURN(int replica_id,
params.device_assn->ReplicaIdForDevice(global_device_id));
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
return Status::OK();
}