Pipe ConfigProto through FLR so that it can be accessed by Ops like PartitionedCallOp.

Also pass the ConfigProto through distributed function calls both in the standard
graph registration mode and in the new eager master setup.

The PFLR stores a std::optional<ConfigProto> instead of a pointer, because it may be created with a pointer that would dangle after its creation.  At the same time, we need to know if a ConfigProto was available at creation time, which is why it's a std::optional.  In contrast, the FLR gets a pointer directly because it is given a valid pointer that will outlast it in all cases.

PiperOrigin-RevId: 272763578
This commit is contained in:
Eugene Brevdo 2019-10-03 16:12:13 -07:00 committed by TensorFlower Gardener
parent 11b69dada6
commit 90f01af49a
43 changed files with 185 additions and 118 deletions

View File

@ -447,6 +447,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) { for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da; *base_request.add_cluster_device_attributes() = da;
} }
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR( LOG_AND_RETURN_IF_ERROR(

View File

@ -82,7 +82,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
FunctionLibraryRuntime* GetFunctionLibraryRuntime() { FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
OptimizerOptions opts; OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), opts); nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def_.get(), opts);
return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
} }

View File

@ -1196,11 +1196,12 @@ Status EncapsulateSubgraphsPass::Run(
std::unique_ptr<DeviceMgr> device_mgr = std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<StaticDeviceMgr>(std::move(devices)); absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts; const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(device_mgr.get(), new ProcessFunctionLibraryRuntime(
options.session_options->env, device_mgr.get(), options.session_options->env,
TF_GRAPH_DEF_VERSION, library, opts)); /*config=*/config, TF_GRAPH_DEF_VERSION, library,
config->graph_options().optimizer_options()));
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0"); pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
if (flr == nullptr) { if (flr == nullptr) {

View File

@ -513,8 +513,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
OptimizerOptions opts; OptimizerOptions opts;
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices)); auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), device_mgr.get(), Env::Default(), /*config=*/nullptr,
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
std::unique_ptr<Graph> graph_out; std::unique_ptr<Graph> graph_out;

View File

@ -246,8 +246,9 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
bool *has_outside_compilation) { bool *has_outside_compilation) {
OptimizerOptions opts; OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, device_mgr_.get(), Env::Default(), /*config=*/nullptr,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); TF_GRAPH_DEF_VERSION, fld, opts,
/*default_thread_pool=*/nullptr);
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
return ExtractOutsideCompilationForFunction( return ExtractOutsideCompilationForFunction(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,

View File

@ -1074,8 +1074,8 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
Status MarkForCompilationPassImpl::FindCompilationCandidates() { Status MarkForCompilationPassImpl::FindCompilationCandidates() {
OptimizerOptions opts; OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION, new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
flib_def_, opts)); TF_GRAPH_DEF_VERSION, flib_def_, opts));
FunctionLibraryRuntime* lib_runtime = FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false); std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);

View File

@ -295,7 +295,7 @@ Status PartiallyDeclusterGraph(Graph* graph,
std::vector<bool> compile_time_const_nodes(graph->num_node_ids()); std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
OptimizerOptions opts; OptimizerOptions opts;
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts); nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
FunctionLibraryRuntime* lib_runtime = FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr, TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,

View File

@ -148,9 +148,9 @@ xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(), new ProcessFunctionLibraryRuntime(
TF_GRAPH_DEF_VERSION, flib_def, nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
OptimizerOptions{})); flib_def, OptimizerOptions{}));
FunctionLibraryRuntime* lib_runtime = FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -73,8 +73,9 @@ class XlaKernelCreatorTest : public ::testing::Test {
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), /*config=*/nullptr,
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
} }

View File

@ -222,10 +222,12 @@ Status FunctionalizeControlFlowForXlaPass::Run(
DumpGraphToFile("functionalize_control_flow_before", *graph, DumpGraphToFile("functionalize_control_flow_before", *graph,
options.flib_def); options.flib_def);
} }
const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime( new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, options.session_options->env, /*device_mgr=*/nullptr, options.session_options->env, config,
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); TF_GRAPH_DEF_VERSION, options.flib_def,
config->graph_options().optimizer_options()));
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -55,8 +55,8 @@ void AnalyzeAndVerify(
ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get())); ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get()));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def, nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
OptimizerOptions()); flib_def, OptimizerOptions());
FunctionLibraryRuntime* lib_runtime = FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo, absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,

