Move devices information from instance params to group params
This removes the necessity to hold both instance_mu_ and gr->mu. Each worker still calculates these information by themselves. This is the same as before, but it may be better if the leader passes these information to others, so that it's easy to make sure they're the same everywhere. PiperOrigin-RevId: 334274263 Change-Id: Ic854b7459b286b7fa3fb1a4f890e0f49aed14d3b
This commit is contained in:
parent
8b893a0ab0
commit
713a9306d8
@ -306,7 +306,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
void BaseCollectiveExecutor::CompleteParamsAsync(
|
||||
const DeviceAttributes& device, CollectiveParams* cp,
|
||||
CancellationManager* cancel_mgr, StatusCallback done) {
|
||||
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
||||
cp->group.gpu_ring_order = *gpu_ring_order_;
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
auto done_with_timeout = done;
|
||||
auto timeout_microseconds =
|
||||
@ -402,9 +402,9 @@ void BaseCollectiveExecutor::UnblockDependencies(
|
||||
mutex_lock l(launch_mu_);
|
||||
if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
|
||||
const string& task_name =
|
||||
col_params.instance.task_names[col_params.default_rank];
|
||||
col_params.group.task_names[col_params.default_rank];
|
||||
const int32 num_devices =
|
||||
col_params.instance.num_devices_per_task.at(task_name);
|
||||
col_params.group.num_devices_per_task.at(task_name);
|
||||
launched_[col_params.instance.instance_key] = num_devices;
|
||||
}
|
||||
if (--launched_[col_params.instance.instance_key] == 0) {
|
||||
|
@ -97,9 +97,11 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
auto it = group_table_.find(cp->group.group_key);
|
||||
if (it == group_table_.end()) {
|
||||
gr = new GroupRec;
|
||||
mutex_lock grl(gr->mu);
|
||||
gr->group.group_key = cp->group.group_key;
|
||||
gr->group.group_size = cp->group.group_size;
|
||||
gr->group.device_type = cp->group.device_type;
|
||||
gr->group.gpu_ring_order = cp->group.gpu_ring_order;
|
||||
|
||||
// Initialize group runtime details.
|
||||
CollectiveImplementationInterface* col_impl;
|
||||
@ -164,6 +166,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
" but that group has size ", gr->group.group_size);
|
||||
}
|
||||
}
|
||||
bool new_device = false;
|
||||
if (gr->status.ok()) {
|
||||
// Insert device if not already present.
|
||||
auto it = gr->devices.find(device.name());
|
||||
@ -177,15 +180,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
} else {
|
||||
// This is a new device that has not yet joined the group.
|
||||
gr->devices[device.name()] = device;
|
||||
if (gr->devices.size() == gr->group.group_size) {
|
||||
// The group is full after adding this device, calculate the number
|
||||
// of tasks.
|
||||
std::unordered_set<string> tasks;
|
||||
for (const auto& item : gr->devices) {
|
||||
tasks.insert(TaskNameFromDeviceName(item.first));
|
||||
}
|
||||
gr->group.num_tasks = static_cast<int32>(tasks.size());
|
||||
}
|
||||
new_device = true;
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string dev_buf;
|
||||
for (const auto& d : gr->devices) {
|
||||
@ -211,7 +206,6 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
}
|
||||
|
||||
if (gr->status.ok()) {
|
||||
cp->group.runtime_details = gr->group.runtime_details;
|
||||
// If the group is not yet complete, queue to wait for it.
|
||||
VLOG(2) << "group_size " << gr->group.group_size << " set size "
|
||||
<< gr->devices.size() << " gr " << gr;
|
||||
@ -221,6 +215,10 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(gr->devices.size(), gr->group.group_size);
|
||||
// We get a full group. Fill in remaining fields in gr->group.
|
||||
if (new_device) {
|
||||
FinishGroup(gr);
|
||||
}
|
||||
}
|
||||
// At this point, we either have a full group, or an error status. Ensure
|
||||
// that all callbacks are invoked with the appropriate status.
|
||||
@ -248,16 +246,16 @@ typedef std::unordered_map<string, DevRec> TaskDeviceMap;
|
||||
typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
|
||||
|
||||
// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
|
||||
GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
|
||||
GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp,
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
GlobalDeviceMap gdm;
|
||||
CHECK_EQ(ip.device_names.size(), ip.task_names.size());
|
||||
CHECK_EQ(ip.device_names.size(), attributes.size());
|
||||
for (int i = 0; i < ip.device_names.size(); ++i) {
|
||||
TaskDeviceMap& tdm = gdm[ip.task_names[i]];
|
||||
DevRec* dr = &tdm[ip.device_names[i]];
|
||||
dr->task = ip.task_names[i];
|
||||
dr->device = ip.device_names[i];
|
||||
CHECK_EQ(gp.device_names.size(), gp.task_names.size());
|
||||
CHECK_EQ(gp.device_names.size(), attributes.size());
|
||||
for (int i = 0; i < gp.device_names.size(); ++i) {
|
||||
TaskDeviceMap& tdm = gdm[gp.task_names[i]];
|
||||
DevRec* dr = &tdm[gp.device_names[i]];
|
||||
dr->task = gp.task_names[i];
|
||||
dr->device = gp.device_names[i];
|
||||
dr->original_rank = i;
|
||||
dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
|
||||
dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
|
||||
@ -378,25 +376,23 @@ void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
|
||||
}
|
||||
}
|
||||
|
||||
// The first time a shared CollectiveParams is established for a
|
||||
// shared set of instances we compute a good rank order for all the
|
||||
// devices in the group, that is appropriate for a ring algorithm.
|
||||
// This order need not be the same across different instance groups
|
||||
// sharing the same device group where there is more than one good
|
||||
// order.
|
||||
// The first time a CollGroupParams is established for a group we compute a good
|
||||
// rank order for all the devices in the group, that is appropriate for a ring
|
||||
// algorithm.
|
||||
GlobalDeviceMap EstablishGlobalRank(
|
||||
CollectiveParams* cp, const std::vector<DeviceAttributes>& attributes) {
|
||||
const CollGroupParams& gp,
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
VLOG(1) << "EstablishGlobalRank";
|
||||
GlobalDeviceMap gdm = BuildDevRecs(cp->instance, attributes);
|
||||
GlobalDeviceMap gdm = BuildDevRecs(gp, attributes);
|
||||
for (auto& iter : gdm) {
|
||||
TaskDeviceMap& tdm = iter.second;
|
||||
OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm);
|
||||
OrderTaskDeviceMap(gp.gpu_ring_order, &tdm);
|
||||
}
|
||||
// Connect the global rank order by the order in which tasks first appear.
|
||||
std::set<string> ordered_tasks;
|
||||
int next_rank = 0;
|
||||
for (int i = 0; i < cp->instance.task_names.size(); ++i) {
|
||||
const string& task_name = cp->instance.task_names[i];
|
||||
for (int i = 0; i < gp.task_names.size(); ++i) {
|
||||
const string& task_name = gp.task_names[i];
|
||||
if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
|
||||
continue;
|
||||
}
|
||||
@ -411,81 +407,104 @@ GlobalDeviceMap EstablishGlobalRank(
|
||||
}
|
||||
|
||||
// Count the devices associated with each task and set
|
||||
// cp->same_num_devices_per_task. Requires cp->instance.task_names
|
||||
// gp->same_num_devices_per_task. Requires gp->task_names
|
||||
// be sorted.
|
||||
void SetDevPerTask(CollectiveParams* cp) {
|
||||
cp->instance.num_devices_per_task.clear();
|
||||
const string* last_task_name = &cp->instance.task_names[0];
|
||||
void SetDevPerTask(CollGroupParams* gp) {
|
||||
gp->num_devices_per_task.clear();
|
||||
const string* last_task_name = &gp->task_names[0];
|
||||
int count = 0;
|
||||
for (const string& task_name : cp->instance.task_names) {
|
||||
for (const string& task_name : gp->task_names) {
|
||||
if (task_name == *last_task_name) {
|
||||
++count;
|
||||
} else {
|
||||
cp->instance.num_devices_per_task[*last_task_name] = count;
|
||||
gp->num_devices_per_task[*last_task_name] = count;
|
||||
count = 1;
|
||||
last_task_name = &task_name;
|
||||
}
|
||||
}
|
||||
cp->instance.num_devices_per_task[*last_task_name] = count;
|
||||
gp->num_devices_per_task[*last_task_name] = count;
|
||||
|
||||
cp->instance.same_num_devices_per_task = false;
|
||||
gp->same_num_devices_per_task = false;
|
||||
int dev_per_task = -1;
|
||||
for (const auto& task_dev : cp->instance.num_devices_per_task) {
|
||||
for (const auto& task_dev : gp->num_devices_per_task) {
|
||||
if (dev_per_task == -1) {
|
||||
dev_per_task = task_dev.second;
|
||||
} else if (dev_per_task != task_dev.second) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cp->instance.same_num_devices_per_task = true;
|
||||
CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
|
||||
gp->same_num_devices_per_task = true;
|
||||
CHECK_EQ((gp->group_size % gp->num_tasks), 0);
|
||||
}
|
||||
|
||||
// Sort cp->instance.device_names lexicographically, but do by first
|
||||
// computing a reordering permutation so we can keep cp->instance.task_names
|
||||
// Sort gp->device_names lexicographically, but do by first
|
||||
// computing a reordering permutation so we can keep gp->task_names
|
||||
// in corresponding order.
|
||||
void SortDevicesAndTasks(CollectiveParams* cp) {
|
||||
VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
|
||||
CHECK(cp);
|
||||
CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
|
||||
CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
|
||||
std::vector<int> perm(cp->group.group_size);
|
||||
void SortDevicesAndTasks(CollGroupParams* gp) {
|
||||
VLOG(1) << "SortDevicesAndTasks " << gp << " " << gp;
|
||||
CHECK(gp);
|
||||
CHECK_EQ(gp->group_size, gp->device_names.size());
|
||||
CHECK_EQ(gp->group_size, gp->task_names.size());
|
||||
std::vector<int> perm(gp->group_size);
|
||||
// TODO(tucker): substitute std::iota when the windows build supports it.
|
||||
// std::iota(perm.begin(), perm.end(), 0);
|
||||
for (int i = 0; i < perm.size(); ++i) {
|
||||
perm[i] = i;
|
||||
}
|
||||
std::sort(perm.begin(), perm.end(), [cp](int a, int b) {
|
||||
return cp->instance.device_names[a] < cp->instance.device_names[b];
|
||||
std::sort(perm.begin(), perm.end(), [gp](int a, int b) {
|
||||
return gp->device_names[a] < gp->device_names[b];
|
||||
});
|
||||
std::vector<string> new_devs;
|
||||
std::vector<string> new_tasks;
|
||||
new_devs.reserve(cp->group.group_size);
|
||||
new_tasks.reserve(cp->group.group_size);
|
||||
new_devs.reserve(gp->group_size);
|
||||
new_tasks.reserve(gp->group_size);
|
||||
for (int pi : perm) {
|
||||
new_devs.push_back(cp->instance.device_names[pi]);
|
||||
new_tasks.push_back(cp->instance.task_names[pi]);
|
||||
new_devs.push_back(gp->device_names[pi]);
|
||||
new_tasks.push_back(gp->task_names[pi]);
|
||||
}
|
||||
cp->instance.device_names = std::move(new_devs);
|
||||
cp->instance.task_names = std::move(new_tasks);
|
||||
VLOG(1) << "Modified device_names on " << cp;
|
||||
SetDevPerTask(cp);
|
||||
gp->device_names = std::move(new_devs);
|
||||
gp->task_names = std::move(new_tasks);
|
||||
VLOG(1) << "Modified device_names on " << gp;
|
||||
SetDevPerTask(gp);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
|
||||
gr->group.device_names.reserve(gr->devices.size());
|
||||
gr->group.task_names.reserve(gr->devices.size());
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
// Unique tasks. It's used to calculate num_tasks.
|
||||
std::unordered_set<string> tasks;
|
||||
attributes.reserve(gr->devices.size());
|
||||
for (const auto& item : gr->devices) {
|
||||
gr->group.device_names.push_back(item.first);
|
||||
string task_name = TaskNameFromDeviceName(item.first);
|
||||
gr->group.task_names.push_back(task_name);
|
||||
tasks.insert(task_name);
|
||||
attributes.push_back(item.second);
|
||||
}
|
||||
gr->group.num_tasks = static_cast<int32>(tasks.size());
|
||||
// Sort device_names lexicographically, keeping task_names in corresponding
|
||||
// order. Also set number of devices per task.
|
||||
SortDevicesAndTasks(&gr->group);
|
||||
// Establish the final order of gp->device_names and gp->task_names by
|
||||
// considering localities of all devices.
|
||||
CompleteDefaultRanking(attributes, &gr->group);
|
||||
}
|
||||
|
||||
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
|
||||
CollectiveParams* cp) {
|
||||
cp->task.is_local.resize(cp->group.group_size, false);
|
||||
for (int i = 0; i < cp->group.group_size; ++i) {
|
||||
cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
|
||||
cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
|
||||
}
|
||||
}
|
||||
|
||||
void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
|
||||
CollectiveParams* cp) {
|
||||
CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
|
||||
CHECK_EQ(cp->group.group_size, cp->group.device_names.size()) << cp;
|
||||
for (int i = 0; i < cp->group.group_size; ++i) {
|
||||
if (cp->instance.device_names[i] == device) {
|
||||
if (cp->group.device_names[i] == device) {
|
||||
cp->default_rank = i;
|
||||
break;
|
||||
}
|
||||
@ -494,40 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
|
||||
|
||||
void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
ir->shared.instance = cp->instance;
|
||||
{
|
||||
mutex_lock gl(gr->mu);
|
||||
ir->shared.group = gr->group;
|
||||
ir->shared.instance.device_names.clear();
|
||||
ir->shared.instance.task_names.clear();
|
||||
ir->shared.instance.device_names.reserve(gr->devices.size());
|
||||
ir->shared.instance.task_names.reserve(gr->devices.size());
|
||||
attributes.reserve(gr->devices.size());
|
||||
for (const auto& item : gr->devices) {
|
||||
ir->shared.instance.device_names.push_back(item.first);
|
||||
ir->shared.instance.task_names.push_back(
|
||||
TaskNameFromDeviceName(item.first));
|
||||
attributes.push_back(item.second);
|
||||
}
|
||||
VLOG(2) << "Initialized names for instance: "
|
||||
<< ir->shared.instance.ToString();
|
||||
}
|
||||
ir->shared.default_rank = -1;
|
||||
|
||||
// Sort device_names lexicographically, keeping task_names in corresponding
|
||||
// order. Also set number of devices per task.
|
||||
SortDevicesAndTasks(&ir->shared);
|
||||
|
||||
// Get Locality data for all devices.
|
||||
|
||||
// Set is_local and task_names in *shared prior to invoking
|
||||
// GetDeviceAttributesAsync. In a distributed context this function can be
|
||||
// called by a derived class, some of the devices may be non-local and
|
||||
// GetDeviceAttributesAsync will use those fields to launch RPCs.
|
||||
CompleteTaskIsLocal(task_name_, &ir->shared);
|
||||
|
||||
CompleteDefaultRanking(gr, cp, ir, attributes);
|
||||
}
|
||||
|
||||
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
|
||||
@ -535,33 +528,31 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
// TensorFlow runtime. This set of devices may be a superset of the devices
|
||||
// participating in this instance of collectives.
|
||||
void CollectiveParamResolverLocal::CompleteDefaultRanking(
|
||||
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
const std::vector<DeviceAttributes>& attributes, CollGroupParams* gp) {
|
||||
// Establish an instance-specific default rank order for devices
|
||||
// based on localities. This rank order should be a good ring
|
||||
// order, if possible.
|
||||
GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, attributes);
|
||||
GlobalDeviceMap gdm = EstablishGlobalRank(*gp, attributes);
|
||||
// Reflect the new global ranking on shared
|
||||
size_t num_devices = ir->shared.group.group_size;
|
||||
size_t num_devices = gp->group_size;
|
||||
std::vector<string> new_device_names(num_devices, "");
|
||||
std::vector<string> new_task_names(num_devices, "");
|
||||
for (const auto& git : gdm) {
|
||||
const TaskDeviceMap& tdm = git.second;
|
||||
for (const auto& tit : tdm) {
|
||||
const DevRec& dr = tit.second;
|
||||
new_device_names[dr.global_rank] =
|
||||
ir->shared.instance.device_names[dr.original_rank];
|
||||
new_task_names[dr.global_rank] =
|
||||
ir->shared.instance.task_names[dr.original_rank];
|
||||
new_device_names[dr.global_rank] = gp->device_names[dr.original_rank];
|
||||
new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
|
||||
}
|
||||
}
|
||||
|
||||
ir->shared.instance.device_names = new_device_names;
|
||||
ir->shared.instance.task_names = new_task_names;
|
||||
gp->device_names = new_device_names;
|
||||
gp->task_names = new_task_names;
|
||||
if (VLOG_IS_ON(2)) {
|
||||
string buf;
|
||||
for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
|
||||
VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
|
||||
VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
|
||||
<< buf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -655,10 +646,13 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
|
||||
|
||||
// Populate the group portion of *cp from *gr. Most of it should already
|
||||
// match.
|
||||
DCHECK_EQ(cp->group.group_key, gr->group.group_key);
|
||||
DCHECK_EQ(cp->group.group_size, gr->group.group_size);
|
||||
DCHECK_EQ(cp->group.device_type, gr->group.device_type);
|
||||
cp->group = gr->group;
|
||||
{
|
||||
mutex_lock l(gr->mu);
|
||||
DCHECK_EQ(cp->group.group_key, gr->group.group_key);
|
||||
DCHECK_EQ(cp->group.group_size, gr->group.group_size);
|
||||
DCHECK_EQ(cp->group.device_type, gr->group.device_type);
|
||||
cp->group = gr->group;
|
||||
}
|
||||
|
||||
InstanceRec* ir = GetOrCreateInstanceRec(gr, cp);
|
||||
CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done);
|
||||
@ -756,11 +750,11 @@ void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ir->known_count < ir->shared.group.group_size) {
|
||||
if (ir->known_count < cp->group.group_size) {
|
||||
ir->known_waiters.push_back(f);
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(ir->known_count, ir->shared.group.group_size);
|
||||
CHECK_EQ(ir->known_count, cp->group.group_size);
|
||||
if (ir->source_rank < 0) {
|
||||
// NOTE(ayushd): changing the error message below would also require
|
||||
// updating CompleteParamsBroadcastForgotSend test in
|
||||
|
@ -69,8 +69,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
||||
|
||||
// Used to complete/verify CollGroup.
|
||||
struct GroupRec {
|
||||
CollGroupParams group;
|
||||
mutable mutex mu;
|
||||
CollGroupParams group TF_GUARDED_BY(mu);
|
||||
Status status TF_GUARDED_BY(mu);
|
||||
std::unordered_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
|
||||
std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
|
||||
@ -88,6 +88,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
||||
const GroupRecCallback& done)
|
||||
TF_LOCKS_EXCLUDED(group_mu_);
|
||||
|
||||
// Finishes the group parameters once all members of the group are there.
|
||||
void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu);
|
||||
|
||||
// Used to complete/verify CollInstance.
|
||||
struct InstanceRec;
|
||||
|
||||
@ -131,11 +134,10 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
||||
void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
|
||||
InstanceRec* ir) TF_LOCKS_EXCLUDED(gr->mu);
|
||||
|
||||
// Establishes the final order of ir->shared.instance.device_names and
|
||||
// ir->shared.instance.task_names by considering localities of all devices.
|
||||
void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp,
|
||||
InstanceRec* ir,
|
||||
const std::vector<DeviceAttributes>& attributes);
|
||||
// Establishes the final order of gp->device_names and gp->task_names by
|
||||
// considering localities of all devices.
|
||||
void CompleteDefaultRanking(const std::vector<DeviceAttributes>& attributes,
|
||||
CollGroupParams* gp);
|
||||
|
||||
// Finish populating *cp.
|
||||
// Precondition: *gr has been fully populated by CompleteGroupLocal.
|
||||
|
@ -59,32 +59,21 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void RunCompleteDefaultRanking(
|
||||
const CollectiveParams& shared_cp,
|
||||
const std::vector<DeviceAttributes>& attributes,
|
||||
CollGroupParams group, const std::vector<DeviceAttributes>& attributes,
|
||||
const std::vector<int32>& gpu_ring_order,
|
||||
const std::vector<string>& expected_device_order) {
|
||||
CollectiveParams cp;
|
||||
cp.instance.device_names = shared_cp.instance.device_names;
|
||||
CollectiveParamResolverLocal::InstanceRec ir;
|
||||
{
|
||||
mutex_lock l(ir.mu);
|
||||
ir.shared.name = shared_cp.name;
|
||||
ir.shared.group = shared_cp.group;
|
||||
ir.shared.instance = shared_cp.instance;
|
||||
if (!gpu_ring_order.empty()) {
|
||||
ir.shared.instance.gpu_ring_order = "";
|
||||
for (int i = 0; i < static_cast<int32>(gpu_ring_order.size() - 1);
|
||||
++i) {
|
||||
ir.shared.instance.gpu_ring_order = strings::StrCat(
|
||||
ir.shared.instance.gpu_ring_order, gpu_ring_order[i], ",");
|
||||
}
|
||||
ir.shared.instance.gpu_ring_order = strings::StrCat(
|
||||
ir.shared.instance.gpu_ring_order, gpu_ring_order.back());
|
||||
if (!gpu_ring_order.empty()) {
|
||||
group.gpu_ring_order = "";
|
||||
for (int i = 0; i < static_cast<int32>(gpu_ring_order.size() - 1); ++i) {
|
||||
group.gpu_ring_order =
|
||||
strings::StrCat(group.gpu_ring_order, gpu_ring_order[i], ",");
|
||||
}
|
||||
VLOG(2) << "gpu_ring_order " << ir.shared.instance.gpu_ring_order;
|
||||
prl_->CompleteDefaultRanking(nullptr, &cp, &ir, attributes);
|
||||
EXPECT_EQ(ir.shared.instance.device_names, expected_device_order);
|
||||
group.gpu_ring_order =
|
||||
strings::StrCat(group.gpu_ring_order, gpu_ring_order.back());
|
||||
}
|
||||
VLOG(2) << "gpu_ring_order " << group.gpu_ring_order;
|
||||
prl_->CompleteDefaultRanking(attributes, &group);
|
||||
EXPECT_EQ(group.device_names, expected_device_order);
|
||||
}
|
||||
|
||||
DeviceAttributes GetDeviceAttributes(const string& device_name) {
|
||||
@ -101,19 +90,15 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
|
||||
TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
constexpr int kNumGpus = 8;
|
||||
CollectiveParams cp;
|
||||
CollGroupParams group;
|
||||
std::vector<DeviceAttributes> attributes(kNumGpus);
|
||||
cp.name = "PRLTest";
|
||||
cp.group.device_type = DeviceType("GPU");
|
||||
cp.group.num_tasks = 1;
|
||||
cp.group.group_size = kNumGpus;
|
||||
cp.instance.instance_key = 5;
|
||||
cp.instance.type = REDUCTION_COLLECTIVE;
|
||||
cp.instance.data_type = DataType(DT_FLOAT);
|
||||
group.device_type = DeviceType("GPU");
|
||||
group.num_tasks = 1;
|
||||
group.group_size = kNumGpus;
|
||||
std::unordered_set<int> clique1 = {0, 1, 6, 7};
|
||||
for (int gpu_idx = 0; gpu_idx < kNumGpus; ++gpu_idx) {
|
||||
cp.instance.task_names.push_back("/job:localhost/replica:0/task:0");
|
||||
cp.instance.device_names.push_back(strings::StrCat(
|
||||
group.task_names.push_back("/job:localhost/replica:0/task:0");
|
||||
group.device_names.push_back(strings::StrCat(
|
||||
"/job:localhost/replica:0/task:0/device:GPU:", gpu_idx));
|
||||
DeviceLocality locality;
|
||||
// Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected
|
||||
@ -138,7 +123,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
}
|
||||
*attributes[gpu_idx].mutable_locality() = locality;
|
||||
}
|
||||
RunCompleteDefaultRanking(cp, attributes, {1, 3, 5, 7, 6, 4, 2, 0},
|
||||
RunCompleteDefaultRanking(group, attributes, {1, 3, 5, 7, 6, 4, 2, 0},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:1",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:3",
|
||||
@ -149,7 +134,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
"/job:localhost/replica:0/task:0/device:GPU:2",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
});
|
||||
RunCompleteDefaultRanking(cp, attributes, {7, 6, 5, 4, 3, 2, 1, 0},
|
||||
RunCompleteDefaultRanking(group, attributes, {7, 6, 5, 4, 3, 2, 1, 0},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:7",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:6",
|
||||
@ -162,7 +147,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
});
|
||||
// With no gpu_ring_order passed, automatic link detection should kick in.
|
||||
// Starting at dev 0, the best order would be: 0,1,6,7,3,2,4,5
|
||||
RunCompleteDefaultRanking(cp, attributes, {},
|
||||
RunCompleteDefaultRanking(group, attributes, {},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:1",
|
||||
@ -189,18 +174,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
|
||||
cp->instance.type = REDUCTION_COLLECTIVE;
|
||||
cp->instance.data_type = DataType(DT_FLOAT);
|
||||
cp->instance.shape = TensorShape({5});
|
||||
cp->instance.device_names.push_back(
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
|
||||
cp->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
cp->is_source = false;
|
||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||
prl_->CompleteParamsAsync(
|
||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
string device =
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
});
|
||||
}
|
||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||
@ -208,17 +192,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
|
||||
}
|
||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||
TF_ASSERT_OK(statuses[i]);
|
||||
ASSERT_EQ(cps[i].instance.device_names.size(), 3);
|
||||
ASSERT_EQ(cps[i].group.device_names.size(), 3);
|
||||
for (int j = 0; j < NUM_DEVS; ++j) {
|
||||
EXPECT_EQ(
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
|
||||
cps[i].instance.device_names[j]);
|
||||
cps[i].group.device_names[j]);
|
||||
EXPECT_TRUE(cps[i].task.is_local[j]);
|
||||
}
|
||||
EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
|
||||
EXPECT_FALSE(cps[i].is_source);
|
||||
EXPECT_EQ(cps[i].default_rank, i);
|
||||
EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
|
||||
EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
|
||||
}
|
||||
}
|
||||
|
||||
@ -233,8 +217,6 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx,
|
||||
cp->instance.type = BROADCAST_COLLECTIVE;
|
||||
cp->instance.data_type = DataType(DT_FLOAT);
|
||||
cp->instance.shape = TensorShape({5});
|
||||
cp->instance.device_names.push_back(strings::StrCat(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:", device_idx));
|
||||
cp->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
cp->is_source = is_source;
|
||||
}
|
||||
@ -248,13 +230,14 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
|
||||
CollectiveParams* cp = &cps[i];
|
||||
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
|
||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||
prl_->CompleteParamsAsync(
|
||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
string device =
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
});
|
||||
}
|
||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||
@ -262,16 +245,16 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
|
||||
}
|
||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||
TF_ASSERT_OK(statuses[i]);
|
||||
ASSERT_EQ(cps[i].instance.device_names.size(), 3);
|
||||
ASSERT_EQ(cps[i].group.device_names.size(), 3);
|
||||
for (int j = 0; j < NUM_DEVS; ++j) {
|
||||
EXPECT_EQ(
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
|
||||
cps[i].instance.device_names[j]);
|
||||
cps[i].group.device_names[j]);
|
||||
EXPECT_TRUE(cps[i].task.is_local[j]);
|
||||
}
|
||||
EXPECT_EQ(cps[i].is_source, (i == 1));
|
||||
EXPECT_EQ(cps[i].default_rank, i);
|
||||
EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
|
||||
EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,13 +270,14 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
|
||||
CollectiveParams* cp = &cps[i];
|
||||
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
|
||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||
prl_->CompleteParamsAsync(
|
||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
string device =
|
||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
|
||||
nullptr /*CancellationManager*/,
|
||||
[&statuses, ¬e, i](const Status& s) {
|
||||
statuses[i] = s;
|
||||
note[i].Notify();
|
||||
});
|
||||
});
|
||||
}
|
||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||
|
@ -60,8 +60,8 @@ string SubdivPermDebugString(const CollectiveParams& col_params) {
|
||||
for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
|
||||
int idx = subdiv_perms[sdi][di];
|
||||
if (idx >= 0) {
|
||||
CHECK_GT(col_params.instance.device_names.size(), idx);
|
||||
strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n");
|
||||
CHECK_GT(col_params.group.device_names.size(), idx);
|
||||
strings::StrAppend(&buf, col_params.group.device_names[idx], "\n");
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&buf, " subdiv_offsets: ");
|
||||
|
@ -80,20 +80,20 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
|
||||
CHECK_EQ(col_params->instance.impl_details.collective_name,
|
||||
"HierarchicalTreeBroadcast");
|
||||
const string& device_name =
|
||||
col_params->instance.device_names[col_params->default_rank];
|
||||
col_params->group.device_names[col_params->default_rank];
|
||||
// Start by counting the devices in each task.
|
||||
// Precondition: device_names must be sorted so that all devices in
|
||||
// the same task are adjacent.
|
||||
VLOG(2) << "Sorted task names: "
|
||||
<< absl::StrJoin(col_params->instance.task_names, ", ");
|
||||
<< absl::StrJoin(col_params->group.task_names, ", ");
|
||||
std::vector<int> dev_per_task;
|
||||
const string* prior_task_name = &col_params->instance.task_names[0];
|
||||
const string* prior_task_name = &col_params->group.task_names[0];
|
||||
int dev_count = 1;
|
||||
for (int di = 1; di < col_params->group.group_size; ++di) {
|
||||
if (col_params->instance.task_names[di] != *prior_task_name) {
|
||||
if (col_params->group.task_names[di] != *prior_task_name) {
|
||||
dev_per_task.push_back(dev_count);
|
||||
dev_count = 1;
|
||||
prior_task_name = &col_params->instance.task_names[di];
|
||||
prior_task_name = &col_params->group.task_names[di];
|
||||
} else {
|
||||
++dev_count;
|
||||
}
|
||||
@ -135,14 +135,13 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
|
||||
if (source_task == ti) {
|
||||
// Source device belongs to this task.
|
||||
perm.push_back(col_params->source_rank);
|
||||
participate =
|
||||
col_params->instance.device_names[col_params->source_rank] ==
|
||||
device_name;
|
||||
participate = col_params->group.device_names[col_params->source_rank] ==
|
||||
device_name;
|
||||
} else {
|
||||
// Source does not belong to this task, choose dev 0.
|
||||
perm.push_back(device_count);
|
||||
participate =
|
||||
col_params->instance.device_names[device_count] == device_name;
|
||||
col_params->group.device_names[device_count] == device_name;
|
||||
}
|
||||
if (participate) col_params->subdiv_rank.push_back(ti);
|
||||
device_count += dev_per_task[ti];
|
||||
@ -165,7 +164,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
|
||||
int subdiv_source = 0;
|
||||
for (int di = 0; di < dev_per_task[ti]; di++) {
|
||||
perm.push_back(abs_di);
|
||||
if (col_params->instance.device_names[abs_di] == device_name) {
|
||||
if (col_params->group.device_names[abs_di] == device_name) {
|
||||
participate = true;
|
||||
col_params->subdiv_rank.push_back(di);
|
||||
}
|
||||
@ -417,11 +416,11 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
|
||||
col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
|
||||
VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
|
||||
<< col_ctx_->device_name << " to_device "
|
||||
<< col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
|
||||
<< col_params_->group.device_names[dst_idx] << " subdiv=" << subdiv
|
||||
<< " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
|
||||
col_ctx_->col_exec->remote_access()->PostToPeer(
|
||||
col_params_->instance.device_names[dst_idx],
|
||||
col_params_->instance.task_names[dst_idx], send_buf_key, col_ctx_->device,
|
||||
col_params_->group.device_names[dst_idx],
|
||||
col_params_->group.task_names[dst_idx], send_buf_key, col_ctx_->device,
|
||||
col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), src_tensor,
|
||||
col_ctx_->device_locality, done);
|
||||
@ -435,12 +434,12 @@ void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
|
||||
int src_idx =
|
||||
col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
|
||||
VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
|
||||
<< col_params_->instance.device_names[src_idx] << " to_device "
|
||||
<< col_params_->group.device_names[src_idx] << " to_device "
|
||||
<< col_ctx_->device_name << " subdiv=" << subdiv
|
||||
<< " src_rank=" << src_rank << " src_idx=" << src_idx;
|
||||
col_ctx_->col_exec->remote_access()->RecvFromPeer(
|
||||
col_params_->instance.device_names[src_idx],
|
||||
col_params_->instance.task_names[src_idx],
|
||||
col_params_->group.device_names[src_idx],
|
||||
col_params_->group.task_names[src_idx],
|
||||
col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
|
||||
col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
|
||||
|
@ -328,8 +328,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
dev_name = strings::StrCat(task_name, "/device:CPU:", di);
|
||||
}
|
||||
VLOG(2) << "dev=" << dev_name;
|
||||
col_params_.instance.device_names.push_back(dev_name);
|
||||
col_params_.instance.task_names.push_back(task_name);
|
||||
col_params_.group.device_names.push_back(dev_name);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
col_params_.task.is_local.push_back(true);
|
||||
}
|
||||
}
|
||||
@ -337,7 +337,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
for (int di = 0; di < num_devices_per_worker; di++) {
|
||||
int default_rank = wi * num_devices_per_worker + di;
|
||||
instances_.push_back(new DeviceInstance(
|
||||
default_rank, col_params_.instance.device_names[default_rank],
|
||||
default_rank, col_params_.group.device_names[default_rank],
|
||||
device_type, this));
|
||||
}
|
||||
}
|
||||
@ -539,8 +539,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
|
||||
for (int di = 0; di < num_gpus; di++) {
|
||||
string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
|
||||
cp->instance.task_names.push_back(task_name);
|
||||
cp->instance.device_names.push_back(dev_name);
|
||||
cp->group.task_names.push_back(task_name);
|
||||
cp->group.device_names.push_back(dev_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -557,22 +557,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.instance.data_type = parent_->col_params_.instance.data_type;
|
||||
col_params_.group.group_key = parent_->col_params_.group.group_key;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
col_params_.instance.instance_key =
|
||||
parent_->col_params_.instance.instance_key;
|
||||
col_params_.group.device_type = parent_->col_params_.group.device_type;
|
||||
col_params_.group.group_size = parent_->col_params_.group.group_size;
|
||||
col_params_.instance.device_names =
|
||||
parent_->col_params_.instance.device_names;
|
||||
col_params_.instance.task_names =
|
||||
parent_->col_params_.instance.task_names;
|
||||
col_params_.task.is_local = parent_->col_params_.task.is_local;
|
||||
col_params_.instance.impl_details.subdiv_permutations =
|
||||
parent_->col_params_.instance.impl_details.subdiv_permutations;
|
||||
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
|
||||
|
||||
int group_size = col_params_.group.group_size;
|
||||
CHECK_EQ(group_size, col_params_.instance.device_names.size());
|
||||
CHECK_EQ(group_size, col_params_.group.device_names.size());
|
||||
// Default rank is order in device_names.
|
||||
col_params_.default_rank = rank;
|
||||
|
||||
@ -789,8 +783,8 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
|
||||
for (int di = 0; di < dev_per_task[ti]; di++) {
|
||||
string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
|
||||
cp.instance.task_names.push_back(task_name);
|
||||
cp.instance.device_names.push_back(dev_name);
|
||||
cp.group.task_names.push_back(task_name);
|
||||
cp.group.device_names.push_back(dev_name);
|
||||
cp.group.group_size++;
|
||||
}
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ void Permuter::DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
|
||||
<< " target_rank=" << target_rank << " src_rank=" << src_rank;
|
||||
col_ctx_->col_exec->remote_access()->PostToPeer(
|
||||
col_params_->instance.devices[target_rank],
|
||||
col_params_->instance.task_names[target_rank], send_buf_key,
|
||||
col_params_->group.task_names[target_rank], send_buf_key,
|
||||
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
|
||||
done);
|
||||
@ -103,7 +103,7 @@ void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
|
||||
<< " target_rank=" << target_rank << " src_rank=" << src_rank;
|
||||
col_ctx_->col_exec->remote_access()->RecvFromPeer(
|
||||
col_params_->instance.devices[src_rank],
|
||||
col_params_->instance.task_names[src_rank],
|
||||
col_params_->group.task_names[src_rank],
|
||||
col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device,
|
||||
col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
|
||||
|
@ -183,11 +183,11 @@ class PermuterTest : public ::testing::Test {
|
||||
} else {
|
||||
dev_name = strings::StrCat(task_name, "/device:CPU:", di);
|
||||
}
|
||||
col_params_.instance.device_names.push_back(dev_name);
|
||||
col_params_.group.device_names.push_back(dev_name);
|
||||
col_params_.instance.devices.push_back(dev_name);
|
||||
int default_rank = wi * num_devices_per_worker + di;
|
||||
permutation_.push_back(default_rank);
|
||||
col_params_.instance.task_names.push_back(task_name);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
col_params_.task.is_local.push_back(true);
|
||||
}
|
||||
}
|
||||
@ -212,7 +212,7 @@ class PermuterTest : public ::testing::Test {
|
||||
for (int di = 0; di < num_devices_per_worker; di++) {
|
||||
int default_rank = wi * num_devices_per_worker + di;
|
||||
instances_.push_back(new DeviceInstance(
|
||||
default_rank, col_params_.instance.device_names[default_rank],
|
||||
default_rank, col_params_.group.device_names[default_rank],
|
||||
device_type, this));
|
||||
}
|
||||
}
|
||||
@ -323,16 +323,14 @@ class PermuterTest : public ::testing::Test {
|
||||
col_params_.instance.instance_key =
|
||||
parent_->col_params_.instance.instance_key;
|
||||
col_params_.group.device_type = parent_->col_params_.group.device_type;
|
||||
col_params_.instance.device_names =
|
||||
parent_->col_params_.instance.device_names;
|
||||
col_params_.group.device_names = parent_->col_params_.group.device_names;
|
||||
col_params_.instance.devices = parent_->col_params_.instance.devices;
|
||||
col_params_.instance.permutation =
|
||||
parent->col_params_.instance.permutation;
|
||||
col_params_.instance.task_names =
|
||||
parent_->col_params_.instance.task_names;
|
||||
col_params_.group.task_names = parent_->col_params_.group.task_names;
|
||||
col_params_.task.is_local = parent_->col_params_.task.is_local;
|
||||
CHECK_EQ(col_params_.instance.devices.size(),
|
||||
col_params_.instance.device_names.size());
|
||||
col_params_.group.device_names.size());
|
||||
// Default rank is order in device_names.
|
||||
col_params_.default_rank = rank;
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
|
||||
|
||||
Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
const string& device_name =
|
||||
col_params->instance.device_names[col_params->default_rank];
|
||||
col_params->group.device_names[col_params->default_rank];
|
||||
// Each subdiv permutation is a ring formed by rotating each
|
||||
// single-task subsequence of devices by an offset. This makes most
|
||||
// sense when each task has the same number of devices but we can't
|
||||
@ -175,15 +175,15 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
// Precondition: device_names must be sorted so that all devices in
|
||||
// the same task are adjacent.
|
||||
VLOG(2) << "Sorted task names: "
|
||||
<< absl::StrJoin(col_params->instance.task_names, ", ");
|
||||
<< absl::StrJoin(col_params->group.task_names, ", ");
|
||||
std::vector<int> dev_per_task;
|
||||
const string* prior_task_name = &col_params->instance.task_names[0];
|
||||
const string* prior_task_name = &col_params->group.task_names[0];
|
||||
int dev_count = 1;
|
||||
for (int di = 1; di < col_params->group.group_size; ++di) {
|
||||
if (col_params->instance.task_names[di] != *prior_task_name) {
|
||||
if (col_params->group.task_names[di] != *prior_task_name) {
|
||||
dev_per_task.push_back(dev_count);
|
||||
dev_count = 1;
|
||||
prior_task_name = &col_params->instance.task_names[di];
|
||||
prior_task_name = &col_params->group.task_names[di];
|
||||
} else {
|
||||
++dev_count;
|
||||
}
|
||||
@ -227,7 +227,7 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
int permuted_di = prior_dev_count + offset_di;
|
||||
int rank = static_cast<int>(perm.size());
|
||||
perm.push_back(permuted_di);
|
||||
if (col_params->instance.device_names[permuted_di] == device_name) {
|
||||
if (col_params->group.device_names[permuted_di] == device_name) {
|
||||
DCHECK_EQ(permuted_di, col_params->default_rank);
|
||||
col_params->subdiv_rank[sdi] = rank;
|
||||
}
|
||||
@ -385,8 +385,8 @@ void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
|
||||
int send_to_dev_idx = col_params_->instance.impl_details
|
||||
.subdiv_permutations[rf->subdiv_idx][send_to_rank];
|
||||
col_ctx_->col_exec->remote_access()->PostToPeer(
|
||||
col_params_->instance.device_names[send_to_dev_idx],
|
||||
col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
|
||||
col_params_->group.device_names[send_to_dev_idx],
|
||||
col_params_->group.task_names[send_to_dev_idx], send_buf_key,
|
||||
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
|
||||
col_ctx_->device_locality, done);
|
||||
@ -404,8 +404,8 @@ void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
|
||||
? &rf->tmp_chunk
|
||||
: &rf->chunk;
|
||||
col_ctx_->col_exec->remote_access()->RecvFromPeer(
|
||||
col_params_->instance.device_names[rf->recv_dev_idx],
|
||||
col_params_->instance.task_names[rf->recv_dev_idx],
|
||||
col_params_->group.device_names[rf->recv_dev_idx],
|
||||
col_params_->group.task_names[rf->recv_dev_idx],
|
||||
col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
|
||||
col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
|
||||
col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
|
||||
|
@ -71,9 +71,9 @@ void RingGatherer::Run(StatusCallback done) {
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string buf;
|
||||
for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
|
||||
for (int r = 0; r < col_params_->group.device_names.size(); ++r) {
|
||||
strings::StrAppend(&buf, "dev ", r, " : ",
|
||||
col_params_->instance.device_names[r], "\n");
|
||||
col_params_->group.device_names[r], "\n");
|
||||
}
|
||||
for (int sd = 0;
|
||||
sd < col_params_->instance.impl_details.subdiv_permutations.size();
|
||||
|
@ -222,8 +222,8 @@ class RingGathererTest : public ::testing::Test {
|
||||
dev_name =
|
||||
strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
|
||||
}
|
||||
col_params_.instance.device_names.push_back(dev_name);
|
||||
col_params_.instance.task_names.push_back(task_name);
|
||||
col_params_.group.device_names.push_back(dev_name);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
// Normally each device would set is_local to its own perspective but
|
||||
// this test runs in a single process so is_local is always true.
|
||||
col_params_.task.is_local.push_back(true);
|
||||
@ -240,7 +240,7 @@ class RingGathererTest : public ::testing::Test {
|
||||
for (int di = 0; di < num_devices; ++di) {
|
||||
int rank = wi * num_devices + di;
|
||||
instances_.push_back(new DeviceInstance(
|
||||
rank, col_params_.instance.device_names[rank], device_type_, this));
|
||||
rank, col_params_.group.device_names[rank], device_type_, this));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -389,9 +389,7 @@ class RingGathererTest : public ::testing::Test {
|
||||
<< "Couldn't find device " << dev_name
|
||||
<< " existing devices: " << parent_->dev_mgr_->DebugString();
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.group.group_key = parent_->col_params_.group.group_key;
|
||||
col_params_.group.device_type = parent_->col_params_.group.device_type;
|
||||
col_params_.group.group_size = parent_->col_params_.group.group_size;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
col_params_.instance = parent->col_params_.instance;
|
||||
col_params_.task.is_local = parent_->col_params_.task.is_local;
|
||||
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
|
||||
@ -399,7 +397,7 @@ class RingGathererTest : public ::testing::Test {
|
||||
int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
|
||||
int group_size = col_params_.group.group_size;
|
||||
CHECK_EQ(group_size,
|
||||
static_cast<int>(col_params_.instance.device_names.size()));
|
||||
static_cast<int>(col_params_.group.device_names.size()));
|
||||
// Id of this device is at rank position in first subdiv perm.
|
||||
int my_device_id =
|
||||
col_params_.instance.impl_details.subdiv_permutations[0][rank];
|
||||
@ -547,8 +545,8 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
|
||||
int dev_id = i % num_devs_per_task;
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
|
||||
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
|
||||
cp.instance.task_names.push_back(task_name);
|
||||
cp.instance.device_names.push_back(device_name);
|
||||
cp.group.task_names.push_back(task_name);
|
||||
cp.group.device_names.push_back(device_name);
|
||||
}
|
||||
return cp;
|
||||
}
|
||||
|
@ -67,9 +67,9 @@ void RingReducer::Run(StatusCallback done) {
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string buf;
|
||||
for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
|
||||
for (int r = 0; r < col_params_->group.device_names.size(); ++r) {
|
||||
strings::StrAppend(&buf, "dev ", r, " : ",
|
||||
col_params_->instance.device_names[r], "\n");
|
||||
col_params_->group.device_names[r], "\n");
|
||||
}
|
||||
for (int sd = 0;
|
||||
sd < col_params_->instance.impl_details.subdiv_permutations.size();
|
||||
|
@ -238,15 +238,15 @@ class RingReducerTest : public ::testing::Test {
|
||||
// Set up all of the fake device contexts.
|
||||
for (int wi = 0; wi < num_workers; ++wi) {
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||
col_params_.instance.num_devices_per_task[task_name] = num_devices;
|
||||
col_params_.group.num_devices_per_task[task_name] = num_devices;
|
||||
for (int di = 0; di < num_devices; ++di) {
|
||||
string dev_name = strings::StrCat(task_name, "/cpu:", di);
|
||||
if (device_type == DEVICE_GPU) {
|
||||
dev_name =
|
||||
strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
|
||||
}
|
||||
col_params_.instance.device_names.push_back(dev_name);
|
||||
col_params_.instance.task_names.push_back(task_name);
|
||||
col_params_.group.device_names.push_back(dev_name);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
// Normally each device would set is_local to its own perspective but
|
||||
// this test runs in a single process so is_local is always true.
|
||||
col_params_.task.is_local.push_back(true);
|
||||
@ -263,7 +263,7 @@ class RingReducerTest : public ::testing::Test {
|
||||
for (int di = 0; di < num_devices; ++di) {
|
||||
int rank = wi * num_devices + di;
|
||||
instances_.push_back(new DeviceInstance(
|
||||
rank, col_params_.instance.device_names[rank], device_type_, this));
|
||||
rank, col_params_.group.device_names[rank], device_type_, this));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -414,9 +414,7 @@ class RingReducerTest : public ::testing::Test {
|
||||
<< "Couldn't find device " << dev_name
|
||||
<< " existing devices: " << parent_->dev_mgr_->DebugString();
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.group.group_key = parent_->col_params_.group.group_key;
|
||||
col_params_.group.device_type = parent_->col_params_.group.device_type;
|
||||
col_params_.group.group_size = parent_->col_params_.group.group_size;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
col_params_.instance = parent->col_params_.instance;
|
||||
col_params_.task.is_local = parent_->col_params_.task.is_local;
|
||||
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
|
||||
@ -424,7 +422,7 @@ class RingReducerTest : public ::testing::Test {
|
||||
int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
|
||||
int group_size = col_params_.group.group_size;
|
||||
CHECK_EQ(group_size,
|
||||
static_cast<int>(col_params_.instance.device_names.size()));
|
||||
static_cast<int>(col_params_.group.device_names.size()));
|
||||
// Id of this device is at rank position in first subdiv perm.
|
||||
int my_device_id =
|
||||
col_params_.instance.impl_details.subdiv_permutations[0][rank];
|
||||
@ -574,8 +572,8 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
|
||||
int dev_id = i % num_devs_per_task;
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
|
||||
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
|
||||
cp.instance.task_names.push_back(task_name);
|
||||
cp.instance.device_names.push_back(device_name);
|
||||
cp.group.task_names.push_back(task_name);
|
||||
cp.group.device_names.push_back(device_name);
|
||||
}
|
||||
return cp;
|
||||
}
|
||||
|
@ -250,27 +250,32 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
||||
gr->devices[device.name()] = device;
|
||||
}
|
||||
gr->group.runtime_details.communicator_key = resp.communicator_key();
|
||||
FinishGroup(gr.get());
|
||||
}
|
||||
GroupRec* previous_gr = nullptr;
|
||||
{
|
||||
// Group membership should never change. Once a record is in group_table_
|
||||
// it never gets removed.
|
||||
mutex_lock l(group_mu_);
|
||||
auto it = group_table_.find(gr->group.group_key);
|
||||
auto it = group_table_.find(resp.group_key());
|
||||
if (it == group_table_.end()) {
|
||||
VLOG(2) << "UpdateGroupCache: communicator_key="
|
||||
<< absl::CEscape(gr->group.runtime_details.communicator_key);
|
||||
<< absl::CEscape(resp.communicator_key());
|
||||
group_table_[gr->group.group_key] = std::move(gr);
|
||||
} else {
|
||||
auto& previous_gr = group_table_[gr->group.group_key];
|
||||
if (previous_gr->group.runtime_details.communicator_key !=
|
||||
gr->group.runtime_details.communicator_key) {
|
||||
return errors::Internal(
|
||||
"UpdateGroupCache: CompleteGroupResponse for group ",
|
||||
gr->group.group_key, " gives communicator_key=",
|
||||
absl::CEscape(gr->group.runtime_details.communicator_key),
|
||||
" but cache already holds communicator_key=",
|
||||
absl::CEscape(previous_gr->group.runtime_details.communicator_key));
|
||||
}
|
||||
previous_gr = it->second.get();
|
||||
}
|
||||
}
|
||||
if (previous_gr != nullptr) {
|
||||
mutex_lock grl(previous_gr->mu);
|
||||
if (previous_gr->group.runtime_details.communicator_key !=
|
||||
resp.communicator_key()) {
|
||||
return errors::Internal(
|
||||
"UpdateGroupCache: CompleteGroupResponse for group ",
|
||||
resp.group_key(),
|
||||
" gives communicator_key=", absl::CEscape(resp.communicator_key()),
|
||||
" but cache already holds communicator_key=",
|
||||
absl::CEscape(previous_gr->group.runtime_details.communicator_key));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -362,7 +367,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
|
||||
if (group_leader_.empty()) {
|
||||
// This is the group leader so resolution is local.
|
||||
return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
|
||||
} else if (InstanceIsCached(gr->group.group_key, cp->instance.instance_key)) {
|
||||
} else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
|
||||
return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
|
||||
} else {
|
||||
CompleteInstanceCall* call = new CompleteInstanceCall(
|
||||
|
@ -253,18 +253,18 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
int idx = wi * num_devices + di;
|
||||
TF_ASSERT_OK(status_[device_name]);
|
||||
EXPECT_EQ(cp_[device_name].default_rank, idx);
|
||||
EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
|
||||
EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
|
||||
EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
|
||||
EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count);
|
||||
EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name);
|
||||
EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name);
|
||||
ValidateDeviceResolver(cp_[device_name], task_name);
|
||||
if (idx > 0) {
|
||||
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
|
||||
cp_[device_name].group.runtime_details.communicator_key);
|
||||
for (int i = 0; i < dev_count; ++i) {
|
||||
EXPECT_EQ(cp_[dev0].instance.device_names[i],
|
||||
cp_[device_name].instance.device_names[i]);
|
||||
EXPECT_EQ(cp_[dev0].instance.task_names[i],
|
||||
cp_[device_name].instance.task_names[i]);
|
||||
EXPECT_EQ(cp_[dev0].group.device_names[i],
|
||||
cp_[device_name].group.device_names[i]);
|
||||
EXPECT_EQ(cp_[dev0].group.task_names[i],
|
||||
cp_[device_name].group.task_names[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -272,7 +272,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void ValidateDeviceResolver(const CollectiveParams& cp, const string& task) {
|
||||
for (const string& device_name : cp.instance.device_names) {
|
||||
for (const string& device_name : cp.group.device_names) {
|
||||
DeviceAttributes attributes;
|
||||
TF_ASSERT_OK(
|
||||
dev_resolvers_[task]->GetDeviceAttributes(device_name, &attributes));
|
||||
|
@ -54,10 +54,23 @@ string CollGroupRuntimeDetails::ToString() const {
|
||||
}
|
||||
|
||||
string CollGroupParams::ToString() const {
|
||||
return strings::StrCat(
|
||||
string v = strings::StrCat(
|
||||
"CollGroupParams {group_key=", group_key, " group_size=", group_size,
|
||||
" device_type=", device_type.type_string(), " num_tasks=", num_tasks,
|
||||
" runtime_details=", runtime_details.ToString(), "}");
|
||||
" runtime_details=", runtime_details.ToString(), " devices {");
|
||||
for (const auto& d : device_names) {
|
||||
strings::StrAppend(&v, d, ",");
|
||||
}
|
||||
strings::StrAppend(&v, "} task_names={");
|
||||
for (const auto& n : task_names) {
|
||||
strings::StrAppend(&v, n, ", ");
|
||||
}
|
||||
strings::StrAppend(&v, "} num_devices_per_task={");
|
||||
for (const auto& dpt : num_devices_per_task) {
|
||||
strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
|
||||
}
|
||||
strings::StrAppend(&v, "}");
|
||||
return v;
|
||||
}
|
||||
|
||||
CollInstanceParams& CollInstanceParams::operator=(
|
||||
@ -67,12 +80,6 @@ CollInstanceParams& CollInstanceParams::operator=(
|
||||
type = other.type;
|
||||
data_type = other.data_type;
|
||||
shape = other.shape;
|
||||
device_names.clear();
|
||||
device_names.assign(other.device_names.begin(), other.device_names.end());
|
||||
task_names.assign(other.task_names.begin(), other.task_names.end());
|
||||
same_num_devices_per_task = other.same_num_devices_per_task;
|
||||
num_devices_per_task = other.num_devices_per_task;
|
||||
gpu_ring_order = other.gpu_ring_order;
|
||||
impl_details.subdiv_offsets.assign(
|
||||
other.impl_details.subdiv_offsets.begin(),
|
||||
other.impl_details.subdiv_offsets.end());
|
||||
@ -96,17 +103,6 @@ string CollInstanceParams::ToString() const {
|
||||
strings::StrCat("CollInstanceParams { instance_key=", instance_key,
|
||||
" type=", type, " data_type=", DataTypeString(data_type),
|
||||
" shape=", shape.DebugString(), " devices {");
|
||||
for (const auto& d : device_names) {
|
||||
strings::StrAppend(&v, d, ",");
|
||||
}
|
||||
strings::StrAppend(&v, "} task_names={");
|
||||
for (const auto& n : task_names) {
|
||||
strings::StrAppend(&v, n, ", ");
|
||||
}
|
||||
strings::StrAppend(&v, "} num_devices_per_task={");
|
||||
for (const auto& dpt : num_devices_per_task) {
|
||||
strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
|
||||
}
|
||||
strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
|
||||
", subdiv_offsets={");
|
||||
strings::StrAppend(&v, "}, subdiv_offsets={");
|
||||
@ -186,7 +182,7 @@ CollectiveContext::CollectiveContext(
|
||||
input(input),
|
||||
output(output),
|
||||
device(nullptr),
|
||||
device_name(col_params.instance.device_names[col_params.default_rank]) {}
|
||||
device_name(col_params.group.device_names[col_params.default_rank]) {}
|
||||
|
||||
/*static*/
|
||||
int64 CollectiveExecutor::kInvalidId = -1;
|
||||
|
@ -63,6 +63,17 @@ struct CollGroupParams {
|
||||
int32 group_key;
|
||||
int32 group_size;
|
||||
DeviceType device_type;
|
||||
// Fully qualified name of device for each member, in default rank order.
|
||||
std::vector<string> device_names;
|
||||
// Task name prefix of corresponding device name.
|
||||
std::vector<string> task_names;
|
||||
// True if every task has the same number of devices.
|
||||
bool same_num_devices_per_task = false;
|
||||
// Task -> number of devices on that task.
|
||||
std::unordered_map<string, int32> num_devices_per_task;
|
||||
// If passed in to GPUOptions in ConfigProto, defines a good ring order for
|
||||
// GPUs. Assumes same GPU configuration at each worker.
|
||||
string gpu_ring_order = "";
|
||||
int32 num_tasks; // number of distinct tasks in group
|
||||
CollGroupRuntimeDetails runtime_details;
|
||||
string ToString() const;
|
||||
@ -98,17 +109,6 @@ struct CollInstanceParams {
|
||||
CollectiveType type = UNDEFINED_COLLECTIVE;
|
||||
DataType data_type = DT_FLOAT;
|
||||
TensorShape shape = {0};
|
||||
// Fully qualified name of device for each member, in default rank order.
|
||||
std::vector<string> device_names;
|
||||
// Task name prefix of corresponding device name.
|
||||
std::vector<string> task_names;
|
||||
// True if every task has the same number of devices.
|
||||
bool same_num_devices_per_task = false;
|
||||
// Task -> number of devices on that task.
|
||||
std::unordered_map<string, int32> num_devices_per_task;
|
||||
// If passed in to GPUOptions in ConfigProto, defines a good ring order for
|
||||
// GPUs. Assumes same GPU configuration at each worker.
|
||||
string gpu_ring_order = "";
|
||||
CollImplDetails impl_details;
|
||||
string ToString() const;
|
||||
CollInstanceParams& operator=(const struct CollInstanceParams& other);
|
||||
|
@ -136,15 +136,14 @@ class NcclTestBase : public ::testing::Test {
|
||||
col_params_.instance.data_type = DT_FLOAT;
|
||||
col_params_.instance.impl_details.collective_name = collective_name_;
|
||||
const string task_name = "/job:worker/replica:0/task:0";
|
||||
col_params_.instance.num_devices_per_task[task_name] = num_ranks;
|
||||
col_params_.group.num_devices_per_task[task_name] = num_ranks;
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
col_params_.instance.device_names.push_back(
|
||||
device_names[rank % num_gpus]);
|
||||
col_params_.instance.task_names.push_back(task_name);
|
||||
col_params_.group.device_names.push_back(device_names[rank % num_gpus]);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
}
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
instances_.push_back(absl::make_unique<DeviceInstance>(
|
||||
rank, col_params_.instance.device_names[rank], this));
|
||||
rank, col_params_.group.device_names[rank], this));
|
||||
}
|
||||
}
|
||||
|
||||
@ -251,9 +250,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
<< parent_->dev_mgr_->DebugString();
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.default_rank = rank;
|
||||
col_params_.group.group_key = parent_->col_params_.group.group_key;
|
||||
col_params_.group.device_type = parent_->col_params_.group.device_type;
|
||||
col_params_.group.group_size = parent_->col_params_.group.group_size;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
col_params_.instance = parent->col_params_.instance;
|
||||
}
|
||||
|
||||
|
@ -64,8 +64,7 @@ class CollectiveOpKernel : public AsyncOpKernel {
|
||||
// immediately.
|
||||
bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
const DoneCallback& done) {
|
||||
if (col_params_.group.group_size >
|
||||
col_params_.instance.device_names.size()) {
|
||||
if (col_params_.group.group_size > col_params_.group.device_names.size()) {
|
||||
// This is the first invocation: Finish initializing col_params_.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
|
@ -69,8 +69,8 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
StatusCallback done) {
|
||||
const CollectiveParams& col_params = col_ctx->col_params;
|
||||
const int num_global_devices = col_params.group.group_size;
|
||||
const int num_local_devices = col_params.instance.num_devices_per_task.at(
|
||||
col_params.instance.task_names[col_params.default_rank]);
|
||||
const int num_local_devices = col_params.group.num_devices_per_task.at(
|
||||
col_params.group.task_names[col_params.default_rank]);
|
||||
const string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
|
||||
auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
|
||||
@ -84,7 +84,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
col_params.source_rank);
|
||||
VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
|
||||
<< " num_tasks " << col_params.group.num_tasks << " current task "
|
||||
<< col_params.instance.task_names[col_params.default_rank]
|
||||
<< col_params.group.task_names[col_params.default_rank]
|
||||
<< " num local devices " << num_local_devices
|
||||
<< " num global devices " << num_global_devices << " device "
|
||||
<< col_ctx->device_name << " instance "
|
||||
|
Loading…
Reference in New Issue
Block a user