Make sure the rendezvous abort check is finished before triggering the callback.

PiperOrigin-RevId: 313204522
Change-Id: I88f38391d9ee2296fac9a6e86bb9f9d2c477f1c8
This commit is contained in:
Haoyu Zhang 2020-05-26 09:21:15 -07:00 committed by TensorFlower Gardener
parent 2e842db3cc
commit 09af9319d9
4 changed files with 125 additions and 12 deletions

View File

@ -462,6 +462,8 @@ tf_cuda_cc_tests(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:test_utils",
"//tensorflow/core/platform:blocking_counter",
],
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -136,7 +137,12 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
// Start the main RecvTensor call, checking for an async abort.
void StartRTCall(std::function<void()> recv_done) {
resp_.InitAlloc(dst_device_, alloc_attrs_);
auto cb = [this, recv_done = std::move(recv_done)](const Status& s) {
auto abort_checked = std::make_shared<Notification>();
auto cb = [this, abort_checked,
recv_done = std::move(recv_done)](const Status& s) {
// Make sure the Rendezvous abort checking is finished before running the
// callback, which might destroy the current call object.
abort_checked->WaitForNotification();
if (!s.ok()) {
mutex_lock l(mu_);
status_.Update(s);
@ -158,6 +164,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
if (!s.ok()) {
opts_.StartCancel();
}
// Notify that the abort check has finished.
abort_checked->Notify();
}
string src_worker_;

View File

@ -16,13 +16,16 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -48,13 +51,34 @@ Rendezvous::ParsedKey MakeKey(const string& s) {
}
namespace {
// A dummy worker interface implementation that simply triggers the callback
// with OK status for RecvTensor request.
class DummyWorker : public TestWorkerInterface {
public:
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
TensorResponse* response, StatusCallback done) override {
SchedClosure([done = std::move(done)]() {
// Simulate a random delay for RPC. This is needed to fill the entire
// object buffer in `RpcRecvTensorFreeList` and trigger the destruction of
// RPC call objects.
const int64 t_us = random::New64() % 100 * 1000;
Env::Default()->SleepForMicroseconds(t_us);
done(Status::OK());
});
}
};
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {}
WorkerInterface* GetOrCreateWorker(const string& target) override {
return nullptr;
if (dummy_remote_worker_ == nullptr) {
// Ownership transferred to WorkerFreeList
dummy_remote_worker_ = new DummyWorker;
}
return dummy_remote_worker_;
}
Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
@ -66,7 +90,31 @@ class DummyWorkerCache : public WorkerCacheInterface {
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {}
private:
DummyWorker* dummy_remote_worker_ = nullptr;
};
static Device* CreateDevice(const char* type, const char* name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};
DeviceAttributes attr;
attr.set_name(name);
attr.set_device_type(type);
return new FakeDevice(attr);
}
static DeviceMgr* CreateDeviceMgr() {
std::unique_ptr<Device> d0(
CreateDevice("CPU", "/job:mnist/replica:1/task:2/cpu:1"));
std::vector<std::unique_ptr<Device>> devices;
devices.emplace_back(std::move(d0));
return new StaticDeviceMgr(std::move(devices));
}
} // namespace
class RpcRendezvousMgrTest : public ::testing::Test {
@ -75,7 +123,7 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache),
worker_session_("rpc_session", "/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_),
std::unique_ptr<DeviceMgr>(),
std::unique_ptr<DeviceMgr>(CreateDeviceMgr()),
std::unique_ptr<GraphMgr>(), nullptr),
rmgr_(&env) {
env.env = Env::Default();
@ -193,6 +241,7 @@ TEST_F(RpcRendezvousMgrTest, CancelAfterReceived) {
delete cm;
}
namespace {
class DummyDeviceContext : public DeviceContext {
public:
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
@ -202,6 +251,7 @@ class DummyDeviceContext : public DeviceContext {
private:
const int stream_id_;
};
} // namespace
TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
DummyDeviceContext* dc = new DummyDeviceContext(123);
@ -237,6 +287,59 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
dc->Unref();
}
// NOTE: Remote Send/Recv is better tested in worker_test.cc
TEST_F(RpcRendezvousMgrTest, RemoteRecvOne) {
const int64 step_id = 123;
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:worker/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
Tensor val(DT_STRING);
bool val_dead = false;
TF_ASSERT_OK(rendez->Recv(key, args, &val, &val_dead));
}
rmgr_.Cleanup(step_id);
}
TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) {
const int64 step_id = 123;
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:worker/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
// Send a large number of async RPC requests to fill up the buffer in
// `RpcRecvTensorFreeList`, in order to test deleting RPC call objects.
int num_requests = 10000;
Tensor val(DT_STRING);
mutex mu_;
Status status = Status::OK();
BlockingCounter counter(num_requests);
for (int i = 0; i < num_requests; i++) {
rendez->RecvAsync(
key, args,
[&mu_, &status, &counter](const Status& s, const Rendezvous::Args&,
const Rendezvous::Args&, const Tensor&,
const bool) {
mutex_lock l(mu_);
status.Update(s);
counter.DecrementCount();
});
}
counter.Wait();
TF_ASSERT_OK(status);
}
rmgr_.Cleanup(step_id);
}
} // namespace tensorflow

View File

@ -70,28 +70,28 @@ class TestWorkerInterface : public WorkerInterface {
void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("CleanupGraphAsync"));
}
void CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("CleanupAllAsync"));
}
void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
TensorResponse* response, StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("RecvTensorAsync"));
}
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("LoggingAsync"));
}
void TracingAsync(const TracingRequest* request, TracingResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("TracingAsync"));
}
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
@ -103,20 +103,20 @@ class TestWorkerInterface : public WorkerInterface {
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("CompleteGroupAsync"));
}
void CompleteInstanceAsync(CallOptions* ops,
const CompleteInstanceRequest* request,
CompleteInstanceResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("CompleteInstanceAsync"));
}
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
StatusCallback done) override {
done(errors::Unimplemented("RunGraphAsync"));
done(errors::Unimplemented("GetStepSequenceAsync"));
}
};