View File

@ -293,8 +293,8 @@ TEST(CachedFunctionHandles, Basic) {
FunctionLibraryDefinition fld(OpRegistry::Global(), proto); FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime( new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld, /*device_mgr=*/nullptr, Env::Default(), /*config=*/nullptr,
OptimizerOptions())); TF_GRAPH_DEF_VERSION, &fld, OptimizerOptions()));
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -517,11 +517,11 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
FunctionDefLibrary{})); FunctionDefLibrary{}));
local_pflr_.reset(new ProcessFunctionLibraryRuntime( local_pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version, &device_mgr_, Env::Default(), /*config=*/nullptr,
local_flib_def_.get(), OptimizerOptions())); options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version, options.flib_def, &device_mgr_, Env::Default(), /*config=*/nullptr,
OptimizerOptions())); options.graph_def_version, options.flib_def, OptimizerOptions()));
local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name());

View File

@ -1292,7 +1292,7 @@ Status DirectSession::CreateExecutors(
? &options_.config.experimental().session_metadata() ? &options_.config.experimental().session_metadata()
: nullptr; : nullptr;
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, graph_def_version, device_mgr_.get(), options_.env, &options_.config, graph_def_version,
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
nullptr, nullptr, session_metadata)); nullptr, nullptr, session_metadata));

View File

@ -80,9 +80,9 @@ EagerContext::EagerContext(
thread_pool_(NewThreadPoolFromSessionOptions(opts)), thread_pool_(NewThreadPoolFromSessionOptions(opts)),
custom_kernel_creator_(custom_kernel_creator), custom_kernel_creator_(custom_kernel_creator),
pflr_(new ProcessFunctionLibraryRuntime( pflr_(new ProcessFunctionLibraryRuntime(
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
opts.config.graph_options().optimizer_options(), thread_pool_.get(), &func_lib_def_, opts.config.graph_options().optimizer_options(),
cluster_flr, custom_kernel_creator_)), thread_pool_.get(), cluster_flr, custom_kernel_creator_)),
log_device_placement_(opts.config.log_device_placement()), log_device_placement_(opts.config.log_device_placement()),
allow_soft_placement_(opts.config.allow_soft_placement()), allow_soft_placement_(opts.config.allow_soft_placement()),
num_active_steps_(0), num_active_steps_(0),
@ -697,9 +697,13 @@ Status EagerContext::StoreCollectiveOpsServer(
} }
} }
const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, local_unowned_device_manager_, env_, /*config=*/config,
{}, thread_pool_.get())); TF_GRAPH_DEF_VERSION, &func_lib_def_,
/*optimizer_options=*/
config ? config->graph_options().optimizer_options() : OptimizerOptions(),
thread_pool_.get()));
// Memory leak! // Memory leak!
if (server_ != nullptr) { if (server_ != nullptr) {
@ -849,9 +853,11 @@ Status EagerContext::SetMasterContextState(
entry.second->ClearError(); entry.second->ClearError();
} }
} }
const auto* config = pflr_->config();
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, local_unowned_device_manager_, env_, config, TF_GRAPH_DEF_VERSION,
{}, thread_pool_.get(), cluster_flr, custom_kernel_creator_)); &func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_));
keep_alive_secs_ = keep_alive_secs; keep_alive_secs_ = keep_alive_secs;
sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2); sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
@ -995,9 +1001,10 @@ Status EagerContext::UpdateRemoteWorker(
} }
SessionOptions options = SessionOptions(); SessionOptions options = SessionOptions();
const auto* config = pflr_->config();
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
worker_session_device_mgr, options.env, TF_GRAPH_DEF_VERSION, worker_session_device_mgr, options.env, config, TF_GRAPH_DEF_VERSION,
FuncLibDef(), options.config.graph_options().optimizer_options(), FuncLibDef(), config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_)); thread_pool_.get(), cluster_flr, custom_kernel_creator_));
return Status::OK(); return Status::OK();
} }

View File

@ -48,8 +48,9 @@ class TestEnv {
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts; OptimizerOptions opts;
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>( pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_, device_mgr_.get(), Env::Default(), /*config=*/nullptr,
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); TF_GRAPH_DEF_VERSION, &flib_def_, opts,
/*default_thread_pool=*/nullptr);
flr_ = pflr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0"); flr_ = pflr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
CHECK(flr_ != nullptr); CHECK(flr_ != nullptr);

