Consolidating the code to fill the partition's function library
into one place. Previously, Partition() and MasterSession::RegisterPartition() both fills in the partitioned graph's function library. PiperOrigin-RevId: 163400992
This commit is contained in:
parent
28373cfe70
commit
6b7314de49
tensorflow/core
common_runtime
distributed_runtime
graph
@ -1342,6 +1342,7 @@ Status DirectSession::CreateGraphs(
|
||||
// Just return '1'.
|
||||
return 1;
|
||||
};
|
||||
popts.flib_def = &client_graph->graph.flib_def();
|
||||
popts.control_flow_added = false;
|
||||
|
||||
std::unordered_map<string, GraphDef> partitions;
|
||||
|
@ -155,6 +155,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
return PartitionOptions::kIllegalIncarnation;
|
||||
}
|
||||
};
|
||||
popts.flib_def = &graph.flib_def();
|
||||
popts.control_flow_added = true;
|
||||
popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
|
||||
TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
|
||||
|
@ -76,7 +76,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
debug_opts_(bopts.debug_options),
|
||||
worker_cache_(worker_cache) {
|
||||
VLOG(1) << "Created ReffedClientGraph for node with "
|
||||
<< client_graph_->graph.num_node_ids();
|
||||
<< client_graph()->graph.num_node_ids();
|
||||
|
||||
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
|
||||
|
||||
@ -166,8 +166,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
|
||||
// Partitions the graph into subgraphs and registers them on
|
||||
// workers.
|
||||
Status RegisterPartitions(const PartitionOptions& popts,
|
||||
const FunctionLibraryDefinition& flib_def);
|
||||
Status RegisterPartitions(const PartitionOptions& popts);
|
||||
|
||||
// Runs one step of all partitions.
|
||||
Status RunPartitions(const MasterEnv* env, int64 step_id,
|
||||
@ -263,7 +262,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
PartitionOptions pots,
|
||||
std::unordered_map<string, GraphDef>* out_partitions);
|
||||
Status DoRegisterPartitions(
|
||||
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib,
|
||||
const PartitionOptions& popts,
|
||||
std::unordered_map<string, GraphDef> graph_partitions);
|
||||
|
||||
// Deregisters the partitions on the workers. Called in the
|
||||
@ -274,7 +273,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
};
|
||||
|
||||
Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
||||
const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
|
||||
const PartitionOptions& popts) {
|
||||
{ // Ensure register once.
|
||||
mu_.lock();
|
||||
if (!init_started_) {
|
||||
@ -293,8 +292,7 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
||||
graph_defs_for_publishing.push_back(&name_def.second);
|
||||
}
|
||||
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
|
||||
s = DoRegisterPartitions(popts, flib_def.ToProto(),
|
||||
std::move(graph_defs));
|
||||
s = DoRegisterPartitions(popts, std::move(graph_defs));
|
||||
}
|
||||
mu_.lock();
|
||||
init_result_ = s;
|
||||
@ -374,7 +372,7 @@ Status MasterSession::ReffedClientGraph::DoBuildPartitions(
|
||||
}
|
||||
|
||||
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
||||
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib,
|
||||
const PartitionOptions& popts,
|
||||
std::unordered_map<string, GraphDef> graph_partitions) {
|
||||
partitions_.reserve(graph_partitions.size());
|
||||
Status s;
|
||||
@ -408,8 +406,6 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
||||
Call* c = &calls[i];
|
||||
c->req.set_session_handle(session_handle_);
|
||||
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
|
||||
// For simplicity, we ship the library completely to every worker.
|
||||
*c->req.mutable_graph_def()->mutable_library() = func_def_lib;
|
||||
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
|
||||
*c->req.mutable_debug_options() = debug_opts_;
|
||||
VLOG(2) << "Register " << c->req.graph_def().DebugString();
|
||||
@ -1305,6 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat(prefix, "_S", next_node_id_++);
|
||||
};
|
||||
popts.flib_def = rcg->client_graph()->flib_def.get();
|
||||
popts.get_incarnation = [this](const string& name) -> int64 {
|
||||
Device* d = devices_->FindDeviceByName(name);
|
||||
if (d == nullptr) {
|
||||
@ -1332,8 +1329,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
|
||||
popts.need_to_record_start_times = true;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
|
||||
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1165,11 +1165,16 @@ Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
}
|
||||
}
|
||||
|
||||
const FunctionLibraryDefinition* flib_def = opts.flib_def;
|
||||
if (flib_def == nullptr) {
|
||||
flib_def = &g->flib_def();
|
||||
}
|
||||
|
||||
// Set versions, function library and send/recv incarnation.
|
||||
for (auto& it : *partitions) {
|
||||
GraphDef* gdef = &it.second;
|
||||
*gdef->mutable_versions() = g->versions();
|
||||
*gdef->mutable_library() = g->flib_def().ToProto();
|
||||
*gdef->mutable_library() = flib_def->ToProto();
|
||||
|
||||
// Traverse the graph to fill every send/recv op's incarnation
|
||||
// information.
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -45,6 +46,10 @@ struct PartitionOptions {
|
||||
typedef std::function<uint64(const string&)> GetIncarnationFunc;
|
||||
GetIncarnationFunc get_incarnation = nullptr;
|
||||
|
||||
// If specified, flib_def defines a function library that should be
|
||||
// partitioned and replicated into each resulting partition graphs.
|
||||
const FunctionLibraryDefinition* flib_def = nullptr;
|
||||
|
||||
// True if all the control flow "code" has already been added. The
|
||||
// control flow code needs to be added when we still have the entire
|
||||
// graph before any partitioning. So this flag should be false for
|
||||
|
Loading…
Reference in New Issue
Block a user