diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 200e795a2e8..375d30c4cf3 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -81,7 +81,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { std::vector names; device_set.ForEach([&](DeviceId device_id) { names.push_back(string(GetNameFor(device_id))); - return false; + return true; }); return absl::StrCat("[", absl::StrJoin(names, ","), "]"); @@ -118,19 +118,48 @@ xla::StatusOr> PickDeviceForXlaImpl( bool multiple_gpu_devices = false; bool multiple_unknown_devices = false; + // Returns 'true' if d0 and d1 are conflicting devices. If they are + // compatible, update d1 with a more specific one. + // TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache. + const auto is_multiple_devices = + [&](const jit::DeviceId& d0, absl::optional* d1) -> bool { + const absl::string_view name0 = device_info_cache.GetNameFor(d0); + const absl::string_view name1 = device_info_cache.GetNameFor(d1->value()); + + DeviceNameUtils::ParsedName parsed0, parsed1; + if (!DeviceNameUtils::ParseFullName(name0, &parsed0) || + !DeviceNameUtils::ParseFullName(name1, &parsed1) || + !DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) { + return true; + } + + if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) { + return false; + } + + if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) { + *d1 = d0; + return false; + } + + return true; + }; + devices.ForEach([&](jit::DeviceId device) { if (device_info_cache.IsGpu(device)) { if (maybe_gpu_device) { - multiple_gpu_devices = true; - return false; + multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device); + if (multiple_gpu_devices) return false; + } else { + maybe_gpu_device = device; } - maybe_gpu_device = device; } else if (device_info_cache.IsCpu(device)) { if (maybe_cpu_device) { - multiple_cpu_devices = true; - return false; + multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device); + if (multiple_cpu_devices) return false; + } else { + maybe_cpu_device = device; } - maybe_cpu_device = device; } else { if (maybe_unknown_device) { multiple_unknown_devices = true; diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index 9396c49d52e..124cb09cfb7 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -65,10 +65,21 @@ const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1"; const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1"; const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1"; +const char* kCPU0Partial = "/device:CPU:0"; +const char* kGPU0Partial = "/device:GPU:0"; +const char* kXPU0Partial = "/device:XPU:0"; + TEST(PickDeviceForXla, UniqueDevice) { CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0}); } +TEST(PickDeviceForXla, MoreSpecificDevice) { + CheckPickDeviceResult(kCPU0, false, {kCPU0, kCPU0Partial}); + CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0Partial}); + // Unknown devices do not support merging of full and partial specifications. + CheckPickDeviceHasError(false, {kXPU1, kXPU0Partial}); +} + TEST(PickDeviceForXla, DeviceOrder) { CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0}); CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});