[XLA] Avoid quadratic behavior in DevicesForSharding
PiperOrigin-RevId: 333652350 Change-Id: I8f0b73f5a584cfcb462a2083ede813331c50c90e
This commit is contained in:
parent
7e6061037f
commit
7893e4bcc1
@ -477,6 +477,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
@ -816,29 +817,51 @@ IdentityValueAndHloOpcodeForScatterReduceComputation(
|
||||
"add/or/multiply/add/min/max");
|
||||
}
|
||||
|
||||
std::vector<int64> DevicesForSharding(
|
||||
const HloSharding& sharding, const std::vector<int64>& available_devices) {
|
||||
std::vector<int64> devices;
|
||||
if (sharding.IsReplicated()) {
|
||||
for (int64 d : available_devices) {
|
||||
if (!HloSharding::IsReservedDevice(d)) {
|
||||
devices.push_back(d);
|
||||
namespace {
|
||||
|
||||
void DevicesForShardingInternal(
|
||||
const HloSharding& sharding,
|
||||
const absl::flat_hash_set<int64>& available_devices,
|
||||
absl::flat_hash_set<int64>* used) {
|
||||
if (sharding.IsTuple()) {
|
||||
for (const auto& subsharding : sharding.tuple_elements()) {
|
||||
DevicesForShardingInternal(subsharding, available_devices, used);
|
||||
}
|
||||
}
|
||||
return devices;
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64 i : available_devices) {
|
||||
if (sharding.UsesDevice(i)) {
|
||||
devices.push_back(i);
|
||||
if (sharding.IsReplicated()) {
|
||||
for (int64 device : available_devices) {
|
||||
if (!HloSharding::IsReservedDevice(device)) {
|
||||
used->insert(device);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
DCHECK(std::all_of(
|
||||
sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
|
||||
[&](int64 device) { return available_devices.contains(device); }));
|
||||
sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
|
||||
int64 device) { used->insert(device); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<int64> DevicesForSharding(
|
||||
const HloSharding& sharding, const std::vector<int64>& available_devices) {
|
||||
absl::flat_hash_set<int64> available_set;
|
||||
for (int64 device : available_devices) {
|
||||
available_set.insert(device);
|
||||
}
|
||||
absl::flat_hash_set<int64> used_set;
|
||||
DevicesForShardingInternal(sharding, available_set, &used_set);
|
||||
std::vector<int64> devices;
|
||||
for (int64 device : available_devices) {
|
||||
if (used_set.contains(device)) {
|
||||
devices.push_back(device);
|
||||
}
|
||||
}
|
||||
DCHECK(std::all_of(sharding.tile_assignment().begin(),
|
||||
sharding.tile_assignment().end(), [&](int64 device) {
|
||||
return std::find(available_devices.begin(),
|
||||
available_devices.end(),
|
||||
device) != available_devices.end();
|
||||
}));
|
||||
return devices;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user