[XLA] Avoid quadratic behavior in DevicesForSharding

PiperOrigin-RevId: 333652350
Change-Id: I8f0b73f5a584cfcb462a2083ede813331c50c90e
This commit is contained in:
Yuanzhong Xu 2020-09-24 20:01:03 -07:00 committed by TensorFlower Gardener
parent 7e6061037f
commit 7893e4bcc1
2 changed files with 42 additions and 18 deletions

View File

@ -477,6 +477,7 @@ cc_library(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
@ -816,29 +817,51 @@ IdentityValueAndHloOpcodeForScatterReduceComputation(
"add/or/multiply/add/min/max"); "add/or/multiply/add/min/max");
} }
std::vector<int64> DevicesForSharding( namespace {
const HloSharding& sharding, const std::vector<int64>& available_devices) {
std::vector<int64> devices; void DevicesForShardingInternal(
if (sharding.IsReplicated()) { const HloSharding& sharding,
for (int64 d : available_devices) { const absl::flat_hash_set<int64>& available_devices,
if (!HloSharding::IsReservedDevice(d)) { absl::flat_hash_set<int64>* used) {
devices.push_back(d); if (sharding.IsTuple()) {
} for (const auto& subsharding : sharding.tuple_elements()) {
DevicesForShardingInternal(subsharding, available_devices, used);
} }
return devices; return;
} }
for (int64 i : available_devices) { if (sharding.IsReplicated()) {
if (sharding.UsesDevice(i)) { for (int64 device : available_devices) {
devices.push_back(i); if (!HloSharding::IsReservedDevice(device)) {
used->insert(device);
}
}
return;
}
DCHECK(std::all_of(
sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
[&](int64 device) { return available_devices.contains(device); }));
sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
int64 device) { used->insert(device); });
}
} // namespace
std::vector<int64> DevicesForSharding(
const HloSharding& sharding, const std::vector<int64>& available_devices) {
absl::flat_hash_set<int64> available_set;
for (int64 device : available_devices) {
available_set.insert(device);
}
absl::flat_hash_set<int64> used_set;
DevicesForShardingInternal(sharding, available_set, &used_set);
std::vector<int64> devices;
for (int64 device : available_devices) {
if (used_set.contains(device)) {
devices.push_back(device);
} }
} }
DCHECK(std::all_of(sharding.tile_assignment().begin(),
sharding.tile_assignment().end(), [&](int64 device) {
return std::find(available_devices.begin(),
available_devices.end(),
device) != available_devices.end();
}));
return devices; return devices;
} }