[XLA] Avoid quadratic behavior in DevicesForSharding
PiperOrigin-RevId: 333652350 Change-Id: I8f0b73f5a584cfcb462a2083ede813331c50c90e
This commit is contained in:
parent
7e6061037f
commit
7893e4bcc1
@ -477,6 +477,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/compiler/xla/array.h"
|
#include "tensorflow/compiler/xla/array.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
@ -816,29 +817,51 @@ IdentityValueAndHloOpcodeForScatterReduceComputation(
|
|||||||
"add/or/multiply/add/min/max");
|
"add/or/multiply/add/min/max");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64> DevicesForSharding(
|
namespace {
|
||||||
const HloSharding& sharding, const std::vector<int64>& available_devices) {
|
|
||||||
std::vector<int64> devices;
|
void DevicesForShardingInternal(
|
||||||
if (sharding.IsReplicated()) {
|
const HloSharding& sharding,
|
||||||
for (int64 d : available_devices) {
|
const absl::flat_hash_set<int64>& available_devices,
|
||||||
if (!HloSharding::IsReservedDevice(d)) {
|
absl::flat_hash_set<int64>* used) {
|
||||||
devices.push_back(d);
|
if (sharding.IsTuple()) {
|
||||||
}
|
for (const auto& subsharding : sharding.tuple_elements()) {
|
||||||
|
DevicesForShardingInternal(subsharding, available_devices, used);
|
||||||
}
|
}
|
||||||
return devices;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64 i : available_devices) {
|
if (sharding.IsReplicated()) {
|
||||||
if (sharding.UsesDevice(i)) {
|
for (int64 device : available_devices) {
|
||||||
devices.push_back(i);
|
if (!HloSharding::IsReservedDevice(device)) {
|
||||||
|
used->insert(device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
DCHECK(std::all_of(
|
||||||
|
sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
|
||||||
|
[&](int64 device) { return available_devices.contains(device); }));
|
||||||
|
sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
|
||||||
|
int64 device) { used->insert(device); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::vector<int64> DevicesForSharding(
|
||||||
|
const HloSharding& sharding, const std::vector<int64>& available_devices) {
|
||||||
|
absl::flat_hash_set<int64> available_set;
|
||||||
|
for (int64 device : available_devices) {
|
||||||
|
available_set.insert(device);
|
||||||
|
}
|
||||||
|
absl::flat_hash_set<int64> used_set;
|
||||||
|
DevicesForShardingInternal(sharding, available_set, &used_set);
|
||||||
|
std::vector<int64> devices;
|
||||||
|
for (int64 device : available_devices) {
|
||||||
|
if (used_set.contains(device)) {
|
||||||
|
devices.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DCHECK(std::all_of(sharding.tile_assignment().begin(),
|
|
||||||
sharding.tile_assignment().end(), [&](int64 device) {
|
|
||||||
return std::find(available_devices.begin(),
|
|
||||||
available_devices.end(),
|
|
||||||
device) != available_devices.end();
|
|
||||||
}));
|
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user