diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index f21820d13e2..b7c71dc5cba 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -327,8 +327,8 @@ Status RemoveIdentityNodesForArgRetval(Graph* g) { } // Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when -// 'new_num_per_replicate_inputs' are added to the inputs of `xla_node`. -Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs, +// 'additional_per_replicate_inputs' are added to the inputs of `xla_node`. +Status UpdateMirroredVariableIndices(int additional_per_replica_inputs, Node* xla_node) { std::vector mirrored_variable_indices; if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) != @@ -340,7 +340,7 @@ Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs, if (!mirrored_variable_indices.empty()) { for (int i = 0; i < mirrored_variable_indices.size(); ++i) - mirrored_variable_indices[i] += new_num_per_replica_inputs; + mirrored_variable_indices[i] += additional_per_replica_inputs; xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR); xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR, mirrored_variable_indices); @@ -567,8 +567,8 @@ Status MoveHeadOutsideCompilationToHost( xla_node->ClearAttr("Tinputs"); xla_node->AddAttr("Tinputs", new_input_types); - TF_RETURN_IF_ERROR( - UpdateMirroredVariableIndices(new_num_per_replica_inputs, xla_node)); + TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices( + /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node)); int new_variable_start_index = num_new_per_replica_input_types / num_replicas + num_distributed_vars +