Fix UpdatedMirroredIndices to take in the additional per_replica inputs and not the new total per replica inputs.

PiperOrigin-RevId: 337605082
Change-Id: Id30e7fa068131d3e5f32bf93b0d8828e8c4e8d97
This commit is contained in:
Ken Franko 2020-10-16 17:14:26 -07:00 committed by TensorFlower Gardener
parent 25fcf6ef90
commit e89bb384c7

View File

@ -327,8 +327,8 @@ Status RemoveIdentityNodesForArgRetval(Graph* g) {
}
// Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
// 'new_num_per_replicate_inputs' are added to the inputs of `xla_node`.
Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs,
// 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
Node* xla_node) {
std::vector<int> mirrored_variable_indices;
if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
@ -340,7 +340,7 @@ Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs,
if (!mirrored_variable_indices.empty()) {
for (int i = 0; i < mirrored_variable_indices.size(); ++i)
mirrored_variable_indices[i] += new_num_per_replica_inputs;
mirrored_variable_indices[i] += additional_per_replica_inputs;
xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
mirrored_variable_indices);
@ -567,8 +567,8 @@ Status MoveHeadOutsideCompilationToHost(
xla_node->ClearAttr("Tinputs");
xla_node->AddAttr("Tinputs", new_input_types);
TF_RETURN_IF_ERROR(
UpdateMirroredVariableIndices(new_num_per_replica_inputs, xla_node));
TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
/*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
int new_variable_start_index =
num_new_per_replica_input_types / num_replicas + num_distributed_vars +