Exchange device attributes at group resolution
Previously CollectiveParamResolver queries device attributes when initializing instance params. That has issues when the collective leader fails and restarts quickly between group resolution and instance resolution. In such case, all other workers get the incarnation of the restarted leader, thus they're unable to detect that the leader has failed; the leader will deadlock on the group resolution. This change doesn't fully fixed the issue because it only exchanges device attributes at group resolution, but doesn't populate the device attributes to DeviceResolver. That will be done in a following change. This change also changes the behavior when a non-leader fails and restarts. Previously it gets the cached group resolution from the leader, now it will get an error because its incarnation doesn't match with the one in the cached group parameters. This should have no actual effect since that worker will always restart again after the leader has restarted. This change changes both the client and server without being backward compatible. It assumes that client and server are running the same version of Tensorflow. This should be true since the only way to use CollectiveParamResolverDistributed is through MultiWorkerMirroredStrategy (MWMS). For MWMS, all workers should run the same version of the program. PiperOrigin-RevId: 329735919 Change-Id: I5c29a3ec8462c7737bcbbbf823a95693b0d27dc3
This commit is contained in:
parent
7921c4d517
commit
49220c3679
@ -379,12 +379,11 @@ cc_library(
|
|||||||
hdrs = ["collective_param_resolver_local.h"],
|
hdrs = ["collective_param_resolver_local.h"],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":device",
|
||||||
":device_mgr",
|
":device_mgr",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -298,8 +298,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BaseCollectiveExecutor::CompleteParamsAsync(
|
void BaseCollectiveExecutor::CompleteParamsAsync(
|
||||||
const DeviceAttributes& device, CollectiveParams* cp,
|
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||||
CancellationManager* cancel_mgr, StatusCallback done) {
|
StatusCallback done) {
|
||||||
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
cp->instance.gpu_ring_order = *gpu_ring_order_;
|
||||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||||
auto done_with_timeout = done;
|
auto done_with_timeout = done;
|
||||||
|
@ -113,7 +113,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
|
|||||||
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
|
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
|
||||||
const string& exec_key, StatusCallback done) override;
|
const string& exec_key, StatusCallback done) override;
|
||||||
|
|
||||||
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
|
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
StatusCallback done) override;
|
StatusCallback done) override;
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
@ -31,7 +30,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
@ -76,21 +74,12 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
|||||||
return "undef";
|
return "undef";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string TaskNameFromDeviceName(const string& device_name) {
|
|
||||||
DeviceNameUtils::ParsedName parsed_device;
|
|
||||||
CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
|
|
||||||
string task_name;
|
|
||||||
CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
|
|
||||||
return task_name;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void CollectiveParamResolverLocal::CompleteGroupLocal(
|
void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||||
const DeviceAttributes& device, CollectiveParams* cp,
|
const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
|
||||||
const GroupRecCallback& done) {
|
VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
|
||||||
VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
|
<< cp->ToString();
|
||||||
<< ": " << cp->ToString();
|
|
||||||
std::vector<StatusCallback> to_be_called;
|
std::vector<StatusCallback> to_be_called;
|
||||||
GroupRec* gr = nullptr;
|
GroupRec* gr = nullptr;
|
||||||
Status status;
|
Status status;
|
||||||
@ -150,13 +139,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
|||||||
// status.
|
// status.
|
||||||
VLOG(2) << "gr device_type=" << gr->group.device_type
|
VLOG(2) << "gr device_type=" << gr->group.device_type
|
||||||
<< " cp device_type=" << cp->group.device_type
|
<< " cp device_type=" << cp->group.device_type
|
||||||
<< " current device=" << device.name();
|
<< " current device=" << device;
|
||||||
if (gr->status.ok()) {
|
if (gr->status.ok()) {
|
||||||
// Check for consistency with existing GroupRec.
|
// Check for consistency with existing GroupRec.
|
||||||
if (cp->group.device_type != gr->group.device_type) {
|
if (cp->group.device_type != gr->group.device_type) {
|
||||||
gr->status = errors::Internal(
|
gr->status = errors::Internal(
|
||||||
"Collective Op ", cp->name, " is assigned to device ",
|
"Collective Op ", cp->name, " is assigned to device ", device,
|
||||||
device.name(), " with type ", cp->group.device_type.type_string(),
|
" with type ", cp->group.device_type.type_string(),
|
||||||
" and group_key ", cp->group.group_key, " but that group has type ",
|
" and group_key ", cp->group.group_key, " but that group has type ",
|
||||||
gr->group.device_type.type_string());
|
gr->group.device_type.type_string());
|
||||||
} else if (cp->group.group_size != gr->group.group_size) {
|
} else if (cp->group.group_size != gr->group.group_size) {
|
||||||
@ -168,47 +157,38 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
|||||||
}
|
}
|
||||||
if (gr->status.ok()) {
|
if (gr->status.ok()) {
|
||||||
// Insert device if not already present.
|
// Insert device if not already present.
|
||||||
auto it = gr->devices.find(device.name());
|
auto it = gr->device_set.find(device);
|
||||||
if (it == gr->devices.end()) {
|
if (it == gr->device_set.end()) {
|
||||||
if (gr->devices.size() == gr->group.group_size) {
|
if (gr->device_set.size() == gr->group.group_size) {
|
||||||
// The group is already full.
|
// The group is already full.
|
||||||
gr->status = errors::Internal(
|
gr->status = errors::Internal(
|
||||||
"Collective Op ", cp->name, " is assigned to device ",
|
"Collective Op ", cp->name, " is assigned to device ", device,
|
||||||
device.name(), " and group_key ", cp->group.group_key,
|
" and group_key ", cp->group.group_key,
|
||||||
" but that group doesn't contain that device.");
|
" but that group doesn't contain that device.");
|
||||||
} else {
|
} else {
|
||||||
// This is a new device that has not yet joined the group.
|
// This is a new device that has not yet joined the group.
|
||||||
gr->devices[device.name()] = device;
|
gr->device_set.insert(device);
|
||||||
if (gr->devices.size() == gr->group.group_size) {
|
gr->device_list.push_back(device);
|
||||||
// The group is full after adding this device, calculate the number
|
DeviceNameUtils::ParsedName parsed_device;
|
||||||
// of tasks.
|
DeviceNameUtils::ParseFullName(device, &parsed_device);
|
||||||
absl::flat_hash_set<string> tasks;
|
string task_name = strings::StrCat("/job:", parsed_device.job,
|
||||||
for (const auto& item : gr->devices) {
|
"/replica:", parsed_device.replica,
|
||||||
tasks.insert(TaskNameFromDeviceName(item.first));
|
"/task:", parsed_device.task);
|
||||||
}
|
gr->task_set.insert(task_name);
|
||||||
gr->group.num_tasks = static_cast<int32>(tasks.size());
|
gr->task_list.push_back(task_name);
|
||||||
}
|
gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
string dev_buf;
|
string dev_buf;
|
||||||
for (const auto& d : gr->devices) {
|
for (const auto& d : gr->device_set) {
|
||||||
strings::StrAppend(&dev_buf, ",", d.first);
|
strings::StrAppend(&dev_buf, ",", d);
|
||||||
}
|
}
|
||||||
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
|
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
|
||||||
<< " group_size=" << gr->group.group_size << " (current"
|
<< " group_size=" << gr->group.group_size << " (current"
|
||||||
<< " devices)=(" << dev_buf << ") (number of"
|
<< " devices)=(" << dev_buf << ") (number of"
|
||||||
<< " devices pending)="
|
<< " devices pending)="
|
||||||
<< (gr->group.group_size - gr->devices.size());
|
<< (gr->group.group_size - gr->device_set.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// If the device already exists, check if the incarnation matches.
|
|
||||||
if (it->second.incarnation() != device.incarnation()) {
|
|
||||||
gr->status = errors::FailedPrecondition(
|
|
||||||
"Device ", device.name(),
|
|
||||||
" current incarnation doesn't match with one in the group. This "
|
|
||||||
"usually means this worker has restarted but the collective "
|
|
||||||
"leader hasn't, or this worker connects to a wrong cluster.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,13 +196,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
|||||||
cp->group.runtime_details = gr->group.runtime_details;
|
cp->group.runtime_details = gr->group.runtime_details;
|
||||||
// If the group is not yet complete, queue to wait for it.
|
// If the group is not yet complete, queue to wait for it.
|
||||||
VLOG(2) << "group_size " << gr->group.group_size << " set size "
|
VLOG(2) << "group_size " << gr->group.group_size << " set size "
|
||||||
<< gr->devices.size() << " gr " << gr;
|
<< gr->device_set.size() << " gr " << gr;
|
||||||
|
|
||||||
if (gr->devices.size() < gr->group.group_size) {
|
if (gr->device_set.size() < gr->group.group_size) {
|
||||||
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
|
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK_EQ(gr->devices.size(), gr->group.group_size);
|
CHECK_EQ(gr->device_set.size(), gr->group.group_size);
|
||||||
}
|
}
|
||||||
// At this point, we either have a full group, or an error status. Ensure
|
// At this point, we either have a full group, or an error status. Ensure
|
||||||
// that all callbacks are invoked with the appropriate status.
|
// that all callbacks are invoked with the appropriate status.
|
||||||
@ -501,15 +481,10 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
|||||||
{
|
{
|
||||||
mutex_lock gl(gr->mu);
|
mutex_lock gl(gr->mu);
|
||||||
ir->shared.group = gr->group;
|
ir->shared.group = gr->group;
|
||||||
ir->shared.instance.device_names.clear();
|
ir->shared.instance.device_names.assign(gr->device_list.begin(),
|
||||||
ir->shared.instance.task_names.clear();
|
gr->device_list.end());
|
||||||
ir->shared.instance.device_names.reserve(gr->devices.size());
|
ir->shared.instance.task_names.assign(gr->task_list.begin(),
|
||||||
ir->shared.instance.task_names.reserve(gr->devices.size());
|
gr->task_list.end());
|
||||||
for (const auto& item : gr->devices) {
|
|
||||||
ir->shared.instance.device_names.push_back(item.first);
|
|
||||||
ir->shared.instance.task_names.push_back(
|
|
||||||
TaskNameFromDeviceName(item.first));
|
|
||||||
}
|
|
||||||
VLOG(2) << "Initialized names for instance: "
|
VLOG(2) << "Initialized names for instance: "
|
||||||
<< ir->shared.instance.ToString();
|
<< ir->shared.instance.ToString();
|
||||||
}
|
}
|
||||||
@ -707,15 +682,15 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CollectiveParamResolverLocal::CompleteParamsAsync(
|
void CollectiveParamResolverLocal::CompleteParamsAsync(
|
||||||
const DeviceAttributes& device, CollectiveParams* cp,
|
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||||
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
const StatusCallback& done) {
|
||||||
VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
|
VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
|
||||||
<< cp->ToString();
|
<< cp->ToString();
|
||||||
CompleteGroupLocal(
|
CompleteGroupLocal(
|
||||||
device, cp,
|
device, cp,
|
||||||
[this, device, cp, done](const Status& s, const GroupRec* gr) {
|
[this, device, cp, done](const Status& s, const GroupRec* gr) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
|
CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
|
||||||
} else {
|
} else {
|
||||||
done(s);
|
done(s);
|
||||||
}
|
}
|
||||||
|
@ -21,9 +21,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
|
||||||
@ -47,7 +45,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
|||||||
|
|
||||||
~CollectiveParamResolverLocal() override {}
|
~CollectiveParamResolverLocal() override {}
|
||||||
|
|
||||||
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
|
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
@ -72,7 +70,10 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
|||||||
CollGroupParams group;
|
CollGroupParams group;
|
||||||
mutable mutex mu;
|
mutable mutex mu;
|
||||||
Status status TF_GUARDED_BY(mu);
|
Status status TF_GUARDED_BY(mu);
|
||||||
absl::flat_hash_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
|
std::set<string> device_set TF_GUARDED_BY(mu);
|
||||||
|
std::vector<string> device_list TF_GUARDED_BY(mu);
|
||||||
|
std::set<string> task_set TF_GUARDED_BY(mu);
|
||||||
|
std::vector<string> task_list TF_GUARDED_BY(mu);
|
||||||
std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
|
std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -84,7 +85,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
|||||||
// callback.
|
// callback.
|
||||||
typedef std::function<void(const Status& s, const GroupRec* gr)>
|
typedef std::function<void(const Status& s, const GroupRec* gr)>
|
||||||
GroupRecCallback;
|
GroupRecCallback;
|
||||||
void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
|
void CompleteGroupLocal(const string& device, CollectiveParams* cp,
|
||||||
const GroupRecCallback& done)
|
const GroupRecCallback& done)
|
||||||
TF_LOCKS_EXCLUDED(group_mu_);
|
TF_LOCKS_EXCLUDED(group_mu_);
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -87,12 +86,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAttributes GetDeviceAttributes(const string& device_name) {
|
|
||||||
Device* device = nullptr;
|
|
||||||
TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device));
|
|
||||||
return device->attributes();
|
|
||||||
}
|
|
||||||
|
|
||||||
string task_name_;
|
string task_name_;
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||||
@ -194,13 +187,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
|
|||||||
cp->instance.impl_details.subdiv_offsets.push_back(0);
|
cp->instance.impl_details.subdiv_offsets.push_back(0);
|
||||||
cp->is_source = false;
|
cp->is_source = false;
|
||||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||||
prl_->CompleteParamsAsync(
|
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
|
||||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
nullptr /*CancellationManager*/,
|
||||||
nullptr /*CancellationManager*/,
|
[&statuses, ¬e, i](const Status& s) {
|
||||||
[&statuses, ¬e, i](const Status& s) {
|
statuses[i] = s;
|
||||||
statuses[i] = s;
|
note[i].Notify();
|
||||||
note[i].Notify();
|
});
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||||
@ -248,13 +240,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
|
|||||||
CollectiveParams* cp = &cps[i];
|
CollectiveParams* cp = &cps[i];
|
||||||
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
|
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
|
||||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||||
prl_->CompleteParamsAsync(
|
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
|
||||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
nullptr /*CancellationManager*/,
|
||||||
nullptr /*CancellationManager*/,
|
[&statuses, ¬e, i](const Status& s) {
|
||||||
[&statuses, ¬e, i](const Status& s) {
|
statuses[i] = s;
|
||||||
statuses[i] = s;
|
note[i].Notify();
|
||||||
note[i].Notify();
|
});
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||||
@ -287,13 +278,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
|
|||||||
CollectiveParams* cp = &cps[i];
|
CollectiveParams* cp = &cps[i];
|
||||||
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
|
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
|
||||||
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
|
||||||
prl_->CompleteParamsAsync(
|
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
|
||||||
GetDeviceAttributes(cp->instance.device_names[0]), cp,
|
nullptr /*CancellationManager*/,
|
||||||
nullptr /*CancellationManager*/,
|
[&statuses, ¬e, i](const Status& s) {
|
||||||
[&statuses, ¬e, i](const Status& s) {
|
statuses[i] = s;
|
||||||
statuses[i] = s;
|
note[i].Notify();
|
||||||
note[i].Notify();
|
});
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
for (int i = 0; i < NUM_DEVS; ++i) {
|
for (int i = 0; i < NUM_DEVS; ++i) {
|
||||||
@ -336,8 +326,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
|
|||||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||||
cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
|
cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
|
||||||
/*is_source*/ i == 0);
|
/*is_source*/ i == 0);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
|
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
|
||||||
&cancel_mgr, [&done](const Status& s) {
|
[&done](const Status& s) {
|
||||||
EXPECT_EQ(s.code(), error::ABORTED);
|
EXPECT_EQ(s.code(), error::ABORTED);
|
||||||
EXPECT_EQ(s.error_message(), "__aborted__");
|
EXPECT_EQ(s.error_message(), "__aborted__");
|
||||||
done.DecrementCount();
|
done.DecrementCount();
|
||||||
@ -365,8 +355,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
|
|||||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||||
cp[i] = MakeCollectiveParams(group_key, instance_key,
|
cp[i] = MakeCollectiveParams(group_key, instance_key,
|
||||||
/*is_source*/ i == 0);
|
/*is_source*/ i == 0);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
|
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
|
||||||
&cancel_mgr, [&done](const Status& s) {
|
[&done](const Status& s) {
|
||||||
EXPECT_EQ(s.code(), error::OK);
|
EXPECT_EQ(s.code(), error::OK);
|
||||||
done.DecrementCount();
|
done.DecrementCount();
|
||||||
});
|
});
|
||||||
@ -383,13 +373,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
|
|||||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||||
cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
|
cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
|
||||||
/*is_source*/ i == 0);
|
/*is_source*/ i == 0);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
|
prl_->CompleteParamsAsync(
|
||||||
&cancel_mgr, [&done](const Status& s) {
|
device, &cp[i], &cancel_mgr, [&done](const Status& s) {
|
||||||
EXPECT_EQ(s.code(), error::ABORTED);
|
EXPECT_EQ(s.code(), error::ABORTED);
|
||||||
EXPECT_EQ(s.error_message(),
|
EXPECT_EQ(s.error_message(), "__aborted__");
|
||||||
"__aborted__");
|
done.DecrementCount();
|
||||||
done.DecrementCount();
|
});
|
||||||
});
|
|
||||||
start.DecrementCount();
|
start.DecrementCount();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -413,8 +402,8 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
|
|||||||
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
|
||||||
cp[i] = MakeCollectiveParams(group_key, instance_key,
|
cp[i] = MakeCollectiveParams(group_key, instance_key,
|
||||||
/*is_source*/ i == 0);
|
/*is_source*/ i == 0);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
|
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
|
||||||
&cancel_mgr, [&done](const Status& s) {
|
[&done](const Status& s) {
|
||||||
EXPECT_EQ(s.code(), error::OK);
|
EXPECT_EQ(s.code(), error::OK);
|
||||||
done.DecrementCount();
|
done.DecrementCount();
|
||||||
});
|
});
|
||||||
@ -429,7 +418,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
|
|||||||
Notification done;
|
Notification done;
|
||||||
auto cp = MakeCollectiveParams(group_key, instance_key,
|
auto cp = MakeCollectiveParams(group_key, instance_key,
|
||||||
/*is_source*/ true);
|
/*is_source*/ true);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
|
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
|
||||||
[&done](const Status& s) {
|
[&done](const Status& s) {
|
||||||
EXPECT_EQ(s.code(), error::ABORTED);
|
EXPECT_EQ(s.code(), error::ABORTED);
|
||||||
EXPECT_EQ(s.error_message(), "__aborted__");
|
EXPECT_EQ(s.error_message(), "__aborted__");
|
||||||
@ -468,8 +457,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
|
|||||||
auto cp =
|
auto cp =
|
||||||
MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
|
MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
|
||||||
/*is_source*/ i == 0);
|
/*is_source*/ i == 0);
|
||||||
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
|
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
|
||||||
&cancel_mgr,
|
|
||||||
[&status, &n](const Status& s) {
|
[&status, &n](const Status& s) {
|
||||||
status = s;
|
status = s;
|
||||||
n.Notify();
|
n.Notify();
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -36,7 +35,7 @@ class TestCollectiveExecutor : public CollectiveExecutor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
class TestParamResolver : public ParamResolverInterface {
|
class TestParamResolver : public ParamResolverInterface {
|
||||||
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
|
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const StatusCallback& done) override {
|
const StatusCallback& done) override {
|
||||||
done(errors::Internal("Unimplemented"));
|
done(errors::Internal("Unimplemented"));
|
||||||
|
@ -571,9 +571,7 @@ cc_library(
|
|||||||
":device_resolver_distributed",
|
":device_resolver_distributed",
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:errors",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -608,7 +606,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/kernels:collective_ops",
|
"//tensorflow/core/kernels:collective_ops",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,18 +18,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class CompleteGroupCall : public CancellableCall {
|
class CompleteGroupCall : public CancellableCall {
|
||||||
public:
|
public:
|
||||||
CompleteGroupCall(const CollGroupParams& group,
|
CompleteGroupCall(const CollGroupParams& group, const string& device_name,
|
||||||
const DeviceAttributes& device,
|
|
||||||
const CollectiveType& collective_type,
|
const CollectiveType& collective_type,
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const string& remote_worker, WorkerCacheInterface* wc)
|
const string& remote_worker, WorkerCacheInterface* wc)
|
||||||
@ -37,7 +33,7 @@ class CompleteGroupCall : public CancellableCall {
|
|||||||
req_.set_group_key(group.group_key);
|
req_.set_group_key(group.group_key);
|
||||||
req_.set_group_size(group.group_size);
|
req_.set_group_size(group.group_size);
|
||||||
req_.set_device_type(group.device_type.type_string());
|
req_.set_device_type(group.device_type.type_string());
|
||||||
*req_.mutable_device_attributes() = device;
|
req_.add_device_name(device_name);
|
||||||
req_.set_collective_type(collective_type);
|
req_.set_collective_type(collective_type);
|
||||||
}
|
}
|
||||||
~CompleteGroupCall() override {}
|
~CompleteGroupCall() override {}
|
||||||
@ -102,16 +98,16 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CollectiveParamResolverDistributed::CompleteParamsAsync(
|
void CollectiveParamResolverDistributed::CompleteParamsAsync(
|
||||||
const DeviceAttributes& device, CollectiveParams* cp,
|
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||||
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
const StatusCallback& done) {
|
||||||
VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
|
VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
|
||||||
<< ": " << cp->ToString();
|
<< cp->ToString();
|
||||||
CompleteGroupDistributed(device, cp, cancel_mgr,
|
CompleteGroupDistributed(device, cp, cancel_mgr,
|
||||||
[this, device, cp, cancel_mgr, done](
|
[this, device, cp, cancel_mgr, done](
|
||||||
const Status& s, const GroupRec* gr) {
|
const Status& s, const GroupRec* gr) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
CompleteInstanceDistributed(
|
CompleteInstanceDistributed(device, gr, cp,
|
||||||
device.name(), gr, cp, cancel_mgr, done);
|
cancel_mgr, done);
|
||||||
} else {
|
} else {
|
||||||
done(s);
|
done(s);
|
||||||
}
|
}
|
||||||
@ -121,28 +117,28 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync(
|
|||||||
void CollectiveParamResolverDistributed::CompleteGroupAsync(
|
void CollectiveParamResolverDistributed::CompleteGroupAsync(
|
||||||
const CompleteGroupRequest* request, CompleteGroupResponse* response,
|
const CompleteGroupRequest* request, CompleteGroupResponse* response,
|
||||||
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
||||||
if (!request->has_device_attributes()) {
|
|
||||||
done(errors::Internal(
|
|
||||||
"CompleteGroupRequest device_attributes is not set. Make sure you're "
|
|
||||||
"running the same version of Tensorflow on all workers."));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
CollectiveParams cp;
|
CollectiveParams cp;
|
||||||
cp.group.group_key = request->group_key();
|
cp.group.group_key = request->group_key();
|
||||||
cp.group.group_size = request->group_size();
|
cp.group.group_size = request->group_size();
|
||||||
cp.group.device_type = DeviceType(request->device_type());
|
cp.group.device_type = DeviceType(request->device_type());
|
||||||
|
for (const string& dn : request->device_name()) {
|
||||||
|
cp.instance.device_names.push_back(dn);
|
||||||
|
}
|
||||||
cp.instance.type = CollectiveType(request->collective_type());
|
cp.instance.type = CollectiveType(request->collective_type());
|
||||||
CompleteGroupDistributed(
|
CompleteGroupDistributed(
|
||||||
request->device_attributes(), &cp, cancel_mgr,
|
cp.instance.device_names[0], &cp, cancel_mgr,
|
||||||
[response, done](const Status& s, const GroupRec* gr) {
|
[response, done](const Status& s, const GroupRec* gr) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
mutex_lock l(gr->mu);
|
mutex_lock l(gr->mu);
|
||||||
response->set_group_key(gr->group.group_key);
|
response->set_group_key(gr->group.group_key);
|
||||||
response->set_group_size(gr->group.group_size);
|
response->set_group_size(gr->group.group_size);
|
||||||
response->set_device_type(gr->group.device_type.type_string());
|
response->set_device_type(gr->group.device_type.type_string());
|
||||||
response->set_num_tasks(gr->group.num_tasks);
|
response->set_num_tasks(gr->task_set.size());
|
||||||
for (const auto& item : gr->devices) {
|
for (const string& dn : gr->device_list) {
|
||||||
*response->add_device_attributes() = item.second;
|
response->add_device_name(dn);
|
||||||
|
}
|
||||||
|
for (const string& tn : gr->task_list) {
|
||||||
|
response->add_task_name(tn);
|
||||||
}
|
}
|
||||||
response->set_communicator_key(
|
response->set_communicator_key(
|
||||||
gr->group.runtime_details.communicator_key);
|
gr->group.runtime_details.communicator_key);
|
||||||
@ -156,22 +152,6 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
|
|||||||
void CollectiveParamResolverDistributed::CompleteInstanceAsync(
|
void CollectiveParamResolverDistributed::CompleteInstanceAsync(
|
||||||
const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
|
const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
|
||||||
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
||||||
GroupRec* gr = GetCachedGroup(request->group_key());
|
|
||||||
if (gr == nullptr) {
|
|
||||||
done(errors::FailedPrecondition(
|
|
||||||
"group ", request->group_key(),
|
|
||||||
" not found. This normally means the server has restarted"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
{
|
|
||||||
mutex_lock l(gr->mu);
|
|
||||||
if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
|
|
||||||
done(errors::FailedPrecondition(
|
|
||||||
"group ", request->group_key(),
|
|
||||||
" failed to resolve. This normally means the server has restarted"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CollectiveParams* cp = new CollectiveParams;
|
CollectiveParams* cp = new CollectiveParams;
|
||||||
cp->name = request->name();
|
cp->name = request->name();
|
||||||
cp->group.group_key = request->group_key();
|
cp->group.group_key = request->group_key();
|
||||||
@ -184,44 +164,56 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
|
|||||||
for (int32 offset : request->subdiv_offset()) {
|
for (int32 offset : request->subdiv_offset()) {
|
||||||
cp->instance.impl_details.subdiv_offsets.push_back(offset);
|
cp->instance.impl_details.subdiv_offsets.push_back(offset);
|
||||||
}
|
}
|
||||||
StatusCallback done_and_cleanup = [cp, done](const Status& s) {
|
string* device = new string(request->device());
|
||||||
|
VLOG(1) << "New cp " << cp << " for device " << *device << " : "
|
||||||
|
<< cp->ToString();
|
||||||
|
StatusCallback done_and_cleanup = [cp, device, done](const Status& s) {
|
||||||
done(s);
|
done(s);
|
||||||
delete cp;
|
delete cp;
|
||||||
|
delete device;
|
||||||
};
|
};
|
||||||
CompleteInstanceDistributed(
|
// Start by completing the group.
|
||||||
request->device(), gr, cp, cancel_mgr,
|
CompleteGroupDistributed(
|
||||||
[this, gr, cp, response, done_and_cleanup](const Status& ci_status) {
|
*device, cp, cancel_mgr,
|
||||||
if (ci_status.ok()) {
|
[this, cp, device, response, cancel_mgr, done_and_cleanup](
|
||||||
// Now source_rank should be known, so
|
const Status& cg_status, const GroupRec* gr) {
|
||||||
// retrieve it.
|
if (cg_status.ok()) {
|
||||||
FindInstanceRec(
|
// Then complete the instance.
|
||||||
gr, cp,
|
CompleteInstanceDistributed(
|
||||||
[cp, response, done_and_cleanup](const Status& fi_status,
|
*device, gr, cp, cancel_mgr,
|
||||||
InstanceRec* ir) {
|
[this, gr, cp, response,
|
||||||
if (fi_status.ok()) {
|
done_and_cleanup](const Status& ci_status) {
|
||||||
mutex_lock l(ir->out_mu);
|
if (ci_status.ok()) {
|
||||||
ir->WaitForOutMu(l);
|
// Now source_rank should be known, so
|
||||||
response->set_instance_key(cp->instance.instance_key);
|
// retrieve it.
|
||||||
response->set_source_rank(ir->source_rank);
|
FindInstanceRec(
|
||||||
done_and_cleanup(fi_status);
|
gr, cp,
|
||||||
|
[cp, response, done_and_cleanup](const Status& fi_status,
|
||||||
|
InstanceRec* ir) {
|
||||||
|
if (fi_status.ok()) {
|
||||||
|
mutex_lock l(ir->out_mu);
|
||||||
|
ir->WaitForOutMu(l);
|
||||||
|
response->set_instance_key(cp->instance.instance_key);
|
||||||
|
response->set_source_rank(ir->source_rank);
|
||||||
|
done_and_cleanup(fi_status);
|
||||||
|
} else {
|
||||||
|
done_and_cleanup(fi_status);
|
||||||
|
}
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
done_and_cleanup(fi_status);
|
done_and_cleanup(ci_status);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
done_and_cleanup(ci_status);
|
done_and_cleanup(cg_status);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
CollectiveParamResolverDistributed::GroupRec*
|
bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
|
||||||
CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
|
|
||||||
mutex_lock l(group_mu_);
|
mutex_lock l(group_mu_);
|
||||||
auto it = group_table_.find(group_key);
|
const auto& it = group_table_.find(group_key);
|
||||||
if (it == group_table_.end()) {
|
return it != group_table_.end();
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return it->second.get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
||||||
@ -234,19 +226,26 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
|||||||
gr->group.group_key = resp.group_key();
|
gr->group.group_key = resp.group_key();
|
||||||
gr->group.group_size = resp.group_size();
|
gr->group.group_size = resp.group_size();
|
||||||
gr->group.num_tasks = resp.num_tasks();
|
gr->group.num_tasks = resp.num_tasks();
|
||||||
if (resp.device_attributes().empty()) {
|
if (resp.device_name_size() != gr->group.group_size) {
|
||||||
return errors::Internal(
|
|
||||||
"CompleteGroupResponse device_attributes is empty. Make sure you're "
|
|
||||||
"running the same version of Tensorflow on all workers.");
|
|
||||||
}
|
|
||||||
if (resp.device_attributes_size() != gr->group.group_size) {
|
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"CompleteGroupResponse group_size doesn't match device_name list");
|
"CompleteGroupResponse group_size doesn't match device_name list");
|
||||||
}
|
}
|
||||||
for (const DeviceAttributes& device : resp.device_attributes()) {
|
for (const string& dn : resp.device_name()) {
|
||||||
gr->devices[device.name()] = device;
|
gr->device_set.insert(dn);
|
||||||
|
gr->device_list.push_back(dn);
|
||||||
}
|
}
|
||||||
|
if (resp.task_name_size() != gr->group.group_size) {
|
||||||
|
return errors::Internal(
|
||||||
|
"CompleteGroupResponse group_size doesn't match task_name list");
|
||||||
|
}
|
||||||
|
for (const string& tn : resp.task_name()) {
|
||||||
|
gr->task_list.push_back(tn);
|
||||||
|
gr->task_set.insert(tn);
|
||||||
|
}
|
||||||
|
CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
|
||||||
gr->group.runtime_details.communicator_key = resp.communicator_key();
|
gr->group.runtime_details.communicator_key = resp.communicator_key();
|
||||||
|
VLOG(2) << "Group communicator_key="
|
||||||
|
<< absl::CEscape(gr->group.runtime_details.communicator_key);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Group membership should never change. Once a record is in group_table_
|
// Group membership should never change. Once a record is in group_table_
|
||||||
@ -274,15 +273,14 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
||||||
const DeviceAttributes& device, CollectiveParams* cp,
|
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
|
||||||
CancellationManager* cancel_mgr, const GroupRecCallback& done) {
|
const GroupRecCallback& done) {
|
||||||
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
|
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
|
||||||
<< " dev: " << device.name()
|
<< " dev: " << device << " is_leader=" << (group_leader_.empty());
|
||||||
<< " is_leader=" << (group_leader_.empty());
|
|
||||||
if (group_leader_.empty()) {
|
if (group_leader_.empty()) {
|
||||||
// This is the group leader, so resolution is local.
|
// This is the group leader, so resolution is local.
|
||||||
return CompleteGroupLocal(device, cp, done);
|
return CompleteGroupLocal(device, cp, done);
|
||||||
} else if (GetCachedGroup(cp->group.group_key) == nullptr) {
|
} else if (!GroupIsCached(cp->group.group_key)) {
|
||||||
// Need to update Group cache from the leader.
|
// Need to update Group cache from the leader.
|
||||||
CompleteGroupCall* call =
|
CompleteGroupCall* call =
|
||||||
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
|
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
class ConfigProto;
|
class ConfigProto;
|
||||||
@ -32,7 +31,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
|||||||
WorkerCacheInterface* worker_cache,
|
WorkerCacheInterface* worker_cache,
|
||||||
const string& task_name);
|
const string& task_name);
|
||||||
|
|
||||||
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
|
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
@ -47,9 +46,9 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
|||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Returns the cached group iff there's an entry for this group_key in the
|
// Returns true iff there's an entry for this group_key in the
|
||||||
// local group_table_; returns nullptr otherwise.
|
// local group_table_.
|
||||||
GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
|
bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
|
||||||
|
|
||||||
// Updates group_table_ with contents of resp.
|
// Updates group_table_ with contents of resp.
|
||||||
Status UpdateGroupCache(const CompleteGroupResponse& resp)
|
Status UpdateGroupCache(const CompleteGroupResponse& resp)
|
||||||
@ -60,8 +59,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
|||||||
//
|
//
|
||||||
// Semantics are like those of CompleteGroupLocal but will make a
|
// Semantics are like those of CompleteGroupLocal but will make a
|
||||||
// remote call to the group leader if necessary.
|
// remote call to the group leader if necessary.
|
||||||
void CompleteGroupDistributed(const DeviceAttributes& device,
|
void CompleteGroupDistributed(const string& device, CollectiveParams* cp,
|
||||||
CollectiveParams* cp,
|
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const GroupRecCallback& done);
|
const GroupRecCallback& done);
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||||
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||||
@ -24,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/random.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
@ -43,7 +41,6 @@ static std::unique_ptr<Device> NewDevice(const string& type,
|
|||||||
attr.set_name(name);
|
attr.set_name(name);
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
||||||
attr.set_incarnation(random::New64());
|
|
||||||
return absl::make_unique<FakeDevice>(attr);
|
return absl::make_unique<FakeDevice>(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,110 +125,127 @@ class FakeCache : public TestWorkerCache {
|
|||||||
|
|
||||||
class DeviceResDistTest : public ::testing::Test {
|
class DeviceResDistTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void DefineWorkers(int num_workers, int num_devices,
|
DeviceResDistTest() {}
|
||||||
const string& device_type, bool nccl) {
|
|
||||||
for (int w = 0; w < num_workers; ++w) {
|
~DeviceResDistTest() override {
|
||||||
string name = strings::StrCat("/job:worker/replica:0/task:", w);
|
for (DeviceMgr* dm : device_mgrs_) {
|
||||||
DefineWorker(name, device_type, num_devices, nccl);
|
delete dm;
|
||||||
|
}
|
||||||
|
for (auto it : dev_resolvers_) {
|
||||||
|
delete it.second;
|
||||||
|
}
|
||||||
|
for (auto it : cp_resolvers_) {
|
||||||
|
delete it.second;
|
||||||
|
}
|
||||||
|
for (FakeWorker* w : workers_) {
|
||||||
|
delete w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DefineWorker(const string& worker_name, const string& device_type,
|
void DefineWorkers(int num_workers, int num_devices,
|
||||||
int num_devices, bool nccl) {
|
const string& device_type, bool nccl) {
|
||||||
ConfigProto config;
|
ConfigProto config;
|
||||||
config.mutable_experimental()->set_collective_group_leader(
|
for (int w = 0; w < num_workers; ++w) {
|
||||||
"/job:worker/replica:0/task:0");
|
string name = strings::StrCat("/job:worker/replica:0/task:", w);
|
||||||
config.mutable_experimental()->set_collective_nccl(nccl);
|
if (w == 0) {
|
||||||
|
config.mutable_experimental()->set_collective_group_leader(name);
|
||||||
|
if (nccl) {
|
||||||
|
config.mutable_experimental()->set_collective_nccl(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DefineWorker(config, name, device_type, num_devices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DefineWorker(const ConfigProto& config, const string& worker_name,
|
||||||
|
const string& device_type, int num_devices) {
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
for (int i = 0; i < num_devices; ++i) {
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
devices.push_back(NewDevice(
|
devices.push_back(NewDevice(
|
||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||||
}
|
}
|
||||||
device_mgrs_[worker_name] =
|
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||||
absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
dv->clear();
|
for (auto* d : dev_mgr->ListDevices()) {
|
||||||
for (auto* d : device_mgrs_[worker_name]->ListDevices()) {
|
|
||||||
dv->push_back(d->name());
|
dv->push_back(d->name());
|
||||||
}
|
}
|
||||||
dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
|
DeviceResolverDistributed* dev_res =
|
||||||
device_mgrs_[worker_name].get(), &wc_, worker_name);
|
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||||
cp_resolvers_[worker_name] =
|
dev_resolvers_[worker_name] = dev_res;
|
||||||
absl::make_unique<CollectiveParamResolverDistributed>(
|
CollectiveParamResolverDistributed* cp_res =
|
||||||
config, device_mgrs_[worker_name].get(),
|
new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_,
|
||||||
dev_resolvers_[worker_name].get(), &wc_, worker_name);
|
worker_name);
|
||||||
workers_[worker_name] = absl::make_unique<FakeWorker>(
|
cp_resolvers_[worker_name] = cp_res;
|
||||||
worker_name, device_mgrs_[worker_name].get(),
|
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res);
|
||||||
cp_resolvers_[worker_name].get());
|
workers_.push_back(fw);
|
||||||
wc_.AddWorker(worker_name, workers_[worker_name].get());
|
wc_.AddWorker(worker_name, fw);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DefineCollectiveParams(int num_workers, int num_devices,
|
void DefineCollectiveParams(int num_workers, int num_devices) {
|
||||||
const string& device_type) {
|
|
||||||
for (int wi = 0; wi < num_workers; ++wi) {
|
|
||||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
|
||||||
for (int di = 0; di < num_devices; ++di) {
|
|
||||||
string device_name =
|
|
||||||
strings::StrCat(task_name, "/device:", device_type, ":", di);
|
|
||||||
cp_[device_name] =
|
|
||||||
CreateCollectiveParams(num_workers, num_devices, device_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
|
|
||||||
const string& device_type) {
|
|
||||||
const int kGroupKey = 5;
|
const int kGroupKey = 5;
|
||||||
const int kInstanceKey = 3;
|
const int kInstanceKey = 3;
|
||||||
CollectiveParams cp;
|
|
||||||
cp.group.group_key = kGroupKey;
|
|
||||||
cp.group.group_size = num_workers * num_devices;
|
|
||||||
cp.group.device_type = DeviceType(device_type);
|
|
||||||
cp.group.num_tasks = num_workers;
|
|
||||||
cp.instance.instance_key = kInstanceKey;
|
|
||||||
cp.instance.type = REDUCTION_COLLECTIVE;
|
|
||||||
cp.instance.data_type = DT_FLOAT;
|
|
||||||
cp.instance.shape = TensorShape({64});
|
|
||||||
cp.instance.impl_details.subdiv_offsets.push_back(0);
|
|
||||||
return cp;
|
|
||||||
}
|
|
||||||
|
|
||||||
void IssueRequests(int num_workers, int num_devices) {
|
|
||||||
{
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
num_done_ = 0;
|
|
||||||
}
|
|
||||||
int group_size = num_workers * num_devices;
|
|
||||||
for (int wi = 0; wi < num_workers; ++wi) {
|
for (int wi = 0; wi < num_workers; ++wi) {
|
||||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||||
for (int di = 0; di < num_devices; ++di) {
|
for (int di = 0; di < num_devices; ++di) {
|
||||||
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
|
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
|
||||||
IssueRequest(task_name, device_name, group_size);
|
cp_.push_back(CollectiveParams());
|
||||||
|
CollectiveParams& cp = cp_.back();
|
||||||
|
cp.group.group_key = kGroupKey;
|
||||||
|
cp.group.group_size = num_workers * num_devices;
|
||||||
|
cp.group.device_type = DEVICE_CPU;
|
||||||
|
cp.group.num_tasks = num_workers;
|
||||||
|
cp.instance.instance_key = kInstanceKey;
|
||||||
|
cp.instance.type = REDUCTION_COLLECTIVE;
|
||||||
|
cp.instance.data_type = DT_FLOAT;
|
||||||
|
cp.instance.shape = TensorShape({64});
|
||||||
|
cp.instance.impl_details.subdiv_offsets.push_back(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void IssueRequest(const string& task_name, const string& device_name,
|
void IssueRequests(int num_workers, int num_devices) {
|
||||||
int group_size) {
|
const int device_count = num_workers * num_devices;
|
||||||
Device* device = nullptr;
|
{
|
||||||
TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
|
mutex_lock l(mu_);
|
||||||
CollectiveParams* cp = &cp_[device_name];
|
num_done_ = 0;
|
||||||
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
|
}
|
||||||
|
cp_.resize(device_count);
|
||||||
|
status_.resize(device_count);
|
||||||
|
int idx = 0;
|
||||||
|
for (int wi = 0; wi < num_workers; ++wi) {
|
||||||
|
for (int di = 0; di < num_devices; ++di) {
|
||||||
|
IssueRequest(num_workers, num_devices, idx);
|
||||||
|
++idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void IssueRequest(int num_workers, int num_devices, int idx) {
|
||||||
|
int device_count = num_workers * num_devices;
|
||||||
|
int wi = idx / num_devices;
|
||||||
|
int di = idx % num_devices;
|
||||||
|
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||||
|
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
|
||||||
|
while (idx >= cp_.size()) {
|
||||||
|
status_.resize(idx + 1);
|
||||||
|
cp_.resize(idx + 1);
|
||||||
|
}
|
||||||
|
CollectiveParams* cp = &cp_[idx];
|
||||||
|
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name];
|
||||||
CHECK(cp_res);
|
CHECK(cp_res);
|
||||||
cp_res->CompleteParamsAsync(
|
cp_res->CompleteParamsAsync(device_name, cp, &cm_,
|
||||||
device->attributes(), cp, &cm_,
|
[this, idx, device_count](const Status& s) {
|
||||||
[this, device_name, group_size](const Status& s) {
|
status_[idx] = s;
|
||||||
status_[device_name] = s;
|
{
|
||||||
{
|
mutex_lock l(mu_);
|
||||||
mutex_lock l(mu_);
|
++num_done_;
|
||||||
++num_done_;
|
if (num_done_ == device_count) {
|
||||||
if (num_done_ == group_size) {
|
done_.notify_all();
|
||||||
done_.notify_all();
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ValidateCollectiveParams(int num_workers, int num_devices) {
|
void ValidateCollectiveParams(int num_workers, int num_devices) {
|
||||||
@ -245,59 +259,39 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
// Verify that all cp_ values get the same set of task and device
|
// Verify that all cp_ values get the same set of task and device
|
||||||
// names, with unique default_rank in the expected order.
|
// names, with unique default_rank in the expected order.
|
||||||
const int dev_count = num_workers * num_devices;
|
const int dev_count = num_workers * num_devices;
|
||||||
string dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
|
|
||||||
for (int wi = 0; wi < num_workers; ++wi) {
|
for (int wi = 0; wi < num_workers; ++wi) {
|
||||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||||
for (int di = 0; di < num_devices; ++di) {
|
for (int di = 0; di < num_devices; ++di) {
|
||||||
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
|
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
|
||||||
int idx = wi * num_devices + di;
|
int idx = wi * num_devices + di;
|
||||||
TF_ASSERT_OK(status_[device_name]);
|
TF_ASSERT_OK(status_[idx]);
|
||||||
EXPECT_EQ(cp_[device_name].default_rank, idx);
|
EXPECT_EQ(cp_[idx].default_rank, idx);
|
||||||
EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
|
EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count);
|
||||||
EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
|
EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name);
|
||||||
EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
|
EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name);
|
||||||
if (idx > 0) {
|
if (idx > 0) {
|
||||||
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
|
EXPECT_EQ(cp_[0].group.runtime_details.communicator_key,
|
||||||
cp_[device_name].group.runtime_details.communicator_key);
|
cp_[idx].group.runtime_details.communicator_key);
|
||||||
for (int i = 0; i < dev_count; ++i) {
|
for (int i = 0; i < dev_count; ++i) {
|
||||||
EXPECT_EQ(cp_[dev0].instance.device_names[i],
|
EXPECT_EQ(cp_[0].instance.device_names[i],
|
||||||
cp_[device_name].instance.device_names[i]);
|
cp_[idx].instance.device_names[i]);
|
||||||
EXPECT_EQ(cp_[dev0].instance.task_names[i],
|
EXPECT_EQ(cp_[0].instance.task_names[i],
|
||||||
cp_[device_name].instance.task_names[i]);
|
cp_[idx].instance.task_names[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RestartWorker(int worker_idx, int num_workers, int num_devices,
|
|
||||||
const string& device_type, bool nccl) {
|
|
||||||
string worker_name =
|
|
||||||
strings::StrCat("/job:worker/replica:0/task:", worker_idx);
|
|
||||||
DefineWorker(worker_name, device_type, num_devices, nccl);
|
|
||||||
for (int i = 0; i < num_devices; ++i) {
|
|
||||||
string device_name =
|
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i);
|
|
||||||
cp_[device_name] =
|
|
||||||
CreateCollectiveParams(num_workers, num_devices, device_type);
|
|
||||||
status_.erase(device_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FakeCache wc_;
|
FakeCache wc_;
|
||||||
CancellationManager cm_;
|
CancellationManager cm_;
|
||||||
// Below are keyed by task names.
|
std::vector<DeviceMgr*> device_mgrs_;
|
||||||
absl::flat_hash_map<string, std::unique_ptr<DeviceMgr>> device_mgrs_;
|
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
|
||||||
absl::flat_hash_map<string, std::unique_ptr<DeviceResolverDistributed>>
|
std::unordered_map<string, CollectiveParamResolverDistributed*> cp_resolvers_;
|
||||||
dev_resolvers_;
|
std::unordered_map<string, std::vector<string>> dev_by_task_;
|
||||||
absl::flat_hash_map<string,
|
std::vector<FakeWorker*> workers_;
|
||||||
std::unique_ptr<CollectiveParamResolverDistributed>>
|
std::vector<CollectiveParams> cp_;
|
||||||
cp_resolvers_;
|
std::vector<Status> status_;
|
||||||
absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
|
|
||||||
absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
|
|
||||||
// Below are keyed by device names;
|
|
||||||
absl::flat_hash_map<string, CollectiveParams> cp_;
|
|
||||||
absl::flat_hash_map<string, Status> status_;
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int num_done_ TF_GUARDED_BY(mu_);
|
int num_done_ TF_GUARDED_BY(mu_);
|
||||||
condition_variable done_;
|
condition_variable done_;
|
||||||
@ -306,8 +300,8 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
TEST_F(DeviceResDistTest, Workers1Devices1) {
|
TEST_F(DeviceResDistTest, Workers1Devices1) {
|
||||||
const int num_workers = 1;
|
const int num_workers = 1;
|
||||||
const int num_devices = 1;
|
const int num_devices = 1;
|
||||||
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
|
DefineWorkers(num_workers, num_devices, "CPU", false);
|
||||||
DefineCollectiveParams(num_workers, num_devices, "CPU");
|
DefineCollectiveParams(num_workers, num_devices);
|
||||||
IssueRequests(num_workers, num_devices);
|
IssueRequests(num_workers, num_devices);
|
||||||
ValidateCollectiveParams(num_workers, num_devices);
|
ValidateCollectiveParams(num_workers, num_devices);
|
||||||
}
|
}
|
||||||
@ -315,25 +309,12 @@ TEST_F(DeviceResDistTest, Workers1Devices1) {
|
|||||||
TEST_F(DeviceResDistTest, Workers2Devices2) {
|
TEST_F(DeviceResDistTest, Workers2Devices2) {
|
||||||
const int num_workers = 2;
|
const int num_workers = 2;
|
||||||
const int num_devices = 2;
|
const int num_devices = 2;
|
||||||
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
|
DefineWorkers(num_workers, num_devices, "CPU", false);
|
||||||
DefineCollectiveParams(num_workers, num_devices, "CPU");
|
DefineCollectiveParams(num_workers, num_devices);
|
||||||
IssueRequests(num_workers, num_devices);
|
IssueRequests(num_workers, num_devices);
|
||||||
ValidateCollectiveParams(num_workers, num_devices);
|
ValidateCollectiveParams(num_workers, num_devices);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeviceResDistTest, DifferentIncarnation) {
|
|
||||||
const int num_workers = 2;
|
|
||||||
const int num_devices = 1;
|
|
||||||
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
|
|
||||||
DefineCollectiveParams(num_workers, num_devices, "CPU");
|
|
||||||
IssueRequests(num_workers, num_devices);
|
|
||||||
RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false);
|
|
||||||
const string task_name = "/job:worker/replica:0/task:1";
|
|
||||||
const string device_name = absl::StrCat(task_name, "/device:CPU:0");
|
|
||||||
IssueRequest(task_name, device_name, num_workers * num_devices);
|
|
||||||
EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name]));
|
|
||||||
}
|
|
||||||
|
|
||||||
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
|
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
|
||||||
namespace {
|
namespace {
|
||||||
// A mock NcclReducer for testing group runtime details initialization with CPU
|
// A mock NcclReducer for testing group runtime details initialization with CPU
|
||||||
@ -366,7 +347,7 @@ TEST_F(DeviceResDistTest, Workers4Devices3) {
|
|||||||
const int num_workers = 4;
|
const int num_workers = 4;
|
||||||
const int num_devices = 3;
|
const int num_devices = 3;
|
||||||
DefineWorkers(num_workers, num_devices, "CPU", true);
|
DefineWorkers(num_workers, num_devices, "CPU", true);
|
||||||
DefineCollectiveParams(num_workers, num_devices, "CPU");
|
DefineCollectiveParams(num_workers, num_devices);
|
||||||
IssueRequests(num_workers, num_devices);
|
IssueRequests(num_workers, num_devices);
|
||||||
ValidateCollectiveParams(num_workers, num_devices);
|
ValidateCollectiveParams(num_workers, num_devices);
|
||||||
}
|
}
|
||||||
|
@ -180,8 +180,7 @@ class ParamResolverInterface {
|
|||||||
// Called by each collective op at first execution in order to fill out
|
// Called by each collective op at first execution in order to fill out
|
||||||
// the CollectiveParams structure with data gathered from the full
|
// the CollectiveParams structure with data gathered from the full
|
||||||
// (maybe distributed) collection of peer nodes.
|
// (maybe distributed) collection of peer nodes.
|
||||||
virtual void CompleteParamsAsync(const DeviceAttributes& device,
|
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CollectiveParams* cp,
|
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const StatusCallback& done) = 0;
|
const StatusCallback& done) = 0;
|
||||||
|
|
||||||
@ -302,8 +301,7 @@ class CollectiveExecutor : public core::RefCounted {
|
|||||||
"a CollectiveExecutor has not been provided."));
|
"a CollectiveExecutor has not been provided."));
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void CompleteParamsAsync(const DeviceAttributes& device,
|
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
|
||||||
CollectiveParams* cp,
|
|
||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
done(errors::Internal(
|
done(errors::Internal(
|
||||||
|
@ -73,7 +73,7 @@ class CollectiveOpKernel : public AsyncOpKernel {
|
|||||||
<< " group " << col_params_.group.group_key << " instance "
|
<< " group " << col_params_.group.group_key << " instance "
|
||||||
<< col_params_.instance.instance_key;
|
<< col_params_.instance.instance_key;
|
||||||
col_exec->CompleteParamsAsync(
|
col_exec->CompleteParamsAsync(
|
||||||
c->device()->attributes(), &col_params_, c->cancellation_manager(),
|
c->device()->name(), &col_params_, c->cancellation_manager(),
|
||||||
[this, c, done](const Status& s) {
|
[this, c, done](const Status& s) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
col_params_.instance.impl_details.dependencies = dependencies_;
|
col_params_.instance.impl_details.dependencies = dependencies_;
|
||||||
@ -538,8 +538,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
|||||||
<< " group " << col_params->group.group_key << " instance "
|
<< " group " << col_params->group.group_key << " instance "
|
||||||
<< col_params->instance.instance_key;
|
<< col_params->instance.instance_key;
|
||||||
col_exec->CompleteParamsAsync(
|
col_exec->CompleteParamsAsync(
|
||||||
c->device()->attributes(), col_params.get(),
|
c->device()->name(), col_params.get(), c->cancellation_manager(),
|
||||||
c->cancellation_manager(),
|
|
||||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||||
|
@ -545,10 +545,8 @@ message CompleteGroupRequest {
|
|||||||
int32 group_key = 1;
|
int32 group_key = 1;
|
||||||
int32 group_size = 2;
|
int32 group_size = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
|
repeated string device_name = 4;
|
||||||
int32 collective_type = 5;
|
int32 collective_type = 5;
|
||||||
DeviceAttributes device_attributes = 6;
|
|
||||||
|
|
||||||
reserved 4;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gives the complete membership of the group identified by group_key.
|
// Gives the complete membership of the group identified by group_key.
|
||||||
@ -557,10 +555,9 @@ message CompleteGroupResponse {
|
|||||||
int32 group_size = 2;
|
int32 group_size = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
int32 num_tasks = 4; // number of distinct tasks hosting the devices
|
int32 num_tasks = 4; // number of distinct tasks hosting the devices
|
||||||
|
repeated string device_name = 5;
|
||||||
|
repeated string task_name = 6; // task name prefixes of device_names
|
||||||
bytes communicator_key = 7;
|
bytes communicator_key = 7;
|
||||||
repeated DeviceAttributes device_attributes = 8;
|
|
||||||
|
|
||||||
reserved 5, 6;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Supplies data about one collective op belonging to the instance identified
|
// Supplies data about one collective op belonging to the instance identified
|
||||||
|
Loading…
Reference in New Issue
Block a user