Only add sub_index to _Arg nodes.
PiperOrigin-RevId: 311604877 Change-Id: Ib7c941b38e6ea38378bd4d9d44dc1d262ee6dd4a
This commit is contained in:
parent
ec7ea83d9d
commit
501309eef9
|
@ -42,7 +42,9 @@ class ReplicateHelper {
|
|||
Node* replicated_node = graph->AddNode(node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
replicated_node->set_assigned_device_name(device);
|
||||
if (replicated_node->IsArg()) {
|
||||
replicated_node->AddAttr("sub_index", i);
|
||||
}
|
||||
replicated_nodes[i] = replicated_node;
|
||||
}
|
||||
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
||||
|
@ -214,8 +216,10 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
|||
// Reuse the original nodes if there is only one allowed device.
|
||||
for (Node* n : cluster_nodes) {
|
||||
n->set_assigned_device_name(allowed_devices.at(0));
|
||||
if (n->IsArg()) {
|
||||
n->AddAttr("sub_index", 0);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
ReplicateHelper helper;
|
||||
|
|
Loading…
Reference in New Issue