Automated rollback of commit c00a1c6dc1

PiperOrigin-RevId: 270144377
This commit is contained in:
Brian Zhao 2019-09-19 15:56:01 -07:00 committed by TensorFlower Gardener
parent 297178914f
commit 8c521f81b1
10 changed files with 75 additions and 88 deletions

View File

@ -281,7 +281,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr(), std::move(remote_mgr)));
worker_session->cluster_flr.get(), std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.

View File

@ -154,13 +154,13 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
{
mutex_lock l(mu_);
if (session_ != nullptr) {
if (session_->worker_name() == session->worker_name()) {
if (session_->worker_name == session->worker_name) {
LOG(INFO) << "Skipping rendezvous re-initialization.";
return Status::OK();
}
Status s = errors::Internal(
"Double init! Worker names would have changed from: ",
session_->worker_name(), " -> ", session->worker_name());
session_->worker_name, " -> ", session->worker_name);
LOG(WARNING) << s;
return s;
}
@ -191,10 +191,10 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
tf_shared_lock l(mu_);
if (!status_.ok()) return status_;
DCHECK(is_initialized_locked());
if (!IsLocalDevice(session_->worker_name(), parsed.src_device)) {
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name());
session_->worker_name);
}
}
// Buffers "val" and "device_context" in local_.
@ -214,15 +214,13 @@ Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
}
sess = session_;
}
if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
sess->worker_name());
if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", sess->worker_name);
}
if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
sess->worker_name());
if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
parsed.FullKey(), " @ ", sess->worker_name);
}
return Status::OK();
}

View File

@ -160,8 +160,8 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
for (auto& function_data : function_data_) {
worker_session_->worker_cache()->ReleaseWorker(function_data.target,
function_data.wi);
worker_session_->worker_cache->ReleaseWorker(function_data.target,
function_data.wi);
}
}
@ -172,11 +172,11 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target
<< " (this: " << this << ")";
WorkerInterface* wi =
worker_session_->worker_cache()->GetOrCreateWorker(options.target);
worker_session_->worker_cache->GetOrCreateWorker(options.target);
if (wi == nullptr) {
std::vector<string> workers;
worker_session_->worker_cache()->ListWorkers(&workers);
worker_session_->worker_cache->ListWorkers(&workers);
return errors::InvalidArgument(
"Could not find worker with target: ", options.target,
" Available workers: ", absl::StrJoin(workers, ", "));
@ -199,7 +199,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
}
RegisterGraphRequest req;
req.set_session_handle(worker_session_->session_name());
req.set_session_handle(worker_session_->session_name);
req.set_create_worker_session_called(create_worker_session_called_);
*req.mutable_graph_def() = std::move(gdef);
req.mutable_graph_options()
@ -237,7 +237,7 @@ void ClusterFunctionLibraryRuntime::Run(
}
RunGraphRequest* req = new RunGraphRequest;
req->set_session_handle(worker_session_->session_name());
req->set_session_handle(worker_session_->session_name);
req->set_create_worker_session_called(create_worker_session_called_);
req->set_graph_handle(function_data->graph_handle);
req->set_step_id(opts.step_id);

View File

@ -132,20 +132,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, GetDefaultCustomKernelCreator(),
worker_session->cluster_flr());
worker_session->cluster_flr.get());
// Ownership will be transferred to the ServerContext, or else in an error
// case ctx will be deleted by this unref.
core::ScopedUnref unref_ctx(ctx);
std::vector<string> remote_workers;
worker_session->worker_cache()->ListWorkers(&remote_workers);
worker_session->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
worker_session->worker_name()),
worker_session->worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
&remote_eager_workers));
TF_RETURN_IF_ERROR(
worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers));
auto remote_mgr =
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);

View File

