[XLA] Allow multiple compatible devices in device_util
PiperOrigin-RevId: 257904907
This commit is contained in:
parent
4a3cf05668
commit
1670303042
@ -81,7 +81,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
|||||||
std::vector<string> names;
|
std::vector<string> names;
|
||||||
device_set.ForEach([&](DeviceId device_id) {
|
device_set.ForEach([&](DeviceId device_id) {
|
||||||
names.push_back(string(GetNameFor(device_id)));
|
names.push_back(string(GetNameFor(device_id)));
|
||||||
return false;
|
return true;
|
||||||
});
|
});
|
||||||
|
|
||||||
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
||||||
@ -118,19 +118,48 @@ xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
|
|||||||
bool multiple_gpu_devices = false;
|
bool multiple_gpu_devices = false;
|
||||||
bool multiple_unknown_devices = false;
|
bool multiple_unknown_devices = false;
|
||||||
|
|
||||||
|
// Returns 'true' if d0 and d1 are conflicting devices. If they are
|
||||||
|
// compatible, update d1 with a more specific one.
|
||||||
|
// TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
|
||||||
|
const auto is_multiple_devices =
|
||||||
|
[&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
|
||||||
|
const absl::string_view name0 = device_info_cache.GetNameFor(d0);
|
||||||
|
const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
|
||||||
|
|
||||||
|
DeviceNameUtils::ParsedName parsed0, parsed1;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
|
||||||
|
!DeviceNameUtils::ParseFullName(name1, &parsed1) ||
|
||||||
|
!DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
|
||||||
|
*d1 = d0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
devices.ForEach([&](jit::DeviceId device) {
|
devices.ForEach([&](jit::DeviceId device) {
|
||||||
if (device_info_cache.IsGpu(device)) {
|
if (device_info_cache.IsGpu(device)) {
|
||||||
if (maybe_gpu_device) {
|
if (maybe_gpu_device) {
|
||||||
multiple_gpu_devices = true;
|
multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
|
||||||
return false;
|
if (multiple_gpu_devices) return false;
|
||||||
|
} else {
|
||||||
|
maybe_gpu_device = device;
|
||||||
}
|
}
|
||||||
maybe_gpu_device = device;
|
|
||||||
} else if (device_info_cache.IsCpu(device)) {
|
} else if (device_info_cache.IsCpu(device)) {
|
||||||
if (maybe_cpu_device) {
|
if (maybe_cpu_device) {
|
||||||
multiple_cpu_devices = true;
|
multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
|
||||||
return false;
|
if (multiple_cpu_devices) return false;
|
||||||
|
} else {
|
||||||
|
maybe_cpu_device = device;
|
||||||
}
|
}
|
||||||
maybe_cpu_device = device;
|
|
||||||
} else {
|
} else {
|
||||||
if (maybe_unknown_device) {
|
if (maybe_unknown_device) {
|
||||||
multiple_unknown_devices = true;
|
multiple_unknown_devices = true;
|
||||||
|
@ -65,10 +65,21 @@ const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
|
|||||||
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
|
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
|
||||||
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
|
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
|
||||||
|
|
||||||
|
const char* kCPU0Partial = "/device:CPU:0";
|
||||||
|
const char* kGPU0Partial = "/device:GPU:0";
|
||||||
|
const char* kXPU0Partial = "/device:XPU:0";
|
||||||
|
|
||||||
TEST(PickDeviceForXla, UniqueDevice) {
|
TEST(PickDeviceForXla, UniqueDevice) {
|
||||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
|
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PickDeviceForXla, MoreSpecificDevice) {
|
||||||
|
CheckPickDeviceResult(kCPU0, false, {kCPU0, kCPU0Partial});
|
||||||
|
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0Partial});
|
||||||
|
// Unknown devices do not support merging of full and partial specifications.
|
||||||
|
CheckPickDeviceHasError(false, {kXPU1, kXPU0Partial});
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PickDeviceForXla, DeviceOrder) {
|
TEST(PickDeviceForXla, DeviceOrder) {
|
||||||
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
|
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
|
||||||
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
|
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user