[CostModel] Generating CostGraphDef with Global Index

This commit is contained in:
Xinan Jiang 2020-11-12 15:29:13 +08:00
parent 4c7aeb0a0c
commit 20d3933a57
2 changed files with 12 additions and 3 deletions
tensorflow/core/graph

View File

@ -479,11 +479,12 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
CostGraphDef* cost_graph) const {
std::vector<const Edge*> inputs;
std::vector<const Edge*> control_inputs;
int offset = cost_graph->node_size();
for (const Node* n : graph->nodes()) {
CostGraphDef::Node* cnode = cost_graph->add_node();
cnode->set_name(n->name());
cnode->set_device(n->assigned_device_name());
cnode->set_id(Id(n));
cnode->set_id(GlobalId(n, offset));
inputs.clear();
inputs.resize(n->num_inputs(), nullptr);
@ -502,7 +503,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
for (const Edge* e : inputs) {
CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
input_info->set_preceding_node(Id(e->src()));
input_info->set_preceding_node(GlobalId(e->src(), offset));
input_info->set_preceding_port(e->src_output());
}
@ -528,7 +529,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
}
for (const Edge* e : control_inputs) {
cnode->add_control_input(Id(e->src()));
cnode->add_control_input(GlobalId(e->src(), offset));
}
cnode->set_temporary_memory_size(TempMemorySize(n).value());

View File

@ -66,6 +66,14 @@ class CostModel {
}
}
inline int GlobalId(const Node* n, int offset) const {
if (is_global_) {
return n->cost_id();
} else {
return n->id() + offset;
}
}
// Initializes cost model for 'g'.
void InitFromGraph(const Graph& g);