Add AllPermute interface to Dtensor to support halo exchange.
PiperOrigin-RevId: 332947906 Change-Id: Ib7c80eb201441d4a4f9399797bc25b26a5c85661
This commit is contained in:
parent
3f94cbfa73
commit
d526d49e19
@ -71,6 +71,9 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
||||
case GATHER_COLLECTIVE:
|
||||
return "RingGather";
|
||||
|
||||
case PERMUTE_COLLECTIVE:
|
||||
return "Permute";
|
||||
|
||||
default:
|
||||
return "undef";
|
||||
}
|
||||
|
@ -85,6 +85,8 @@ CollInstanceParams& CollInstanceParams::operator=(
|
||||
other.impl_details.subdiv_source_rank.begin(),
|
||||
other.impl_details.subdiv_source_rank.end());
|
||||
impl_details.dependencies = other.impl_details.dependencies;
|
||||
devices.assign(other.devices.begin(), other.devices.end());
|
||||
permutation.assign(other.permutation.begin(), other.permutation.end());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -125,8 +127,18 @@ string CollInstanceParams::ToString() const {
|
||||
strings::StrAppend(&v, r, ",");
|
||||
}
|
||||
strings::StrAppend(&v, "}");
|
||||
} // all subdivs
|
||||
if (type == PERMUTE_COLLECTIVE) {
|
||||
strings::StrAppend(&v, "}, permute_devices {");
|
||||
for (const auto& d : devices) {
|
||||
strings::StrAppend(&v, d, ",");
|
||||
}
|
||||
strings::StrAppend(&v, "}, permute_permutation {");
|
||||
for (const auto& p : permutation) {
|
||||
strings::StrAppend(&v, p, ",");
|
||||
}
|
||||
strings::StrAppend(&v, "}");
|
||||
}
|
||||
strings::StrAppend(&v, "}"); // all subdivs
|
||||
return v;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user