Make sure the rendezvous abort check is finished before triggering the callback.
PiperOrigin-RevId: 313204522 Change-Id: I88f38391d9ee2296fac9a6e86bb9f9d2c477f1c8
This commit is contained in:
parent
2e842db3cc
commit
09af9319d9
@ -462,6 +462,8 @@ tf_cuda_cc_tests(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
|
"//tensorflow/core/distributed_runtime:test_utils",
|
||||||
|
"//tensorflow/core/platform:blocking_counter",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/notification.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -136,7 +137,12 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
|||||||
// Start the main RecvTensor call, checking for an async abort.
|
// Start the main RecvTensor call, checking for an async abort.
|
||||||
void StartRTCall(std::function<void()> recv_done) {
|
void StartRTCall(std::function<void()> recv_done) {
|
||||||
resp_.InitAlloc(dst_device_, alloc_attrs_);
|
resp_.InitAlloc(dst_device_, alloc_attrs_);
|
||||||
auto cb = [this, recv_done = std::move(recv_done)](const Status& s) {
|
auto abort_checked = std::make_shared<Notification>();
|
||||||
|
auto cb = [this, abort_checked,
|
||||||
|
recv_done = std::move(recv_done)](const Status& s) {
|
||||||
|
// Make sure the Rendezvous abort checking is finished before running the
|
||||||
|
// callback, which might destroy the current call object.
|
||||||
|
abort_checked->WaitForNotification();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
status_.Update(s);
|
status_.Update(s);
|
||||||
@ -158,6 +164,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
|||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
opts_.StartCancel();
|
opts_.StartCancel();
|
||||||
}
|
}
|
||||||
|
// Notify that the abort check has finished.
|
||||||
|
abort_checked->Notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
string src_worker_;
|
string src_worker_;
|
||||||
|
@ -16,13 +16,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/control_flow.h"
|
#include "tensorflow/core/framework/control_flow.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/random.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -48,13 +51,34 @@ Rendezvous::ParsedKey MakeKey(const string& s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
// A dummy worker interface implementation that simply triggers the callback
|
||||||
|
// with OK status for RecvTensor request.
|
||||||
|
class DummyWorker : public TestWorkerInterface {
|
||||||
|
public:
|
||||||
|
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
||||||
|
TensorResponse* response, StatusCallback done) override {
|
||||||
|
SchedClosure([done = std::move(done)]() {
|
||||||
|
// Simulate a random delay for RPC. This is needed to fill the entire
|
||||||
|
// object buffer in `RpcRecvTensorFreeList` and trigger the destruction of
|
||||||
|
// RPC call objects.
|
||||||
|
const int64 t_us = random::New64() % 100 * 1000;
|
||||||
|
Env::Default()->SleepForMicroseconds(t_us);
|
||||||
|
done(Status::OK());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Fake cache implementation for WorkerEnv.
|
// Fake cache implementation for WorkerEnv.
|
||||||
class DummyWorkerCache : public WorkerCacheInterface {
|
class DummyWorkerCache : public WorkerCacheInterface {
|
||||||
void ListWorkers(std::vector<string>* workers) const override {}
|
void ListWorkers(std::vector<string>* workers) const override {}
|
||||||
void ListWorkersInJob(const string& job_name,
|
void ListWorkersInJob(const string& job_name,
|
||||||
std::vector<string>* workers) const override {}
|
std::vector<string>* workers) const override {}
|
||||||
WorkerInterface* GetOrCreateWorker(const string& target) override {
|
WorkerInterface* GetOrCreateWorker(const string& target) override {
|
||||||
return nullptr;
|
if (dummy_remote_worker_ == nullptr) {
|
||||||
|
// Ownership transferred to WorkerFreeList
|
||||||
|
dummy_remote_worker_ = new DummyWorker;
|
||||||
|
}
|
||||||
|
return dummy_remote_worker_;
|
||||||
}
|
}
|
||||||
Status GetEagerClientCache(
|
Status GetEagerClientCache(
|
||||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||||
@ -66,7 +90,31 @@ class DummyWorkerCache : public WorkerCacheInterface {
|
|||||||
}
|
}
|
||||||
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
|
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
|
||||||
StatusCallback done) override {}
|
StatusCallback done) override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DummyWorker* dummy_remote_worker_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static Device* CreateDevice(const char* type, const char* name) {
|
||||||
|
class FakeDevice : public Device {
|
||||||
|
public:
|
||||||
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
|
Status Sync() override { return Status::OK(); }
|
||||||
|
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||||
|
};
|
||||||
|
DeviceAttributes attr;
|
||||||
|
attr.set_name(name);
|
||||||
|
attr.set_device_type(type);
|
||||||
|
return new FakeDevice(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
static DeviceMgr* CreateDeviceMgr() {
|
||||||
|
std::unique_ptr<Device> d0(
|
||||||
|
CreateDevice("CPU", "/job:mnist/replica:1/task:2/cpu:1"));
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
devices.emplace_back(std::move(d0));
|
||||||
|
return new StaticDeviceMgr(std::move(devices));
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class RpcRendezvousMgrTest : public ::testing::Test {
|
class RpcRendezvousMgrTest : public ::testing::Test {
|
||||||
@ -75,7 +123,7 @@ class RpcRendezvousMgrTest : public ::testing::Test {
|
|||||||
: cache_(new DummyWorkerCache),
|
: cache_(new DummyWorkerCache),
|
||||||
worker_session_("rpc_session", "/job:mnist/replica:1/task:2",
|
worker_session_("rpc_session", "/job:mnist/replica:1/task:2",
|
||||||
std::unique_ptr<WorkerCacheInterface>(cache_),
|
std::unique_ptr<WorkerCacheInterface>(cache_),
|
||||||
std::unique_ptr<DeviceMgr>(),
|
std::unique_ptr<DeviceMgr>(CreateDeviceMgr()),
|
||||||
std::unique_ptr<GraphMgr>(), nullptr),
|
std::unique_ptr<GraphMgr>(), nullptr),
|
||||||
rmgr_(&env) {
|
rmgr_(&env) {
|
||||||
env.env = Env::Default();
|
env.env = Env::Default();
|
||||||
@ -193,6 +241,7 @@ TEST_F(RpcRendezvousMgrTest, CancelAfterReceived) {
|
|||||||
delete cm;
|
delete cm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
class DummyDeviceContext : public DeviceContext {
|
class DummyDeviceContext : public DeviceContext {
|
||||||
public:
|
public:
|
||||||
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
|
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
|
||||||
@ -202,6 +251,7 @@ class DummyDeviceContext : public DeviceContext {
|
|||||||
private:
|
private:
|
||||||
const int stream_id_;
|
const int stream_id_;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||||
DummyDeviceContext* dc = new DummyDeviceContext(123);
|
DummyDeviceContext* dc = new DummyDeviceContext(123);
|
||||||
@ -237,6 +287,59 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
|||||||
dc->Unref();
|
dc->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: Remote Send/Recv is better tested in worker_test.cc
|
TEST_F(RpcRendezvousMgrTest, RemoteRecvOne) {
|
||||||
|
const int64 step_id = 123;
|
||||||
|
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||||
|
"/job:worker/replica:1/task:2/cpu:0", 7890,
|
||||||
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
|
{
|
||||||
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
|
core::ScopedUnref unref(rendez);
|
||||||
|
Rendezvous::Args args;
|
||||||
|
|
||||||
|
Tensor val(DT_STRING);
|
||||||
|
bool val_dead = false;
|
||||||
|
|
||||||
|
TF_ASSERT_OK(rendez->Recv(key, args, &val, &val_dead));
|
||||||
|
}
|
||||||
|
rmgr_.Cleanup(step_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) {
|
||||||
|
const int64 step_id = 123;
|
||||||
|
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||||
|
"/job:worker/replica:1/task:2/cpu:0", 7890,
|
||||||
|
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||||
|
{
|
||||||
|
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||||
|
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||||
|
core::ScopedUnref unref(rendez);
|
||||||
|
Rendezvous::Args args;
|
||||||
|
|
||||||
|
// Send a large number of async RPC requests to fill up the buffer in
|
||||||
|
// `RpcRecvTensorFreeList`, in order to test deleting RPC call objects.
|
||||||
|
int num_requests = 10000;
|
||||||
|
Tensor val(DT_STRING);
|
||||||
|
mutex mu_;
|
||||||
|
Status status = Status::OK();
|
||||||
|
BlockingCounter counter(num_requests);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_requests; i++) {
|
||||||
|
rendez->RecvAsync(
|
||||||
|
key, args,
|
||||||
|
[&mu_, &status, &counter](const Status& s, const Rendezvous::Args&,
|
||||||
|
const Rendezvous::Args&, const Tensor&,
|
||||||
|
const bool) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
status.Update(s);
|
||||||
|
counter.DecrementCount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
counter.Wait();
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
}
|
||||||
|
rmgr_.Cleanup(step_id);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -70,28 +70,28 @@ class TestWorkerInterface : public WorkerInterface {
|
|||||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||||
CleanupGraphResponse* response,
|
CleanupGraphResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("CleanupGraphAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CleanupAllAsync(const CleanupAllRequest* request,
|
void CleanupAllAsync(const CleanupAllRequest* request,
|
||||||
CleanupAllResponse* response,
|
CleanupAllResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("CleanupAllAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
||||||
TensorResponse* response, StatusCallback done) override {
|
TensorResponse* response, StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("RecvTensorAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("LoggingAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TracingAsync(const TracingRequest* request, TracingResponse* response,
|
void TracingAsync(const TracingRequest* request, TracingResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("TracingAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||||
@ -103,20 +103,20 @@ class TestWorkerInterface : public WorkerInterface {
|
|||||||
const CompleteGroupRequest* request,
|
const CompleteGroupRequest* request,
|
||||||
CompleteGroupResponse* response,
|
CompleteGroupResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("CompleteGroupAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompleteInstanceAsync(CallOptions* ops,
|
void CompleteInstanceAsync(CallOptions* ops,
|
||||||
const CompleteInstanceRequest* request,
|
const CompleteInstanceRequest* request,
|
||||||
CompleteInstanceResponse* response,
|
CompleteInstanceResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("CompleteInstanceAsync"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
|
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
|
||||||
GetStepSequenceResponse* response,
|
GetStepSequenceResponse* response,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("RunGraphAsync"));
|
done(errors::Unimplemented("GetStepSequenceAsync"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user