View File

@ -45,15 +45,15 @@ class EagerProcessFunctionLibraryRuntime
: public ProcessFunctionLibraryRuntime { : public ProcessFunctionLibraryRuntime {
public: public:
EagerProcessFunctionLibraryRuntime( EagerProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version, const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
const FunctionLibraryDefinition* lib_def, int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options, const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr, thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr,
const CustomKernelCreator* custom_kernel_creator = nullptr) const CustomKernelCreator* custom_kernel_creator = nullptr)
: ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, : ProcessFunctionLibraryRuntime(
lib_def, optimizer_options, thread_pool, device_mgr, env, config, graph_def_version, lib_def,
parent, custom_kernel_creator) {} optimizer_options, thread_pool, parent, custom_kernel_creator) {}
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
void Run(const FunctionLibraryRuntime::Options& opts, void Run(const FunctionLibraryRuntime::Options& opts,

View File

@ -195,6 +195,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
const override; const override;
Env* env() override; Env* env() override;
const ConfigProto* const config_proto() override;
Device* device() override; Device* device() override;
const Device* device() const override; const Device* device() const override;
std::function<void(std::function<void()>)>* runner() override; std::function<void(std::function<void()>)>* runner() override;
@ -277,6 +278,10 @@ bool FunctionLibraryRuntimeOverlay::IsStateful(
Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
return base_flr_->config_proto();
}
Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
const Device* FunctionLibraryRuntimeOverlay::device() const { const Device* FunctionLibraryRuntimeOverlay::device() const {
@ -317,7 +322,8 @@ Status FunctionLibraryRuntimeOverlay::Clone(
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
public: public:
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
const ConfigProto* config, Device* device,
int graph_def_version, int graph_def_version,
const FunctionLibraryDefinition* lib_def, const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool, thread::ThreadPool* default_thread_pool,
@ -361,6 +367,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const DeviceMgr* device_mgr() const override { return device_mgr_; } const DeviceMgr* device_mgr() const override { return device_mgr_; }
Env* env() override { return env_; } Env* env() override { return env_; }
const ConfigProto* const config_proto() override { return config_; }
int graph_def_version() const override { return graph_def_version_; } int graph_def_version() const override { return graph_def_version_; }
string DebugString(Handle h) override; string DebugString(Handle h) override;
@ -376,6 +383,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const DeviceMgr* const device_mgr_; const DeviceMgr* const device_mgr_;
Device* const device_; Device* const device_;
Env* const env_; Env* const env_;
const ConfigProto* const config_;
const int graph_def_version_; const int graph_def_version_;
const FunctionLibraryDefinition* const base_lib_def_; const FunctionLibraryDefinition* const base_lib_def_;
GraphOptimizer optimizer_; GraphOptimizer optimizer_;
@ -442,8 +450,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
}; };
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
const FunctionLibraryDefinition* lib_def, int graph_def_version, const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool, thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options, const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator, const CustomKernelCreator* custom_kernel_creator,
@ -452,6 +460,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
: device_mgr_(dmgr), : device_mgr_(dmgr),
device_(device), device_(device),
env_(env), env_(env),
config_(config),
graph_def_version_(graph_def_version), graph_def_version_(graph_def_version),
base_lib_def_(lib_def), base_lib_def_(lib_def),
optimizer_(optimizer_options), optimizer_(optimizer_options),
@ -1283,14 +1292,15 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
} }
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device, const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def, Device* device, int graph_def_version,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator, const CustomKernelCreator* custom_kernel_creator,
const SessionMetadata* session_metadata, const SessionMetadata* session_metadata,
ProcessFunctionLibraryRuntime* parent) { ProcessFunctionLibraryRuntime* parent) {
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
device_mgr, env, device, graph_def_version, lib_def, thread_pool, device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
optimizer_options, custom_kernel_creator, session_metadata, parent)); optimizer_options, custom_kernel_creator, session_metadata, parent));
} }

View File

