Fix UpdatedMirroredIndices to take in the additional per_replica inputs and not the new total per replica inputs.
PiperOrigin-RevId: 337605082 Change-Id: Id30e7fa068131d3e5f32bf93b0d8828e8c4e8d97
This commit is contained in:
parent
25fcf6ef90
commit
e89bb384c7
@ -327,8 +327,8 @@ Status RemoveIdentityNodesForArgRetval(Graph* g) {
|
||||
}
|
||||
|
||||
// Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
|
||||
// 'new_num_per_replicate_inputs' are added to the inputs of `xla_node`.
|
||||
Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs,
|
||||
// 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
|
||||
Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
|
||||
Node* xla_node) {
|
||||
std::vector<int> mirrored_variable_indices;
|
||||
if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
|
||||
@ -340,7 +340,7 @@ Status UpdateMirroredVariableIndices(int new_num_per_replica_inputs,
|
||||
|
||||
if (!mirrored_variable_indices.empty()) {
|
||||
for (int i = 0; i < mirrored_variable_indices.size(); ++i)
|
||||
mirrored_variable_indices[i] += new_num_per_replica_inputs;
|
||||
mirrored_variable_indices[i] += additional_per_replica_inputs;
|
||||
xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
|
||||
xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
|
||||
mirrored_variable_indices);
|
||||
@ -567,8 +567,8 @@ Status MoveHeadOutsideCompilationToHost(
|
||||
xla_node->ClearAttr("Tinputs");
|
||||
xla_node->AddAttr("Tinputs", new_input_types);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateMirroredVariableIndices(new_num_per_replica_inputs, xla_node));
|
||||
TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
|
||||
/*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
|
||||
|
||||
int new_variable_start_index =
|
||||
num_new_per_replica_input_types / num_replicas + num_distributed_vars +
|
||||
|
Loading…
Reference in New Issue
Block a user