Only add sub_index to _Arg nodes.

PiperOrigin-RevId: 311604877
Change-Id: Ib7c941b38e6ea38378bd4d9d44dc1d262ee6dd4a
This commit is contained in:
Yujing Zhang 2020-05-14 14:22:42 -07:00 committed by TensorFlower Gardener
parent ec7ea83d9d
commit 501309eef9
1 changed files with 6 additions and 2 deletions

View File

@ -42,7 +42,9 @@ class ReplicateHelper {
Node* replicated_node = graph->AddNode(node_def, &status); Node* replicated_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
replicated_node->set_assigned_device_name(device); replicated_node->set_assigned_device_name(device);
if (replicated_node->IsArg()) {
replicated_node->AddAttr("sub_index", i); replicated_node->AddAttr("sub_index", i);
}
replicated_nodes[i] = replicated_node; replicated_nodes[i] = replicated_node;
} }
replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
@ -214,8 +216,10 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
// Reuse the original nodes if there is only one allowed device. // Reuse the original nodes if there is only one allowed device.
for (Node* n : cluster_nodes) { for (Node* n : cluster_nodes) {
n->set_assigned_device_name(allowed_devices.at(0)); n->set_assigned_device_name(allowed_devices.at(0));
if (n->IsArg()) {
n->AddAttr("sub_index", 0); n->AddAttr("sub_index", 0);
} }
}
continue; continue;
} }
ReplicateHelper helper; ReplicateHelper helper;