@ -59,11 +59,12 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c);
// typically owns the created FunctionLibraryRuntime object. The parent pointer // typically owns the created FunctionLibraryRuntime object. The parent pointer
// is not owned by the FunctionLibraryRuntime object. // is not owned by the FunctionLibraryRuntime object.
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device, const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def, Device* device, int graph_def_version,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator, const CustomKernelCreator* custom_kernel_creator,
const SessionMetadata* sesson_metadata, const SessionMetadata* session_metadata,
ProcessFunctionLibraryRuntime* parent); ProcessFunctionLibraryRuntime* parent);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an // FunctionLibraryRuntime::GetFunctionBody returns a description of an

View File

@ -164,8 +164,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), &options.config,
opts)); TF_GRAPH_DEF_VERSION, lib_def_.get(), opts));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");

View File

@ -64,8 +64,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), /*config=*/nullptr,
opts, default_thread_pool)); TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
} }

View File

@ -65,8 +65,8 @@ Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
} }
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version, const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
const FunctionLibraryDefinition* lib_def, int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options, const OptimizerOptions& optimizer_options,
thread::ThreadPool* default_thread_pool, thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent, DistributedFunctionLibraryRuntime* parent,
@ -74,6 +74,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const SessionMetadata* session_metadata) const SessionMetadata* session_metadata)
: parent_(parent), : parent_(parent),
env_(env), env_(env),
config_(config ? absl::make_optional(*config) : absl::nullopt),
device_mgr_(device_mgr), device_mgr_(device_mgr),
lib_def_(lib_def), lib_def_(lib_def),
default_thread_pool_(default_thread_pool), default_thread_pool_(default_thread_pool),
@ -83,14 +84,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_(session_metadata) { session_metadata_(session_metadata) {
if (device_mgr == nullptr) { if (device_mgr == nullptr) {
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime( (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool, nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
optimizer_options, custom_kernel_creator, session_metadata_, this); graph_def_version, lib_def_, default_thread_pool, optimizer_options,
custom_kernel_creator, session_metadata_, this);
return; return;
} }
for (Device* d : device_mgr->ListDevices()) { for (Device* d : device_mgr->ListDevices()) {
(*flr_map_)[d] = NewFunctionLibraryRuntime( (*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool, device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
optimizer_options, custom_kernel_creator, session_metadata_, this); lib_def_, default_thread_pool, optimizer_options, custom_kernel_creator,
session_metadata_, this);
} }
DeviceMgr const* all_devices = device_mgr_; DeviceMgr const* all_devices = device_mgr_;
@ -1330,9 +1333,9 @@ Status ProcessFunctionLibraryRuntime::Clone(
*out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_); *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
} }
*out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( *out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_, env, graph_def_version, out_lib_def->get(), device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
optimizer_options, default_thread_pool_, parent_, custom_kernel_creator, out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
session_metadata_); custom_kernel_creator, session_metadata_);
return Status::OK(); return Status::OK();
} }

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -41,8 +42,8 @@ class ProcessFunctionLibraryRuntime {
// DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
// (if provided) outlive this object. // (if provided) outlive this object.
ProcessFunctionLibraryRuntime( ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version, const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
const FunctionLibraryDefinition* lib_def, int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options, const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr, thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr,
@ -159,6 +160,8 @@ class ProcessFunctionLibraryRuntime {
const DeviceSet* device_set() { return &device_set_; } const DeviceSet* device_set() { return &device_set_; }
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
return lib_def_; return lib_def_;
} }
@ -381,6 +384,7 @@ class ProcessFunctionLibraryRuntime {
mutable mutex mu_; mutable mutex mu_;
Env* const env_; Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_; const DeviceMgr* const device_mgr_;
DeviceSet device_set_; DeviceSet device_set_;
const FunctionLibraryDefinition* lib_def_; const FunctionLibraryDefinition* lib_def_;

View File

@ -121,8 +121,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts; OptimizerOptions opts;
cluster_flr_.reset(new TestClusterFLR(device_mgr_.get())); cluster_flr_.reset(new TestClusterFLR(device_mgr_.get()));
proc_flr_.reset(new ProcessFunctionLibraryRuntime( proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), /*config=*/nullptr,
opts, nullptr, cluster_flr_.get(), nullptr, session_metadata)); TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
nullptr, session_metadata));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
} }
@ -295,8 +296,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
OptimizerOptions opts; OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr( std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
new ProcessFunctionLibraryRuntime( new ProcessFunctionLibraryRuntime(
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION, nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr,
lib_def.get(), opts, nullptr, nullptr /* cluster_flr */)); TF_GRAPH_DEF_VERSION, lib_def.get(), opts));
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
EXPECT_NE(flr, nullptr); EXPECT_NE(flr, nullptr);

