[XLA] Allow multiple compatible devices in device_util
PiperOrigin-RevId: 257904907
This commit is contained in:
parent
4a3cf05668
commit
1670303042
@ -81,7 +81,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
||||
std::vector<string> names;
|
||||
device_set.ForEach([&](DeviceId device_id) {
|
||||
names.push_back(string(GetNameFor(device_id)));
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
||||
@ -118,19 +118,48 @@ xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
|
||||
bool multiple_gpu_devices = false;
|
||||
bool multiple_unknown_devices = false;
|
||||
|
||||
// Returns 'true' if d0 and d1 are conflicting devices. If they are
|
||||
// compatible, update d1 with a more specific one.
|
||||
// TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
|
||||
const auto is_multiple_devices =
|
||||
[&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
|
||||
const absl::string_view name0 = device_info_cache.GetNameFor(d0);
|
||||
const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
|
||||
|
||||
DeviceNameUtils::ParsedName parsed0, parsed1;
|
||||
if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
|
||||
!DeviceNameUtils::ParseFullName(name1, &parsed1) ||
|
||||
!DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
|
||||
*d1 = d0;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
devices.ForEach([&](jit::DeviceId device) {
|
||||
if (device_info_cache.IsGpu(device)) {
|
||||
if (maybe_gpu_device) {
|
||||
multiple_gpu_devices = true;
|
||||
return false;
|
||||
multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
|
||||
if (multiple_gpu_devices) return false;
|
||||
} else {
|
||||
maybe_gpu_device = device;
|
||||
}
|
||||
maybe_gpu_device = device;
|
||||
} else if (device_info_cache.IsCpu(device)) {
|
||||
if (maybe_cpu_device) {
|
||||
multiple_cpu_devices = true;
|
||||
return false;
|
||||
multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
|
||||
if (multiple_cpu_devices) return false;
|
||||
} else {
|
||||
maybe_cpu_device = device;
|
||||
}
|
||||
maybe_cpu_device = device;
|
||||
} else {
|
||||
if (maybe_unknown_device) {
|
||||
multiple_unknown_devices = true;
|
||||
|
@ -65,10 +65,21 @@ const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
|
||||
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
|
||||
|
||||
const char* kCPU0Partial = "/device:CPU:0";
|
||||
const char* kGPU0Partial = "/device:GPU:0";
|
||||
const char* kXPU0Partial = "/device:XPU:0";
|
||||
|
||||
TEST(PickDeviceForXla, UniqueDevice) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, MoreSpecificDevice) {
|
||||
CheckPickDeviceResult(kCPU0, false, {kCPU0, kCPU0Partial});
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0Partial});
|
||||
// Unknown devices do not support merging of full and partial specifications.
|
||||
CheckPickDeviceHasError(false, {kXPU1, kXPU0Partial});
|
||||
}
|
||||
|
||||
TEST(PickDeviceForXla, DeviceOrder) {
|
||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
|
||||
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
|
||||
|
Loading…
Reference in New Issue
Block a user