[XLA] Allow multiple compatible devices in device_util

PiperOrigin-RevId: 257904907
This commit is contained in:
Eugene Zhulenev 2019-07-12 18:29:28 -07:00 committed by TensorFlower Gardener
parent 4a3cf05668
commit 1670303042
2 changed files with 47 additions and 7 deletions

View File

@ -81,7 +81,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
std::vector<string> names;
device_set.ForEach([&](DeviceId device_id) {
names.push_back(string(GetNameFor(device_id)));
return false;
return true;
});
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
@ -118,19 +118,48 @@ xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
bool multiple_gpu_devices = false;
bool multiple_unknown_devices = false;
// Returns 'true' if d0 and d1 are conflicting devices. If they are
// compatible, update d1 with a more specific one.
// TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
const auto is_multiple_devices =
[&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
const absl::string_view name0 = device_info_cache.GetNameFor(d0);
const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
DeviceNameUtils::ParsedName parsed0, parsed1;
if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
!DeviceNameUtils::ParseFullName(name1, &parsed1) ||
!DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
return true;
}
if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
return false;
}
if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
*d1 = d0;
return false;
}
return true;
};
devices.ForEach([&](jit::DeviceId device) {
if (device_info_cache.IsGpu(device)) {
if (maybe_gpu_device) {
multiple_gpu_devices = true;
return false;
multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
if (multiple_gpu_devices) return false;
} else {
maybe_gpu_device = device;
}
maybe_gpu_device = device;
} else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) {
multiple_cpu_devices = true;
return false;
multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
if (multiple_cpu_devices) return false;
} else {
maybe_cpu_device = device;
}
maybe_cpu_device = device;
} else {
if (maybe_unknown_device) {
multiple_unknown_devices = true;

View File

@ -65,10 +65,21 @@ const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
const char* kCPU0Partial = "/device:CPU:0";
const char* kGPU0Partial = "/device:GPU:0";
const char* kXPU0Partial = "/device:XPU:0";
TEST(PickDeviceForXla, UniqueDevice) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
}
TEST(PickDeviceForXla, MoreSpecificDevice) {
CheckPickDeviceResult(kCPU0, false, {kCPU0, kCPU0Partial});
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0Partial});
// Unknown devices do not support merging of full and partial specifications.
CheckPickDeviceHasError(false, {kXPU1, kXPU0Partial});
}
TEST(PickDeviceForXla, DeviceOrder) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});