View File

@ -128,9 +128,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
LOG(INFO) << "Creating " << (request->async() ? "async" : "sync") LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
<< " eager service context with rendezvous_id on host " << " eager service context with rendezvous_id on host "
<< port::Hostname() << " " << worker_session->worker_name(); << port::Hostname() << " " << worker_session->worker_name();
SessionOptions opts;
opts.config = request->server_def().default_session_config();
tensorflow::EagerContext* ctx = new tensorflow::EagerContext( tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
SessionOptions(), opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, GetDefaultCustomKernelCreator(), device_mgr, false, r, GetDefaultCustomKernelCreator(),
worker_session->cluster_flr()); worker_session->cluster_flr());

View File

@ -234,7 +234,6 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
EnqueueRequest remote_enqueue_request; EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id); remote_enqueue_request.set_context_id(context_id);
EnqueueResponse remote_enqueue_response; EnqueueResponse remote_enqueue_response;
@ -400,8 +399,9 @@ TEST_F(EagerServiceImplTest, EagerPFLRTest) {
auto device_mgr = absl::make_unique<StaticDeviceMgr>( auto device_mgr = absl::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1")); DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
auto eager_pflr = absl::make_unique<EagerProcessFunctionLibraryRuntime>( auto eager_pflr = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &func_lib_def, device_mgr.get(), Env::Default(), /*config=*/nullptr,
OptimizerOptions(), nullptr, eager_cluster_flr.get(), nullptr); TF_GRAPH_DEF_VERSION, &func_lib_def, OptimizerOptions(), nullptr,
eager_cluster_flr.get(), nullptr);
tensorflow::FunctionDef fdef = MatMulFunction(); tensorflow::FunctionDef fdef = MatMulFunction();
TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef)); TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));

View File

