Add AllPermute interface to Dtensor to support halo exchange.

PiperOrigin-RevId: 332947906
Change-Id: Ib7c80eb201441d4a4f9399797bc25b26a5c85661
This commit is contained in:
A. Unique TensorFlower 2020-09-21 15:27:18 -07:00 committed by TensorFlower Gardener
parent 3f94cbfa73
commit d526d49e19
2 changed files with 16 additions and 1 deletions

View File

@ -71,6 +71,9 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
case GATHER_COLLECTIVE:
return "RingGather";
case PERMUTE_COLLECTIVE:
return "Permute";
default:
return "undef";
}

View File

@ -85,6 +85,8 @@ CollInstanceParams& CollInstanceParams::operator=(
other.impl_details.subdiv_source_rank.begin(),
other.impl_details.subdiv_source_rank.end());
impl_details.dependencies = other.impl_details.dependencies;
devices.assign(other.devices.begin(), other.devices.end());
permutation.assign(other.permutation.begin(), other.permutation.end());
}
return *this;
}
@ -125,8 +127,18 @@ string CollInstanceParams::ToString() const {
strings::StrAppend(&v, r, ",");
}
strings::StrAppend(&v, "}");
} // all subdivs
if (type == PERMUTE_COLLECTIVE) {
strings::StrAppend(&v, "}, permute_devices {");
for (const auto& d : devices) {
strings::StrAppend(&v, d, ",");
}
strings::StrAppend(&v, "}, permute_permutation {");
for (const auto& p : permutation) {
strings::StrAppend(&v, p, ",");
}
strings::StrAppend(&v, "}");
}
strings::StrAppend(&v, "}"); // all subdivs
return v;
}