diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 60f58c7074b..7eb8a33c727 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -447,6 +447,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( for (const auto& da : local_device_attributes) { *base_request.add_cluster_device_attributes() = da; } + base_request.mutable_server_def() + ->mutable_default_session_config() + ->MergeFrom(server_def.default_session_config()); std::unique_ptr remote_eager_workers; LOG_AND_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index cdce1e92799..e6e49ae7957 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -82,7 +82,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test { FunctionLibraryRuntime* GetFunctionLibraryRuntime() { OptimizerOptions opts; pflr_ = absl::make_unique( - nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), opts); + nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION, + flib_def_.get(), opts); return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 114800d87f3..494a8d1613b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1196,11 +1196,12 @@ Status EncapsulateSubgraphsPass::Run( std::unique_ptr device_mgr = absl::make_unique(std::move(devices)); - OptimizerOptions opts; + const auto* config = &options.session_options->config; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(device_mgr.get(), - options.session_options->env, - TF_GRAPH_DEF_VERSION, library, opts)); + new ProcessFunctionLibraryRuntime( + device_mgr.get(), options.session_options->env, + /*config=*/config, TF_GRAPH_DEF_VERSION, library, + config->graph_options().optimizer_options())); FunctionLibraryRuntime* flr = pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0"); if (flr == nullptr) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 0f7cee518d4..853e6bcfc8c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -513,8 +513,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, OptimizerOptions opts; auto device_mgr = absl::make_unique(std::move(devices)); auto pflr = absl::make_unique( - device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), - opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + device_mgr.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def.get(), opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); std::unique_ptr graph_out; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index fa3ba3bfc4a..049ee8233c7 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -246,8 +246,9 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { bool *has_outside_compilation) { OptimizerOptions opts; pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, - /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, fld, opts, + /*default_thread_pool=*/nullptr); auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); return ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 90755a1cb70..a0cf950957b 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1074,8 +1074,8 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { Status MarkForCompilationPassImpl::FindCompilationCandidates() { OptimizerOptions opts; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION, - flib_def_, opts)); + new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, flib_def_, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); std::vector compile_time_const_nodes(graph_->num_node_ids(), false); diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 30ba5a56efd..d1475ff0c6b 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -295,7 +295,7 @@ Status PartiallyDeclusterGraph(Graph* graph, std::vector compile_time_const_nodes(graph->num_node_ids()); OptimizerOptions opts; auto pflr = absl::make_unique( - nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts); + nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr, diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index aa87a958c1b..acac2f7d055 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -148,9 +148,9 @@ xla::StatusOr> GetNodesRelatedToRefVarsSorted( TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, Env::Default(), - TF_GRAPH_DEF_VERSION, flib_def, - OptimizerOptions{})); + new ProcessFunctionLibraryRuntime( + nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION, + flib_def, OptimizerOptions{})); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index 6ad11c4e028..28606abf2b2 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -73,8 +73,9 @@ class XlaKernelCreatorTest : public ::testing::Test { OptimizerOptions opts; device_mgr_ = absl::make_unique(std::move(devices)); pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 41c8c1e9e68..4bc1a5dd688 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -222,10 +222,12 @@ Status FunctionalizeControlFlowForXlaPass::Run( DumpGraphToFile("functionalize_control_flow_before", *graph, options.flib_def); } + const auto* config = &options.session_options->config; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime( - /*device_mgr=*/nullptr, options.session_options->env, - TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + /*device_mgr=*/nullptr, options.session_options->env, config, + TF_GRAPH_DEF_VERSION, options.flib_def, + config->graph_options().optimizer_options())); FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); diff --git a/tensorflow/compiler/tf2xla/resource_util_test.cc b/tensorflow/compiler/tf2xla/resource_util_test.cc index 2ad6274950d..541bf9870e7 100644 --- a/tensorflow/compiler/tf2xla/resource_util_test.cc +++ b/tensorflow/compiler/tf2xla/resource_util_test.cc @@ -55,8 +55,8 @@ void AnalyzeAndVerify( ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get())); auto pflr = absl::make_unique( - nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def, - OptimizerOptions()); + nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION, + flib_def, OptimizerOptions()); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); absl::flat_hash_map pflr( new ProcessFunctionLibraryRuntime( - /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld, - OptimizerOptions())); + /*device_mgr=*/nullptr, Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, &fld, OptimizerOptions())); FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 986a72fa12c..769ff932480 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -517,11 +517,11 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( - &device_mgr_, Env::Default(), options.graph_def_version, - local_flib_def_.get(), OptimizerOptions())); + &device_mgr_, Env::Default(), /*config=*/nullptr, + options.graph_def_version, local_flib_def_.get(), OptimizerOptions())); pflr_.reset(new ProcessFunctionLibraryRuntime( - &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def, - OptimizerOptions())); + &device_mgr_, Env::Default(), /*config=*/nullptr, + options.graph_def_version, options.flib_def, OptimizerOptions())); local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 7effe581279..ab52aa4aa46 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1292,7 +1292,7 @@ Status DirectSession::CreateExecutors( ? &options_.config.experimental().session_metadata() : nullptr; func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), options_.env, graph_def_version, + device_mgr_.get(), options_.env, &options_.config, graph_def_version, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, nullptr, nullptr, session_metadata)); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5c5be163319..7447ff10456 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -80,9 +80,9 @@ EagerContext::EagerContext( thread_pool_(NewThreadPoolFromSessionOptions(opts)), custom_kernel_creator_(custom_kernel_creator), pflr_(new ProcessFunctionLibraryRuntime( - device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, - opts.config.graph_options().optimizer_options(), thread_pool_.get(), - cluster_flr, custom_kernel_creator_)), + device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION, + &func_lib_def_, opts.config.graph_options().optimizer_options(), + thread_pool_.get(), cluster_flr, custom_kernel_creator_)), log_device_placement_(opts.config.log_device_placement()), allow_soft_placement_(opts.config.allow_soft_placement()), num_active_steps_(0), @@ -697,9 +697,13 @@ Status EagerContext::StoreCollectiveOpsServer( } } + const ConfigProto* config = pflr_ ? pflr_->config() : nullptr; pflr_.reset(new ProcessFunctionLibraryRuntime( - local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, - {}, thread_pool_.get())); + local_unowned_device_manager_, env_, /*config=*/config, + TF_GRAPH_DEF_VERSION, &func_lib_def_, + /*optimizer_options=*/ + config ? config->graph_options().optimizer_options() : OptimizerOptions(), + thread_pool_.get())); // Memory leak! if (server_ != nullptr) { @@ -849,9 +853,11 @@ Status EagerContext::SetMasterContextState( entry.second->ClearError(); } } + const auto* config = pflr_->config(); pflr_.reset(new ProcessFunctionLibraryRuntime( - local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, - {}, thread_pool_.get(), cluster_flr, custom_kernel_creator_)); + local_unowned_device_manager_, env_, config, TF_GRAPH_DEF_VERSION, + &func_lib_def_, config->graph_options().optimizer_options(), + thread_pool_.get(), cluster_flr, custom_kernel_creator_)); keep_alive_secs_ = keep_alive_secs; sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2); @@ -995,9 +1001,10 @@ Status EagerContext::UpdateRemoteWorker( } SessionOptions options = SessionOptions(); + const auto* config = pflr_->config(); pflr_.reset(new ProcessFunctionLibraryRuntime( - worker_session_device_mgr, options.env, TF_GRAPH_DEF_VERSION, - FuncLibDef(), options.config.graph_options().optimizer_options(), + worker_session_device_mgr, options.env, config, TF_GRAPH_DEF_VERSION, + FuncLibDef(), config->graph_options().optimizer_options(), thread_pool_.get(), cluster_flr, custom_kernel_creator_)); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index cc35780b0e4..779f72361cf 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -48,8 +48,9 @@ class TestEnv { device_mgr_ = absl::make_unique(std::move(devices)); OptimizerOptions opts; pflr_ = tensorflow::MakeUnique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_, - opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, &flib_def_, opts, + /*default_thread_pool=*/nullptr); flr_ = pflr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0"); CHECK(flr_ != nullptr); diff --git a/tensorflow/core/common_runtime/eager/process_function_library_runtime.h b/tensorflow/core/common_runtime/eager/process_function_library_runtime.h index ff6dd997996..1bc586fe961 100644 --- a/tensorflow/core/common_runtime/eager/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/eager/process_function_library_runtime.h @@ -45,15 +45,15 @@ class EagerProcessFunctionLibraryRuntime : public ProcessFunctionLibraryRuntime { public: EagerProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, thread::ThreadPool* thread_pool = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr, const CustomKernelCreator* custom_kernel_creator = nullptr) - : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, - lib_def, optimizer_options, thread_pool, - parent, custom_kernel_creator) {} + : ProcessFunctionLibraryRuntime( + device_mgr, env, config, graph_def_version, lib_def, + optimizer_options, thread_pool, parent, custom_kernel_creator) {} #if !defined(IS_MOBILE_PLATFORM) void Run(const FunctionLibraryRuntime::Options& opts, diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 6362325ddce..609a7c97aa5 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -195,6 +195,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { const override; Env* env() override; + const ConfigProto* const config_proto() override; Device* device() override; const Device* device() const override; std::function)>* runner() override; @@ -277,6 +278,10 @@ bool FunctionLibraryRuntimeOverlay::IsStateful( Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } +const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() { + return base_flr_->config_proto(); +} + Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } const Device* FunctionLibraryRuntimeOverlay::device() const { @@ -317,7 +322,8 @@ Status FunctionLibraryRuntimeOverlay::Clone( class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: - FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, + FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, + const ConfigProto* config, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, thread::ThreadPool* default_thread_pool, @@ -361,6 +367,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const DeviceMgr* device_mgr() const override { return device_mgr_; } Env* env() override { return env_; } + const ConfigProto* const config_proto() override { return config_; } int graph_def_version() const override { return graph_def_version_; } string DebugString(Handle h) override; @@ -376,6 +383,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const DeviceMgr* const device_mgr_; Device* const device_; Env* const env_; + const ConfigProto* const config_; const int graph_def_version_; const FunctionLibraryDefinition* const base_lib_def_; GraphOptimizer optimizer_; @@ -442,8 +450,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( - const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, - const FunctionLibraryDefinition* lib_def, + const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device, + int graph_def_version, const FunctionLibraryDefinition* lib_def, thread::ThreadPool* default_thread_pool, const OptimizerOptions& optimizer_options, const CustomKernelCreator* custom_kernel_creator, @@ -452,6 +460,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( : device_mgr_(dmgr), device_(device), env_(env), + config_(config), graph_def_version_(graph_def_version), base_lib_def_(lib_def), optimizer_(optimizer_options), @@ -1283,14 +1292,15 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) { } std::unique_ptr NewFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, Device* device, - int graph_def_version, const FunctionLibraryDefinition* lib_def, - thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + Device* device, int graph_def_version, + const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool, + const OptimizerOptions& optimizer_options, const CustomKernelCreator* custom_kernel_creator, const SessionMetadata* session_metadata, ProcessFunctionLibraryRuntime* parent) { return std::unique_ptr(new FunctionLibraryRuntimeImpl( - device_mgr, env, device, graph_def_version, lib_def, thread_pool, + device_mgr, env, config, device, graph_def_version, lib_def, thread_pool, optimizer_options, custom_kernel_creator, session_metadata, parent)); } diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 9465f07196a..9ef70d2abf6 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -59,11 +59,12 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c); // typically owns the created FunctionLibraryRuntime object. The parent pointer // is not owned by the FunctionLibraryRuntime object. std::unique_ptr NewFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, Device* device, - int graph_def_version, const FunctionLibraryDefinition* lib_def, - thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + Device* device, int graph_def_version, + const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool, + const OptimizerOptions& optimizer_options, const CustomKernelCreator* custom_kernel_creator, - const SessionMetadata* sesson_metadata, + const SessionMetadata* session_metadata, ProcessFunctionLibraryRuntime* parent); // FunctionLibraryRuntime::GetFunctionBody returns a description of an diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 330ad73aa5c..c3d6e948f1e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -164,8 +164,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { OptimizerOptions opts; device_mgr_ = absl::make_unique(std::move(devices)); pflr_.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts)); + device_mgr_.get(), Env::Default(), &options.config, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts)); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index fdacacc11ce..8f31cda9310 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -64,8 +64,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { OptimizerOptions opts; device_mgr_ = absl::make_unique(std::move(devices)); pflr_.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, default_thread_pool)); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool)); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 68dd2e2ce10..a5d6392542d 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -65,8 +65,8 @@ Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit( } ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, thread::ThreadPool* default_thread_pool, DistributedFunctionLibraryRuntime* parent, @@ -74,6 +74,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const SessionMetadata* session_metadata) : parent_(parent), env_(env), + config_(config ? absl::make_optional(*config) : absl::nullopt), device_mgr_(device_mgr), lib_def_(lib_def), default_thread_pool_(default_thread_pool), @@ -83,14 +84,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( session_metadata_(session_metadata) { if (device_mgr == nullptr) { (*flr_map_)[nullptr] = NewFunctionLibraryRuntime( - nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool, - optimizer_options, custom_kernel_creator, session_metadata_, this); + nullptr, env, config_ ? &(*config_) : nullptr, nullptr, + graph_def_version, lib_def_, default_thread_pool, optimizer_options, + custom_kernel_creator, session_metadata_, this); return; } for (Device* d : device_mgr->ListDevices()) { (*flr_map_)[d] = NewFunctionLibraryRuntime( - device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool, - optimizer_options, custom_kernel_creator, session_metadata_, this); + device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version, + lib_def_, default_thread_pool, optimizer_options, custom_kernel_creator, + session_metadata_, this); } DeviceMgr const* all_devices = device_mgr_; @@ -1330,9 +1333,9 @@ Status ProcessFunctionLibraryRuntime::Clone( *out_lib_def = absl::make_unique(*lib_def_); } *out_pflr = absl::make_unique( - device_mgr_, env, graph_def_version, out_lib_def->get(), - optimizer_options, default_thread_pool_, parent_, custom_kernel_creator, - session_metadata_); + device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version, + out_lib_def->get(), optimizer_options, default_thread_pool_, parent_, + custom_kernel_creator, session_metadata_); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index eea51be5273..b93d6965991 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" // clang-format on +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/function.h" @@ -41,8 +42,8 @@ class ProcessFunctionLibraryRuntime { // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent // (if provided) outlive this object. ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, thread::ThreadPool* thread_pool = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr, @@ -159,6 +160,8 @@ class ProcessFunctionLibraryRuntime { const DeviceSet* device_set() { return &device_set_; } + const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } + const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { return lib_def_; } @@ -381,6 +384,7 @@ class ProcessFunctionLibraryRuntime { mutable mutex mu_; Env* const env_; + const absl::optional config_; const DeviceMgr* const device_mgr_; DeviceSet device_set_; const FunctionLibraryDefinition* lib_def_; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 0294168dfa0..4b53a7efaa2 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -121,8 +121,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { OptimizerOptions opts; cluster_flr_.reset(new TestClusterFLR(device_mgr_.get())); proc_flr_.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, nullptr, cluster_flr_.get(), nullptr, session_metadata)); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(), + nullptr, session_metadata)); rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } @@ -295,8 +296,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) { OptimizerOptions opts; std::unique_ptr proc_flr( new ProcessFunctionLibraryRuntime( - nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION, - lib_def.get(), opts, nullptr, nullptr /* cluster_flr */)); + nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def.get(), opts)); FunctionLibraryRuntime* flr = proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); EXPECT_NE(flr, nullptr); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 105963f4047..ef8429c8d2a 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -128,9 +128,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, LOG(INFO) << "Creating " << (request->async() ? "async" : "sync") << " eager service context with rendezvous_id on host " << port::Hostname() << " " << worker_session->worker_name(); + SessionOptions opts; + opts.config = request->server_def().default_session_config(); tensorflow::EagerContext* ctx = new tensorflow::EagerContext( - SessionOptions(), - tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), device_mgr, false, r, GetDefaultCustomKernelCreator(), worker_session->cluster_flr()); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 1b7d869e763..dcd4d261ae3 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -234,7 +234,6 @@ TEST_F(EagerServiceImplTest, BasicTest) { TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); - EnqueueRequest remote_enqueue_request; remote_enqueue_request.set_context_id(context_id); EnqueueResponse remote_enqueue_response; @@ -400,8 +399,9 @@ TEST_F(EagerServiceImplTest, EagerPFLRTest) { auto device_mgr = absl::make_unique( DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1")); auto eager_pflr = absl::make_unique( - device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &func_lib_def, - OptimizerOptions(), nullptr, eager_cluster_flr.get(), nullptr); + device_mgr.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, &func_lib_def, OptimizerOptions(), nullptr, + eager_cluster_flr.get(), nullptr); tensorflow::FunctionDef fdef = MatMulFunction(); TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef)); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 316e98c22d2..2f7f604ae46 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -121,13 +121,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug( // // "executors" are filled with one executor per device if success and // the caller takes the ownership of returned executors. -Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, - WorkerSession* session, - const GraphOptions& graph_options, - const DebugOptions& debug_options, - int64 collective_graph_key, - DistributedFunctionLibraryRuntime* cluster_flr, - Item* item) { +Status GraphMgr::InitItem( + const string& handle, const GraphDef& gdef, WorkerSession* session, + const GraphOptions& graph_options, const DebugOptions& debug_options, + const ConfigProto& config_proto, int64 collective_graph_key, + DistributedFunctionLibraryRuntime* cluster_flr, Item* item) { item->session = handle; item->collective_graph_key = collective_graph_key; item->lib_def.reset( @@ -139,9 +137,10 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, // does that below. item->proc_flr.reset(new ProcessFunctionLibraryRuntime( - device_mgr_, worker_env_->env, gdef.versions().producer(), - item->lib_def.get(), graph_options.optimizer_options(), - worker_env_->compute_pool, cluster_flr)); + device_mgr_, worker_env_->env, /*config=*/&config_proto, + gdef.versions().producer(), item->lib_def.get(), + graph_options.optimizer_options(), worker_env_->compute_pool, + cluster_flr)); // Constructs the graph out of "gdef". Graph graph(OpRegistry::Global()); @@ -287,16 +286,14 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, return Status::OK(); } -Status GraphMgr::Register(const string& handle, const GraphDef& gdef, - WorkerSession* session, - const GraphOptions& graph_options, - const DebugOptions& debug_options, - int64 collective_graph_key, - DistributedFunctionLibraryRuntime* cluster_flr, - string* graph_handle) { +Status GraphMgr::Register( + const string& handle, const GraphDef& gdef, WorkerSession* session, + const GraphOptions& graph_options, const DebugOptions& debug_options, + const ConfigProto& config_proto, int64 collective_graph_key, + DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) { Item* item = new Item; Status s = InitItem(handle, gdef, session, graph_options, debug_options, - collective_graph_key, cluster_flr, item); + config_proto, collective_graph_key, cluster_flr, item); if (!s.ok()) { item->Unref(); return s; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 6ac7b7c3a51..e043c82f927 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -76,7 +76,8 @@ class GraphMgr { // reference to cluster_flr to do cross process function calls. Status Register(const string& handle, const GraphDef& gdef, WorkerSession* session, const GraphOptions& graph_options, - const DebugOptions& debug_options, int64 collective_graph_key, + const DebugOptions& debug_options, + const ConfigProto& config_proto, int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle); @@ -179,7 +180,8 @@ class GraphMgr { Status InitItem(const string& handle, const GraphDef& gdef, WorkerSession* session, const GraphOptions& graph_options, - const DebugOptions& debug_options, int64 collective_graph_key, + const DebugOptions& debug_options, + const ConfigProto& config_proto, int64 collective_graph_key, DistributedFunctionLibraryRuntime* cluster_flr, Item* item); Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 8d1532c28e4..e5d3c6ae354 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -472,6 +472,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( c->req.set_session_handle(session_handle_); c->req.set_create_worker_session_called(!should_deregister_); c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); + *c->req.mutable_config_proto() = session_opts_.config; *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = callable_opts_.run_options().debug_options(); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 908b3b9ff6f..7850ecc46b2 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -80,8 +80,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, s = session->graph_mgr()->Register( request->session_handle(), request->graph_def(), session.get(), request->graph_options(), request->debug_options(), - request->collective_graph_key(), session->cluster_flr(), - response->mutable_graph_handle()); + request->config_proto(), request->collective_graph_key(), + session->cluster_flr(), response->mutable_graph_handle()); } done(s); } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 71531dd2fae..406dd00c85e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -297,6 +297,7 @@ class IteratorContext { allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; + thread::ThreadPool* thread_pool = ctx->device()->tensorflow_device_thread_pool(); if (thread_pool) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 1ce05beed32..55a92059205 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -752,6 +752,9 @@ class FunctionLibraryRuntime { // Returns the environment on which the function executes. virtual Env* env() = 0; + // Returns the ConfigProto passed to the session used to create the function. + virtual const ConfigProto* const config_proto() = 0; + // Returns a debug string showing the definition of the function of // 'handle'. virtual string DebugString(Handle handle) = 0; diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 49958bd15b7..baf063eea74 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -256,7 +256,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, // Create the function library runtime. std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, + new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config, graph_def.versions().producer(), &function_library, *optimizer_opts)); FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name()); diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index e43c6826379..cf92ee5f3af 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -529,10 +529,11 @@ Status CapturedFunction::Instantiate( // The context's runtime will be used for all subsequent calls. FunctionLibraryRuntime* lib = ctx->flr(); FunctionLibraryRuntime::InstantiateOptions inst_opts; - // TODO(b/141576771): Propagate ctx ConfigProto into inst_opts.config_proto. inst_opts.lib_def = metadata_->lib_def(); inst_opts.create_kernels_eagerly = true; inst_opts.default_device_to_target = metadata_->use_default_device(); + inst_opts.config_proto = + lib->config_proto() ? *lib->config_proto() : ConfigProto(); if (!metadata_->use_inter_op_parallelism()) { inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; } diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 0f6e1c9f6bc..48d0ef6a86b 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -419,8 +419,9 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime( OptimizerOptions opts; pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, thread_pool_.get(), nullptr /* cluster_flr */); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(), + nullptr /* cluster_flr */); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); if (thread_pool_ == nullptr) { runner_ = [](std::function fn) { fn(); }; diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 053026c5c08..d989997660a 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -355,16 +355,18 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR( // in its resource manager. The existing device will outlive the // IteratorResource, because we are storing the IteratorResource // in that device's resource manager. + *device_mgr = absl::make_unique(RenamedDevice::NewRenamedDevice( ctx->device()->name(), down_cast(ctx->device()), false /* owns_underlying */, false /* isolate_session_state */)); *flib_def = absl::make_unique( *ctx->function_library()->GetFunctionLibraryDefinition()); + const auto* config = ctx->function_library()->config_proto(); *pflr = absl::make_unique( - device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(), - OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */, - nullptr /* TODO(mrry): ClusterFLR */); + device_mgr->get(), ctx->env(), + /*config=*/config, graph_def_version_, flib_def->get(), + config->graph_options().optimizer_options()); return (*pflr)->GetFLR(ctx->device()->name()); } diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 1b2862d3aef..03665989238 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -325,7 +325,14 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { AttrValueMap attr_values = func_.attr(); FunctionLibraryRuntime::InstantiateOptions instantiate_opts; - // TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto. + + const auto* config = (ctx->function_library()) + ? ctx->function_library()->config_proto() + : nullptr; + if (config) { + instantiate_opts.config_proto = *config; + } + instantiate_opts.target = target_device; FunctionTarget function_target = {target_device, lib}; diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 59e695159f3..62171dbaa7f 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -29,8 +29,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type, device_ = device.get(); device_mgr_ = absl::make_unique(std::move(device)); pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), - OptimizerOptions()); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions()); device_type_ = device_type; #ifdef GOOGLE_CUDA diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index 81089d8328a..2b4a1a7ccab 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -84,8 +84,8 @@ class OpsTestBase : public ::testing::Test { flib_def_ = absl::make_unique( OpRegistry::Global(), FunctionDefLibrary{}); pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, - flib_def_.get(), OptimizerOptions()); + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions()); } ~OpsTestBase() override { diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 50f94e2364b..a65d3c685db 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -166,7 +166,12 @@ Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib, std::vector* inputs, FunctionLibraryRuntime::Handle* handle) { FunctionLibraryRuntime::InstantiateOptions opts; - // TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto. + const auto* config = (ctx->function_library()) + ? ctx->function_library()->config_proto() + : nullptr; + if (config) { + opts.config_proto = *config; + } #ifndef __ANDROID__ // Android tf library does not include grappler. diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 8b2afd7bf36..3a8fbcfb015 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -132,6 +132,11 @@ message RegisterGraphRequest { // concurrently so that BufRendezvous entries will make the correct // values accessible. int64 collective_graph_key = 7; + + // ConfigProto from the session in which this graph was created. + // Contains additional parameters beyond graph_options, including + // the name of the requested executor. + ConfigProto config_proto = 8; } message RegisterGraphResponse { diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index f351c907fef..fa8cc9799e1 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2237,8 +2237,9 @@ bool InlineAllFunctions(GraphDef* graphdef) { tensorflow::StaticDeviceMgr device_mgr(std::move(devices)); tensorflow::OptimizerOptions o_opts; tensorflow::ProcessFunctionLibraryRuntime pflr( - &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, - o_opts, nullptr); + &device_mgr, tensorflow::Env::Default(), &options.config, + TF_GRAPH_DEF_VERSION, &fld, + options.config.graph_options().optimizer_options(), nullptr); tensorflow::FunctionLibraryRuntime* flr; flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index ce89a8a83b3..107c5ee4801 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -124,8 +124,11 @@ def connect_to_cluster(cluster_spec_or_resolver, job_def.tasks[0] = "localhost:{}".format(local_port) server_def = ServerDef( - cluster=cluster_def, job_name=job_name, task_index=task_index, - protocol=protocol) + cluster=cluster_def, + job_name=job_name, + task_index=task_index, + protocol=protocol, + default_session_config=context.context().config) context.set_server_def(server_def)