@ -121,13 +121,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
// //
// "executors" are filled with one executor per device if success and // "executors" are filled with one executor per device if success and
// the caller takes the ownership of returned executors. // the caller takes the ownership of returned executors.
Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, Status GraphMgr::InitItem(
WorkerSession* session, const string& handle, const GraphDef& gdef, WorkerSession* session,
const GraphOptions& graph_options, const GraphOptions& graph_options, const DebugOptions& debug_options,
const DebugOptions& debug_options, const ConfigProto& config_proto, int64 collective_graph_key,
int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
DistributedFunctionLibraryRuntime* cluster_flr,
Item* item) {
item->session = handle; item->session = handle;
item->collective_graph_key = collective_graph_key; item->collective_graph_key = collective_graph_key;
item->lib_def.reset( item->lib_def.reset(
@ -139,9 +137,10 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
// does that below. // does that below.
item->proc_flr.reset(new ProcessFunctionLibraryRuntime( item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_, worker_env_->env, gdef.versions().producer(), device_mgr_, worker_env_->env, /*config=*/&config_proto,
item->lib_def.get(), graph_options.optimizer_options(), gdef.versions().producer(), item->lib_def.get(),
worker_env_->compute_pool, cluster_flr)); graph_options.optimizer_options(), worker_env_->compute_pool,
cluster_flr));
// Constructs the graph out of "gdef". // Constructs the graph out of "gdef".
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
@ -287,16 +286,14 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
return Status::OK(); return Status::OK();
} }
Status GraphMgr::Register(const string& handle, const GraphDef& gdef, Status GraphMgr::Register(
WorkerSession* session, const string& handle, const GraphDef& gdef, WorkerSession* session,
const GraphOptions& graph_options, const GraphOptions& graph_options, const DebugOptions& debug_options,
const DebugOptions& debug_options, const ConfigProto& config_proto, int64 collective_graph_key,
int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
DistributedFunctionLibraryRuntime* cluster_flr,
string* graph_handle) {
Item* item = new Item; Item* item = new Item;
Status s = InitItem(handle, gdef, session, graph_options, debug_options, Status s = InitItem(handle, gdef, session, graph_options, debug_options,
collective_graph_key, cluster_flr, item); config_proto, collective_graph_key, cluster_flr, item);
if (!s.ok()) { if (!s.ok()) {
item->Unref(); item->Unref();
return s; return s;

View File

@ -76,7 +76,8 @@ class GraphMgr {
// reference to cluster_flr to do cross process function calls. // reference to cluster_flr to do cross process function calls.
Status Register(const string& handle, const GraphDef& gdef, Status Register(const string& handle, const GraphDef& gdef,
WorkerSession* session, const GraphOptions& graph_options, WorkerSession* session, const GraphOptions& graph_options,
const DebugOptions& debug_options, int64 collective_graph_key, const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, DistributedFunctionLibraryRuntime* cluster_flr,
string* graph_handle); string* graph_handle);
@ -179,7 +180,8 @@ class GraphMgr {
Status InitItem(const string& handle, const GraphDef& gdef, Status InitItem(const string& handle, const GraphDef& gdef,
WorkerSession* session, const GraphOptions& graph_options, WorkerSession* session, const GraphOptions& graph_options,
const DebugOptions& debug_options, int64 collective_graph_key, const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item); DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,

View File

@ -472,6 +472,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
c->req.set_session_handle(session_handle_); c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_); c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_config_proto() = session_opts_.config;
*c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() = *c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options(); callable_opts_.run_options().debug_options();

View File

@ -80,8 +80,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
s = session->graph_mgr()->Register( s = session->graph_mgr()->Register(
request->session_handle(), request->graph_def(), session.get(), request->session_handle(), request->graph_def(), session.get(),
request->graph_options(), request->debug_options(), request->graph_options(), request->debug_options(),
request->collective_graph_key(), session->cluster_flr(), request->config_proto(), request->collective_graph_key(),
response->mutable_graph_handle()); session->cluster_flr(), response->mutable_graph_handle());
} }
done(s); done(s);
} }

View File

@ -297,6 +297,7 @@ class IteratorContext {
allocator_getter = [device](AllocatorAttributes attrs) { allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs); return device->GetAllocator(attrs);
}; };
thread::ThreadPool* thread_pool = thread::ThreadPool* thread_pool =
ctx->device()->tensorflow_device_thread_pool(); ctx->device()->tensorflow_device_thread_pool();
if (thread_pool) { if (thread_pool) {

View File

@ -752,6 +752,9 @@ class FunctionLibraryRuntime {
// Returns the environment on which the function executes. // Returns the environment on which the function executes.
virtual Env* env() = 0; virtual Env* env() = 0;
// Returns the ConfigProto passed to the session used to create the function.
virtual const ConfigProto* const config_proto() = 0;
// Returns a debug string showing the definition of the function of // Returns a debug string showing the definition of the function of
// 'handle'. // 'handle'.
virtual string DebugString(Handle handle) = 0; virtual string DebugString(Handle handle) = 0;

View File

@ -256,7 +256,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
// Create the function library runtime. // Create the function library runtime.
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
graph_def.versions().producer(), graph_def.versions().producer(),
&function_library, *optimizer_opts)); &function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name()); FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());

View File

@ -529,10 +529,11 @@ Status CapturedFunction::Instantiate(
// The context's runtime will be used for all subsequent calls. // The context's runtime will be used for all subsequent calls.
FunctionLibraryRuntime* lib = ctx->flr(); FunctionLibraryRuntime* lib = ctx->flr();
FunctionLibraryRuntime::InstantiateOptions inst_opts; FunctionLibraryRuntime::InstantiateOptions inst_opts;
// TODO(b/141576771): Propagate ctx ConfigProto into inst_opts.config_proto.
inst_opts.lib_def = metadata_->lib_def(); inst_opts.lib_def = metadata_->lib_def();
inst_opts.create_kernels_eagerly = true; inst_opts.create_kernels_eagerly = true;
inst_opts.default_device_to_target = metadata_->use_default_device(); inst_opts.default_device_to_target = metadata_->use_default_device();
inst_opts.config_proto =
lib->config_proto() ? *lib->config_proto() : ConfigProto();
if (!metadata_->use_inter_op_parallelism()) { if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
} }

View File

