Fix indexing of _TPUReplicate node distributed and variable arguments when determining argument sharding to TPU computation.

An offset of (num_replicas - 1) * num_per_replica_args is added for indexing non per replica arguments. Per replica argument indices do not need to be offset as per replica arguments are ordered by replica first followed by argument per replica and the first replica's inputs can be used for determining sharding across all replicas.

PiperOrigin-RevId: 356898829
Change-Id: Ie17e89c1c2cfe1468bc928712d562f0a686c9a95
This commit is contained in:
Andy Ly 2021-02-10 21:52:59 -08:00 committed by TensorFlower Gardener
parent 69f5aecd05
commit 24720e5940

View File

@ -1974,6 +1974,16 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
replicate_inputs_outputs_by_default_for_xla_spmd_) &&
allow_parameter_replication_for_spmd;
// Offset _TPUReplicate non per replica argument indices by
// (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
// constructed with all per replica args across all replicas while the
// encapsulated function only has 1 replica's per replica args. Per replica
// args are ordered by replica first, so the index here does not require an
// offset and the first replica's input nodes is sufficient for determining
// argument sharding.
const int index_offset =
(params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
for (int i = 0; i < args.size(); ++i) {
const Node* n = args[i];
absl::optional<int64> assigned_core;
@ -1983,9 +1993,11 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
*n, num_cores_per_replica, flr, &cached_function_handles,
&node_and_sharding, &is_fast_mem));
if (params_info.IsPerReplicaArg(i) || params_info.IsDistributedArg(i)) {
const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
Node* input_node;
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
TF_RETURN_IF_ERROR(replicate_node->input_node(
i + (is_per_replica_arg ? 0 : index_offset), &input_node));
if (input_node->type_string() == kTPUPartitionedInput) {
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
@ -2002,7 +2014,8 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
if (params_info.IsVariableArg(i)) {
Node* input_node;
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
TF_RETURN_IF_ERROR(
replicate_node->input_node(i + index_offset, &input_node));
if (input_node->type_string() == kVarHandleOp) {
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,