diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index cfbcde82ce2..fbae80aef55 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -42,7 +42,9 @@ class ReplicateHelper { Node* replicated_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); replicated_node->set_assigned_device_name(device); - replicated_node->AddAttr("sub_index", i); + if (replicated_node->IsArg()) { + replicated_node->AddAttr("sub_index", i); + } replicated_nodes[i] = replicated_node; } replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); @@ -214,7 +216,9 @@ Status ReplicatePerReplicaNodesInFunctionGraph( // Reuse the original nodes if there is only one allowed device. for (Node* n : cluster_nodes) { n->set_assigned_device_name(allowed_devices.at(0)); - n->AddAttr("sub_index", 0); + if (n->IsArg()) { + n->AddAttr("sub_index", 0); + } } continue; }