@ -419,8 +419,9 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
OptimizerOptions opts; OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), /*config=*/nullptr,
opts, thread_pool_.get(), nullptr /* cluster_flr */); TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
nullptr /* cluster_flr */);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
if (thread_pool_ == nullptr) { if (thread_pool_ == nullptr) {
runner_ = [](std::function<void()> fn) { fn(); }; runner_ = [](std::function<void()> fn) { fn(); };

View File

@ -355,16 +355,18 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the // in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource // IteratorResource, because we are storing the IteratorResource
// in that device's resource manager. // in that device's resource manager.
*device_mgr = *device_mgr =
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice( absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()), ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */)); false /* owns_underlying */, false /* isolate_session_state */));
*flib_def = absl::make_unique<FunctionLibraryDefinition>( *flib_def = absl::make_unique<FunctionLibraryDefinition>(
*ctx->function_library()->GetFunctionLibraryDefinition()); *ctx->function_library()->GetFunctionLibraryDefinition());
const auto* config = ctx->function_library()->config_proto();
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(), device_mgr->get(), ctx->env(),
OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */, /*config=*/config, graph_def_version_, flib_def->get(),
nullptr /* TODO(mrry): ClusterFLR */); config->graph_options().optimizer_options());
return (*pflr)->GetFLR(ctx->device()->name()); return (*pflr)->GetFLR(ctx->device()->name());
} }

View File

@ -325,7 +325,14 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
AttrValueMap attr_values = func_.attr(); AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::InstantiateOptions instantiate_opts; FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto.
const auto* config = (ctx->function_library())
? ctx->function_library()->config_proto()
: nullptr;
if (config) {
instantiate_opts.config_proto = *config;
}
instantiate_opts.target = target_device; instantiate_opts.target = target_device;
FunctionTarget function_target = {target_device, lib}; FunctionTarget function_target = {target_device, lib};

View File

@ -29,8 +29,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
device_ = device.get(); device_ = device.get();
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), device_mgr_.get(), Env::Default(), /*config=*/nullptr,
OptimizerOptions()); TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
device_type_ = device_type; device_type_ = device_type;
#ifdef GOOGLE_CUDA #ifdef GOOGLE_CUDA

View File

@ -84,8 +84,8 @@ class OpsTestBase : public ::testing::Test {
flib_def_ = absl::make_unique<FunctionLibraryDefinition>( flib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), FunctionDefLibrary{}); OpRegistry::Global(), FunctionDefLibrary{});
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, device_mgr_.get(), Env::Default(), /*config=*/nullptr,
flib_def_.get(), OptimizerOptions()); TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
} }
~OpsTestBase() override { ~OpsTestBase() override {

View File

@ -166,7 +166,12 @@ Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib,
std::vector<Tensor>* inputs, std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle) { FunctionLibraryRuntime::Handle* handle) {
FunctionLibraryRuntime::InstantiateOptions opts; FunctionLibraryRuntime::InstantiateOptions opts;
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto. const auto* config = (ctx->function_library())
? ctx->function_library()->config_proto()
: nullptr;
if (config) {
opts.config_proto = *config;
}
#ifndef __ANDROID__ #ifndef __ANDROID__
// Android tf library does not include grappler. // Android tf library does not include grappler.

View File

@ -132,6 +132,11 @@ message RegisterGraphRequest {
// concurrently so that BufRendezvous entries will make the correct // concurrently so that BufRendezvous entries will make the correct
// values accessible. // values accessible.
int64 collective_graph_key = 7; int64 collective_graph_key = 7;
// ConfigProto from the session in which this graph was created.
// Contains additional parameters beyond graph_options, including
// the name of the requested executor.
ConfigProto config_proto = 8;
} }
message RegisterGraphResponse { message RegisterGraphResponse {

View File

@ -2237,8 +2237,9 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::StaticDeviceMgr device_mgr(std::move(devices)); tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts; tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr( tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, &device_mgr, tensorflow::Env::Default(), &options.config,
o_opts, nullptr); TF_GRAPH_DEF_VERSION, &fld,
options.config.graph_options().optimizer_options(), nullptr);
tensorflow::FunctionLibraryRuntime* flr; tensorflow::FunctionLibraryRuntime* flr;
flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");

View File

@ -124,8 +124,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
job_def.tasks[0] = "localhost:{}".format(local_port) job_def.tasks[0] = "localhost:{}".format(local_port)
server_def = ServerDef( server_def = ServerDef(
cluster=cluster_def, job_name=job_name, task_index=task_index, cluster=cluster_def,
protocol=protocol) job_name=job_name,
task_index=task_index,
protocol=protocol,
default_session_config=context.context().config)
context.set_server_def(server_def) context.set_server_def(server_def)