Make sure the rendezvous abort check is finished before triggering the callback.
PiperOrigin-RevId: 313204522 Change-Id: I88f38391d9ee2296fac9a6e86bb9f9d2c477f1c8
This commit is contained in:
parent
2e842db3cc
commit
09af9319d9
tensorflow/core/distributed_runtime
@ -462,6 +462,8 @@ tf_cuda_cc_tests(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:test_utils",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -136,7 +137,12 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
// Start the main RecvTensor call, checking for an async abort.
|
||||
void StartRTCall(std::function<void()> recv_done) {
|
||||
resp_.InitAlloc(dst_device_, alloc_attrs_);
|
||||
auto cb = [this, recv_done = std::move(recv_done)](const Status& s) {
|
||||
auto abort_checked = std::make_shared<Notification>();
|
||||
auto cb = [this, abort_checked,
|
||||
recv_done = std::move(recv_done)](const Status& s) {
|
||||
// Make sure the Rendezvous abort checking is finished before running the
|
||||
// callback, which might destroy the current call object.
|
||||
abort_checked->WaitForNotification();
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
@ -158,6 +164,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
if (!s.ok()) {
|
||||
opts_.StartCancel();
|
||||
}
|
||||
// Notify that the abort check has finished.
|
||||
abort_checked->Notify();
|
||||
}
|
||||
|
||||
string src_worker_;
|
||||
|
@ -16,13 +16,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -48,13 +51,34 @@ Rendezvous::ParsedKey MakeKey(const string& s) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
// A dummy worker interface implementation that simply triggers the callback
|
||||
// with OK status for RecvTensor request.
|
||||
class DummyWorker : public TestWorkerInterface {
|
||||
public:
|
||||
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
||||
TensorResponse* response, StatusCallback done) override {
|
||||
SchedClosure([done = std::move(done)]() {
|
||||
// Simulate a random delay for RPC. This is needed to fill the entire
|
||||
// object buffer in `RpcRecvTensorFreeList` and trigger the destruction of
|
||||
// RPC call objects.
|
||||
const int64 t_us = random::New64() % 100 * 1000;
|
||||
Env::Default()->SleepForMicroseconds(t_us);
|
||||
done(Status::OK());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Fake cache implementation for WorkerEnv.
|
||||
class DummyWorkerCache : public WorkerCacheInterface {
|
||||
void ListWorkers(std::vector<string>* workers) const override {}
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const override {}
|
||||
WorkerInterface* GetOrCreateWorker(const string& target) override {
|
||||
return nullptr;
|
||||
if (dummy_remote_worker_ == nullptr) {
|
||||
// Ownership transferred to WorkerFreeList
|
||||
dummy_remote_worker_ = new DummyWorker;
|
||||
}
|
||||
return dummy_remote_worker_;
|
||||
}
|
||||
Status GetEagerClientCache(
|
||||
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||
@ -66,7 +90,31 @@ class DummyWorkerCache : public WorkerCacheInterface {
|
||||
}
|
||||
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
|
||||
StatusCallback done) override {}
|
||||
|
||||
private:
|
||||
DummyWorker* dummy_remote_worker_ = nullptr;
|
||||
};
|
||||
|
||||
static Device* CreateDevice(const char* type, const char* name) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||
};
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
return new FakeDevice(attr);
|
||||
}
|
||||
|
||||
static DeviceMgr* CreateDeviceMgr() {
|
||||
std::unique_ptr<Device> d0(
|
||||
CreateDevice("CPU", "/job:mnist/replica:1/task:2/cpu:1"));
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
devices.emplace_back(std::move(d0));
|
||||
return new StaticDeviceMgr(std::move(devices));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class RpcRendezvousMgrTest : public ::testing::Test {
|
||||
@ -75,7 +123,7 @@ class RpcRendezvousMgrTest : public ::testing::Test {
|
||||
: cache_(new DummyWorkerCache),
|
||||
worker_session_("rpc_session", "/job:mnist/replica:1/task:2",
|
||||
std::unique_ptr<WorkerCacheInterface>(cache_),
|
||||
std::unique_ptr<DeviceMgr>(),
|
||||
std::unique_ptr<DeviceMgr>(CreateDeviceMgr()),
|
||||
std::unique_ptr<GraphMgr>(), nullptr),
|
||||
rmgr_(&env) {
|
||||
env.env = Env::Default();
|
||||
@ -193,6 +241,7 @@ TEST_F(RpcRendezvousMgrTest, CancelAfterReceived) {
|
||||
delete cm;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DummyDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
|
||||
@ -202,6 +251,7 @@ class DummyDeviceContext : public DeviceContext {
|
||||
private:
|
||||
const int stream_id_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
DummyDeviceContext* dc = new DummyDeviceContext(123);
|
||||
@ -237,6 +287,59 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
dc->Unref();
|
||||
}
|
||||
|
||||
// NOTE: Remote Send/Recv is better tested in worker_test.cc
|
||||
TEST_F(RpcRendezvousMgrTest, RemoteRecvOne) {
|
||||
const int64 step_id = 123;
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:worker/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||
core::ScopedUnref unref(rendez);
|
||||
Rendezvous::Args args;
|
||||
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
|
||||
TF_ASSERT_OK(rendez->Recv(key, args, &val, &val_dead));
|
||||
}
|
||||
rmgr_.Cleanup(step_id);
|
||||
}
|
||||
|
||||
TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) {
|
||||
const int64 step_id = 123;
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:worker/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
RemoteRendezvous* rendez = rmgr_.Find(step_id);
|
||||
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
|
||||
core::ScopedUnref unref(rendez);
|
||||
Rendezvous::Args args;
|
||||
|
||||
// Send a large number of async RPC requests to fill up the buffer in
|
||||
// `RpcRecvTensorFreeList`, in order to test deleting RPC call objects.
|
||||
int num_requests = 10000;
|
||||
Tensor val(DT_STRING);
|
||||
mutex mu_;
|
||||
Status status = Status::OK();
|
||||
BlockingCounter counter(num_requests);
|
||||
|
||||
for (int i = 0; i < num_requests; i++) {
|
||||
rendez->RecvAsync(
|
||||
key, args,
|
||||
[&mu_, &status, &counter](const Status& s, const Rendezvous::Args&,
|
||||
const Rendezvous::Args&, const Tensor&,
|
||||
const bool) {
|
||||
mutex_lock l(mu_);
|
||||
status.Update(s);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
TF_ASSERT_OK(status);
|
||||
}
|
||||
rmgr_.Cleanup(step_id);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -70,28 +70,28 @@ class TestWorkerInterface : public WorkerInterface {
|
||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("CleanupGraphAsync"));
|
||||
}
|
||||
|
||||
void CleanupAllAsync(const CleanupAllRequest* request,
|
||||
CleanupAllResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("CleanupAllAsync"));
|
||||
}
|
||||
|
||||
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
||||
TensorResponse* response, StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("RecvTensorAsync"));
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("LoggingAsync"));
|
||||
}
|
||||
|
||||
void TracingAsync(const TracingRequest* request, TracingResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("TracingAsync"));
|
||||
}
|
||||
|
||||
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
@ -103,20 +103,20 @@ class TestWorkerInterface : public WorkerInterface {
|
||||
const CompleteGroupRequest* request,
|
||||
CompleteGroupResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("CompleteGroupAsync"));
|
||||
}
|
||||
|
||||
void CompleteInstanceAsync(CallOptions* ops,
|
||||
const CompleteInstanceRequest* request,
|
||||
CompleteInstanceResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("CompleteInstanceAsync"));
|
||||
}
|
||||
|
||||
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
|
||||
GetStepSequenceResponse* response,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("RunGraphAsync"));
|
||||
done(errors::Unimplemented("GetStepSequenceAsync"));
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user