Fix two memory leaks and enable asan for C API remote tests.
PiperOrigin-RevId: 321801325 Change-Id: Id579f93e167c9665b4ca740eee160da801ca0694
This commit is contained in:
parent
86ba317d72
commit
5c4b8790ca
tensorflow
c
core/distributed_runtime
@ -525,12 +525,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||||
grpc_server->worker_env()->collective_executor_mgr));
|
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||||
} else {
|
} else {
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||||
grpc_server->worker_env()->collective_executor_mgr));
|
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
|
@ -514,7 +514,6 @@ tf_cuda_cc_test(
|
|||||||
extra_copts = tfe_xla_copts(),
|
extra_copts = tfe_xla_copts(),
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
"noasan", # leaks gRPC server instances
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
@ -581,7 +580,6 @@ tf_cuda_cc_test(
|
|||||||
extra_copts = tfe_xla_copts(),
|
extra_copts = tfe_xla_copts(),
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
"noasan", # leaks gRPC server instances
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
|
@ -62,7 +62,7 @@ struct WorkerCacheFactoryOptions {
|
|||||||
struct MasterEnv {
|
struct MasterEnv {
|
||||||
Env* env = nullptr;
|
Env* env = nullptr;
|
||||||
|
|
||||||
// Object from which WorkerInterface instances can be obtained.
|
// Object from which WorkerInterface instances can be obtained. Not owned.
|
||||||
WorkerCacheInterface* worker_cache = nullptr;
|
WorkerCacheInterface* worker_cache = nullptr;
|
||||||
|
|
||||||
// The operation definitions to use. Must be filled before use.
|
// The operation definitions to use. Must be filled before use.
|
||||||
@ -93,7 +93,7 @@ struct MasterEnv {
|
|||||||
worker_cache_factory;
|
worker_cache_factory;
|
||||||
|
|
||||||
// Generates per-step CollectiveExecutors and has access to utilities
|
// Generates per-step CollectiveExecutors and has access to utilities
|
||||||
// supporting collective operations.
|
// supporting collective operations. Not owned.
|
||||||
CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
|
CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -267,9 +267,9 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||||||
CHECK_NE(nullptr, worker_cache);
|
CHECK_NE(nullptr, worker_cache);
|
||||||
|
|
||||||
if (opts.collective_mgr_func) {
|
if (opts.collective_mgr_func) {
|
||||||
worker_env_.collective_executor_mgr =
|
worker_env_.collective_executor_mgr.reset(
|
||||||
opts.collective_mgr_func(config, &worker_env_, worker_cache);
|
opts.collective_mgr_func(config, &worker_env_, worker_cache));
|
||||||
if (!worker_env_.collective_executor_mgr) {
|
if (worker_env_.collective_executor_mgr == nullptr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"collective_mgr_func did not return CollectiveExecutorMgr");
|
"collective_mgr_func did not return CollectiveExecutorMgr");
|
||||||
}
|
}
|
||||||
@ -281,9 +281,9 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||||||
new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
|
new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
|
||||||
dev_resolver.get(), worker_cache,
|
dev_resolver.get(), worker_cache,
|
||||||
default_worker_name));
|
default_worker_name));
|
||||||
worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
|
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
|
||||||
config, worker_env_.device_mgr, std::move(dev_resolver),
|
config, worker_env_.device_mgr, std::move(dev_resolver),
|
||||||
std::move(param_resolver), worker_cache, default_worker_name);
|
std::move(param_resolver), worker_cache, default_worker_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up worker environment.
|
// Set up worker environment.
|
||||||
@ -299,7 +299,8 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||||||
// Finish setting up master environment.
|
// Finish setting up master environment.
|
||||||
master_env_.ops = OpRegistry::Global();
|
master_env_.ops = OpRegistry::Global();
|
||||||
master_env_.worker_cache = worker_cache;
|
master_env_.worker_cache = worker_cache;
|
||||||
master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
|
master_env_.collective_executor_mgr =
|
||||||
|
worker_env_.collective_executor_mgr.get();
|
||||||
StatsPublisherFactory stats_factory = opts.stats_factory;
|
StatsPublisherFactory stats_factory = opts.stats_factory;
|
||||||
master_env_.master_session_factory =
|
master_env_.master_session_factory =
|
||||||
[config, stats_factory](
|
[config, stats_factory](
|
||||||
@ -433,6 +434,8 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Failed to build worker cache with the provided server def.");
|
"Failed to build worker cache with the provided server def.");
|
||||||
}
|
}
|
||||||
|
// Transfer ownership of worker_cache to worker_env_.session_mgr.
|
||||||
|
worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
|
||||||
|
|
||||||
string default_worker_name;
|
string default_worker_name;
|
||||||
string unused;
|
string unused;
|
||||||
@ -447,13 +450,14 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
|
|||||||
new CollectiveParamResolverDistributed(
|
new CollectiveParamResolverDistributed(
|
||||||
server_def_.default_session_config(), worker_env_.device_mgr,
|
server_def_.default_session_config(), worker_env_.device_mgr,
|
||||||
dev_resolver.get(), worker_cache, default_worker_name));
|
dev_resolver.get(), worker_cache, default_worker_name));
|
||||||
worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
|
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
|
||||||
server_def_.default_session_config(), worker_env_.device_mgr,
|
server_def_.default_session_config(), worker_env_.device_mgr,
|
||||||
std::move(dev_resolver), std::move(param_resolver), worker_cache,
|
std::move(dev_resolver), std::move(param_resolver), worker_cache,
|
||||||
default_worker_name);
|
default_worker_name));
|
||||||
|
|
||||||
master_env_.worker_cache = worker_cache;
|
master_env_.worker_cache = worker_cache;
|
||||||
master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
|
master_env_.collective_executor_mgr =
|
||||||
|
worker_env_.collective_executor_mgr.get();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,6 +144,10 @@ Status SessionMgr::CreateSession(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
|
||||||
|
default_worker_cache_.reset(worker_cache);
|
||||||
|
}
|
||||||
|
|
||||||
Status SessionMgr::UpdateSession(
|
Status SessionMgr::UpdateSession(
|
||||||
const string& session, const ServerDef& server_def,
|
const string& session, const ServerDef& server_def,
|
||||||
const protobuf::RepeatedPtrField<DeviceAttributes>&
|
const protobuf::RepeatedPtrField<DeviceAttributes>&
|
||||||
|
@ -53,6 +53,8 @@ class SessionMgr {
|
|||||||
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
|
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
|
||||||
bool isolate_session_state);
|
bool isolate_session_state);
|
||||||
|
|
||||||
|
void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache);
|
||||||
|
|
||||||
// Updates state (worker cache, devices) of worker session identified by
|
// Updates state (worker cache, devices) of worker session identified by
|
||||||
// session name (`session`) based on a new server_def and set of devices.
|
// session name (`session`) based on a new server_def and set of devices.
|
||||||
Status UpdateSession(const string& session, const ServerDef& server_def,
|
Status UpdateSession(const string& session, const ServerDef& server_def,
|
||||||
|
@ -60,7 +60,7 @@ struct WorkerEnv {
|
|||||||
|
|
||||||
// Generates per-step CollectiveExecutors and has access to utilities
|
// Generates per-step CollectiveExecutors and has access to utilities
|
||||||
// supporting collective operations.
|
// supporting collective operations.
|
||||||
CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
|
std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr;
|
||||||
|
|
||||||
// A pool of threads for scheduling compute work.
|
// A pool of threads for scheduling compute work.
|
||||||
thread::ThreadPool* compute_pool = nullptr;
|
thread::ThreadPool* compute_pool = nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user