[CostModel] Generating CostGraphDef with Global Index
This commit is contained in:
parent
4c7aeb0a0c
commit
20d3933a57
tensorflow/core/graph
@ -479,11 +479,12 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
CostGraphDef* cost_graph) const {
|
||||
std::vector<const Edge*> inputs;
|
||||
std::vector<const Edge*> control_inputs;
|
||||
int offset = cost_graph->node_size();
|
||||
for (const Node* n : graph->nodes()) {
|
||||
CostGraphDef::Node* cnode = cost_graph->add_node();
|
||||
cnode->set_name(n->name());
|
||||
cnode->set_device(n->assigned_device_name());
|
||||
cnode->set_id(Id(n));
|
||||
cnode->set_id(GlobalId(n, offset));
|
||||
|
||||
inputs.clear();
|
||||
inputs.resize(n->num_inputs(), nullptr);
|
||||
@ -502,7 +503,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
|
||||
for (const Edge* e : inputs) {
|
||||
CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
|
||||
input_info->set_preceding_node(Id(e->src()));
|
||||
input_info->set_preceding_node(GlobalId(e->src(), offset));
|
||||
input_info->set_preceding_port(e->src_output());
|
||||
}
|
||||
|
||||
@ -528,7 +529,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
}
|
||||
|
||||
for (const Edge* e : control_inputs) {
|
||||
cnode->add_control_input(Id(e->src()));
|
||||
cnode->add_control_input(GlobalId(e->src(), offset));
|
||||
}
|
||||
|
||||
cnode->set_temporary_memory_size(TempMemorySize(n).value());
|
||||
|
@ -66,6 +66,14 @@ class CostModel {
|
||||
}
|
||||
}
|
||||
|
||||
inline int GlobalId(const Node* n, int offset) const {
|
||||
if (is_global_) {
|
||||
return n->cost_id();
|
||||
} else {
|
||||
return n->id() + offset;
|
||||
}
|
||||
}
|
||||
|
||||
// Initializes cost model for 'g'.
|
||||
void InitFromGraph(const Graph& g);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user