diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index dc2731c0da0..2cf034061ca 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1342,6 +1342,7 @@ Status DirectSession::CreateGraphs( // Just return '1'. return 1; }; + popts.flib_def = &client_graph->graph.flib_def(); popts.control_flow_added = false; std::unordered_map partitions; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 7f77bf8b4ef..dd7af86e647 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -155,6 +155,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, return PartitionOptions::kIllegalIncarnation; } }; + popts.flib_def = &graph.flib_def(); popts.control_flow_added = true; popts.scheduling_for_recvs = graph_options.enable_recv_scheduling(); TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions)); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 361e89290d2..920a9c53b28 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -76,7 +76,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { debug_opts_(bopts.debug_options), worker_cache_(worker_cache) { VLOG(1) << "Created ReffedClientGraph for node with " - << client_graph_->graph.num_node_ids(); + << client_graph()->graph.num_node_ids(); stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts); @@ -166,8 +166,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Partitions the graph into subgraphs and registers them on // workers. - Status RegisterPartitions(const PartitionOptions& popts, - const FunctionLibraryDefinition& flib_def); + Status RegisterPartitions(const PartitionOptions& popts); // Runs one step of all partitions. Status RunPartitions(const MasterEnv* env, int64 step_id, @@ -263,7 +262,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { PartitionOptions pots, std::unordered_map* out_partitions); Status DoRegisterPartitions( - const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib, + const PartitionOptions& popts, std::unordered_map graph_partitions); // Deregisters the partitions on the workers. Called in the @@ -274,7 +273,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { }; Status MasterSession::ReffedClientGraph::RegisterPartitions( - const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) { + const PartitionOptions& popts) { { // Ensure register once. mu_.lock(); if (!init_started_) { @@ -293,8 +292,7 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( graph_defs_for_publishing.push_back(&name_def.second); } stats_publisher_->PublishGraphProto(graph_defs_for_publishing); - s = DoRegisterPartitions(popts, flib_def.ToProto(), - std::move(graph_defs)); + s = DoRegisterPartitions(popts, std::move(graph_defs)); } mu_.lock(); init_result_ = s; @@ -374,7 +372,7 @@ Status MasterSession::ReffedClientGraph::DoBuildPartitions( } Status MasterSession::ReffedClientGraph::DoRegisterPartitions( - const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib, + const PartitionOptions& popts, std::unordered_map graph_partitions) { partitions_.reserve(graph_partitions.size()); Status s; @@ -408,8 +406,6 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( Call* c = &calls[i]; c->req.set_session_handle(session_handle_); c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); - // For simplicity, we ship the library completely to every worker. - *c->req.mutable_graph_def()->mutable_library() = func_def_lib; *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = debug_opts_; VLOG(2) << "Register " << c->req.graph_def().DebugString(); @@ -1305,6 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { mutex_lock l(mu_); return strings::StrCat(prefix, "_S", next_node_id_++); }; + popts.flib_def = rcg->client_graph()->flib_def.get(); popts.get_incarnation = [this](const string& name) -> int64 { Device* d = devices_->FindDeviceByName(name); if (d == nullptr) { @@ -1332,8 +1329,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { popts.need_to_record_start_times = true; } - TF_RETURN_IF_ERROR( - rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def)); + TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts)); return Status::OK(); } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index bf8dcb2fcf2..71d8cdd6ab5 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1165,11 +1165,16 @@ Status Partition(const PartitionOptions& opts, Graph* g, } } + const FunctionLibraryDefinition* flib_def = opts.flib_def; + if (flib_def == nullptr) { + flib_def = &g->flib_def(); + } + // Set versions, function library and send/recv incarnation. for (auto& it : *partitions) { GraphDef* gdef = &it.second; *gdef->mutable_versions() = g->versions(); - *gdef->mutable_library() = g->flib_def().ToProto(); + *gdef->mutable_library() = flib_def->ToProto(); // Traverse the graph to fill every send/recv op's incarnation // information. diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h index 8820b8821f2..67fafddd519 100644 --- a/tensorflow/core/graph/graph_partition.h +++ b/tensorflow/core/graph/graph_partition.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph.h" @@ -45,6 +46,10 @@ struct PartitionOptions { typedef std::function GetIncarnationFunc; GetIncarnationFunc get_incarnation = nullptr; + // If specified, flib_def defines a function library that should be + // partitioned and replicated into each resulting partition graphs. + const FunctionLibraryDefinition* flib_def = nullptr; + // True if all the control flow "code" has already been added. The // control flow code needs to be added when we still have the entire // graph before any partitioning. So this flag should be false for