diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b9aff298d35..305e228fbca 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -477,6 +477,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index da4e3d61a81..18f76c5253b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -816,29 +817,51 @@ IdentityValueAndHloOpcodeForScatterReduceComputation( "add/or/multiply/add/min/max"); } -std::vector DevicesForSharding( - const HloSharding& sharding, const std::vector& available_devices) { - std::vector devices; - if (sharding.IsReplicated()) { - for (int64 d : available_devices) { - if (!HloSharding::IsReservedDevice(d)) { - devices.push_back(d); - } +namespace { + +void DevicesForShardingInternal( + const HloSharding& sharding, + const absl::flat_hash_set& available_devices, + absl::flat_hash_set* used) { + if (sharding.IsTuple()) { + for (const auto& subsharding : sharding.tuple_elements()) { + DevicesForShardingInternal(subsharding, available_devices, used); } - return devices; + return; } - for (int64 i : available_devices) { - if (sharding.UsesDevice(i)) { - devices.push_back(i); + if (sharding.IsReplicated()) { + for (int64 device : available_devices) { + if (!HloSharding::IsReservedDevice(device)) { + used->insert(device); + } + } + return; + } + + DCHECK(std::all_of( + sharding.tile_assignment().begin(), sharding.tile_assignment().end(), + [&](int64 device) { return available_devices.contains(device); })); + sharding.tile_assignment().Each([&](absl::Span /*indices*/, + int64 device) { used->insert(device); }); +} + +} // namespace + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + absl::flat_hash_set available_set; + for (int64 device : available_devices) { + available_set.insert(device); + } + absl::flat_hash_set used_set; + DevicesForShardingInternal(sharding, available_set, &used_set); + std::vector devices; + for (int64 device : available_devices) { + if (used_set.contains(device)) { + devices.push_back(device); } } - DCHECK(std::all_of(sharding.tile_assignment().begin(), - sharding.tile_assignment().end(), [&](int64 device) { - return std::find(available_devices.begin(), - available_devices.end(), - device) != available_devices.end(); - })); return devices; }