@ -225,10 +225,10 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
}
WorkerSession* sess = session();
// The worker will be released in a subsequent call to
// `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
// `sess->worker_cache->ReleaseWorker()` (if the call has not yet been
// initialized) or `call->ReleaseWorker()` (if it has been initialized).
WorkerInterface* rwi =
sess->worker_cache()->GetOrCreateWorker(call->src_worker_);
sess->worker_cache->GetOrCreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_);
}
@ -239,7 +239,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
}
if (!s.ok()) {
if (rwi != nullptr) {
sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
}
get_call_freelist()->Release(call);
done(s, Args(), recv_args, Tensor{}, false);
@ -257,7 +257,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// NOTE: `*sess` can potentially be deleted before we return from
// `call->done()(...)`, so we must release the worker before calling the
// callback.
call->ReleaseWorker(sess->worker_cache());
call->ReleaseWorker(sess->worker_cache.get());
call->done()(call->status(), Args(), Args(), Tensor(), false);
get_call_freelist()->Release(call);
return;
@ -274,7 +274,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// NOTE: `*session()` can potentially be deleted before we return from
// `call->done()(...)`, so we must release the worker before calling the
// callback.
call->ReleaseWorker(session()->worker_cache());
call->ReleaseWorker(session()->worker_cache.get());
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
get_call_freelist()->Release(call);
Unref();

View File

@ -71,7 +71,7 @@ Status SessionMgr::CreateSession(
string worker_name;
if (server_def.cluster().job().empty()) {
worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
worker_name = legacy_session_->worker_name();
worker_name = legacy_session_->worker_name;
} else {
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
worker_name = WorkerNameFromServerDef(server_def);
@ -162,7 +162,7 @@ Status SessionMgr::WorkerSessionForSessionLocked(
if (it == sessions_.end()) {
return errors::Aborted("Session handle is not found: ", session_handle,
". Possibly this worker (\"",
legacy_session_->worker_name(),
legacy_session_->worker_name,
"\") just restarted.");
} else {
*out_session = it->second;
@ -186,7 +186,7 @@ void SessionMgr::SetLogging(bool active) {
this->is_logging_active_ = active;
// Legacy Session
if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache();
auto* worker_cache = legacy_session_->worker_cache.get();
if (worker_cache) {
worker_cache->SetLogging(active);
}
@ -195,7 +195,7 @@ void SessionMgr::SetLogging(bool active) {
for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get();
if (session) {
auto* worker_cache = session->worker_cache();
auto* worker_cache = session->worker_cache.get();
if (worker_cache) {
worker_cache->SetLogging(active);
}
@ -208,7 +208,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
mutex_lock l(mu_);
// Legacy Session
if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache();
auto* worker_cache = legacy_session_->worker_cache.get();
if (worker_cache) {
auto step_stats = StepStats();
if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
@ -221,7 +221,7 @@ void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get();
if (session) {
auto* worker_cache = session->worker_cache();
auto* worker_cache = session->worker_cache.get();
if (worker_cache) {
auto step_stats = StepStats();
if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
@ -238,7 +238,7 @@ void SessionMgr::ClearLogs() {
mutex_lock l(mu_);
// Legacy Session
if (legacy_session_) {
auto* worker_cache = legacy_session_->worker_cache();
auto* worker_cache = legacy_session_->worker_cache.get();
if (worker_cache) {
worker_cache->ClearLogs();
}
@ -247,7 +247,7 @@ void SessionMgr::ClearLogs() {
for (const auto& session_kv : sessions_) {
auto session = session_kv.second.get();
if (session) {
auto* worker_cache = session->worker_cache();
auto* worker_cache = session->worker_cache.get();
if (worker_cache) {
worker_cache->ClearLogs();
}

View File

@ -102,7 +102,7 @@ TEST_F(SessionMgrTest, CreateSessionClusterDefWorkerName) {
EXPECT_TRUE(device->IsLocal());
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name());
EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
@ -113,7 +113,7 @@ TEST_F(SessionMgrTest, CreateSessionDefaultWorkerName) {
std::shared_ptr<WorkerSession> session;
TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name());
EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}

View File

@ -77,10 +77,10 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
session = env_->session_mgr->LegacySession();
}
if (s.ok()) {
s = session->graph_mgr()->Register(
s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), session.get(),
request->graph_options(), request->debug_options(),
request->collective_graph_key(), session->cluster_flr(),
request->collective_graph_key(), session->cluster_flr.get(),
response->mutable_graph_handle());
}
done(s);
@ -98,7 +98,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
session = env_->session_mgr->LegacySession();
}
if (s.ok()) {
s = session->graph_mgr()->Deregister(request->graph_handle());
s = session->graph_mgr->Deregister(request->graph_handle());
}
done(s);
@ -218,14 +218,14 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
done(errors::Aborted("Call was aborted"));
return;
}
session->graph_mgr()->ExecuteAsync(
session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, session.get(), request->exec_opts(),
collector, response, cm, in,
[this, step_id, response, session, cm, out, token, collector,
profiler_session, opts, done](const Status& status) {
Status s = status;
if (s.ok()) {
s = session->graph_mgr()->RecvOutputs(step_id, out);
s = session->graph_mgr->RecvOutputs(step_id, out);
}
opts->ClearCancelCallback();
@ -311,7 +311,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
token = cancellation_manager_.get_cancellation_token();
cancellation_manager_.RegisterCallback(token,
[cm]() { cm->StartCancel(); });
session->graph_mgr()->ExecuteAsync(
session->graph_mgr->ExecuteAsync(
graph_handle, step_id, session.get(), request->exec_opts(),
nullptr /* collector */, nullptr /* response */, cm, in,
[this, token, step_id, session](Status s) {
@ -320,14 +320,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
});
} else {
// Send the partial run's new inputs.
s = session->graph_mgr()->SendInputs(step_id, in);
s = session->graph_mgr->SendInputs(step_id, in);
if (!s.ok()) {
finish(s);
return;
}
}
session->graph_mgr()->RecvOutputsAsync(
session->graph_mgr->RecvOutputsAsync(
step_id, out, [this, out, request, response, step_id, finish](Status s) {
if (s.ok()) {
// Construct and return the resp.
@ -442,7 +442,7 @@ Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
"RecvTensor expects a different device incarnation: ",
parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
". Your worker job (\"",
env_->session_mgr->LegacySession()->worker_name(),
env_->session_mgr->LegacySession()->worker_name,
"\") was probably restarted. Check your "
"worker job for the reason why it was restarted.");
}

View File

@ -109,11 +109,11 @@ WorkerSession::WorkerSession(
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
: session_name_(session_name),
worker_name_(worker_name),
worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr_(std::move(graph_mgr)),
cluster_flr_(new ClusterFunctionLibraryRuntime(
: session_name(session_name),
worker_name(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr(std::move(graph_mgr)),
cluster_flr(new ClusterFunctionLibraryRuntime(
this, !session_name.empty(),
remote_device_mgr ? remote_device_mgr.get() : nullptr)),
device_mgr_(std::move(device_mgr)),
@ -142,12 +142,12 @@ WorkerSession::WorkerSession(
std::unique_ptr<WorkerCacheInterface> worker_cache,
DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
: session_name_(session_name),
worker_name_(worker_name),
worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr_(std::move(graph_mgr)),
cluster_flr_(new ClusterFunctionLibraryRuntime(
this, !session_name.empty(), remote_device_mgr.get())),
: session_name(session_name),
worker_name(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
graph_mgr(std::move(graph_mgr)),
cluster_flr(new ClusterFunctionLibraryRuntime(this, !session_name.empty(),
remote_device_mgr.get())),
device_mgr_(nullptr),
borrowed_device_mgr_(borrowed_device_mgr),
remote_device_mgr_(std::move(remote_device_mgr)) {
@ -159,8 +159,8 @@ WorkerSession::WorkerSession(
}
WorkerSession::~WorkerSession() {
if (graph_mgr_) {
Status s = graph_mgr_->DeregisterAll();
if (graph_mgr) {
Status s = graph_mgr->DeregisterAll();
if (!s.ok()) {
LOG(WARNING) << "Error during worker session deletion: " << s;
}

View File

@ -30,8 +30,16 @@ class GraphMgr;
class WorkerCacheInterface;
// WorkerSession encapsulates all of the state relating to a given session.
class WorkerSession {
public:
struct WorkerSession {
// The name of the session.
const string session_name;
// The name of the worker. E.g., /job:mnist/replica:0/task:1.
const string worker_name;
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache;
// Collection of local devices. These devices are typically
// RenamedDevices in all except the SessionMgr.legacy_session_ and
// sessions created with `isolate_session_state == false`. In the
@ -43,15 +51,13 @@ class WorkerSession {
DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); }
const string& session_name() const { return session_name_; }
const string& worker_name() const { return worker_name_; }
// graph_mgr keeps track of the registered graphs of this session.
//
// Note: graph_mgr must be deleted before rendezvous_mgr!
// Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr;
WorkerCacheInterface* worker_cache() const { return worker_cache_.get(); }
GraphMgr* graph_mgr() const { return graph_mgr_.get(); }
ClusterFunctionLibraryRuntime* cluster_flr() const {
return cluster_flr_.get();
}
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr;
WorkerSession(const string& session_name, const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache,
@ -74,23 +80,6 @@ class WorkerSession {
std::unique_ptr<GraphMgr> graph_mgr,
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);
// The name of the session.
const string session_name_;
// The name of the worker. E.g., /job:mnist/replica:0/task:1.
const string worker_name_;
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache_;
// graph_mgr keeps track of the registered graphs of this session.
//
// Note: graph_mgr must be deleted before rendezvous_mgr!
// Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr_;
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
const std::unique_ptr<DeviceMgr> device_mgr_;
DeviceMgr* const borrowed_device_mgr_; // Not owned.
const std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_;