Only add sub_index to _Arg nodes.
PiperOrigin-RevId: 311604877 Change-Id: Ib7c941b38e6ea38378bd4d9d44dc1d262ee6dd4a
This commit is contained in:
parent
ec7ea83d9d
commit
501309eef9
|
@ -42,7 +42,9 @@ class ReplicateHelper {
|
||||||
Node* replicated_node = graph->AddNode(node_def, &status);
|
Node* replicated_node = graph->AddNode(node_def, &status);
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
replicated_node->set_assigned_device_name(device);
|
replicated_node->set_assigned_device_name(device);
|
||||||
|
if (replicated_node->IsArg()) {
|
||||||
replicated_node->AddAttr("sub_index", i);
|
replicated_node->AddAttr("sub_index", i);
|
||||||
|
}
|
||||||
replicated_nodes[i] = replicated_node;
|
replicated_nodes[i] = replicated_node;
|
||||||
}
|
}
|
||||||
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
||||||
|
@ -214,8 +216,10 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||||
// Reuse the original nodes if there is only one allowed device.
|
// Reuse the original nodes if there is only one allowed device.
|
||||||
for (Node* n : cluster_nodes) {
|
for (Node* n : cluster_nodes) {
|
||||||
n->set_assigned_device_name(allowed_devices.at(0));
|
n->set_assigned_device_name(allowed_devices.at(0));
|
||||||
|
if (n->IsArg()) {
|
||||||
n->AddAttr("sub_index", 0);
|
n->AddAttr("sub_index", 0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ReplicateHelper helper;
|
ReplicateHelper helper;
|
||||||
|
|
Loading…
Reference in New Issue