diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index b062529a3ff..902ca2c2ee2 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -116,12 +116,15 @@ void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) { if (a_type_name != b_type_name) { auto a_priority = DeviceFactory::DevicePriority(a_type_name); auto b_priority = DeviceFactory::DevicePriority(b_type_name); - // First sort by prioritized device type (higher is preferred) and - // then by device name (lexicographically). if (a_priority != b_priority) { return a_priority > b_priority; } } + + if (a.first->IsLocal() != b.first->IsLocal()) { + return a.first->IsLocal(); + } + return StringPiece(a.first->name()) < StringPiece(b.first->name()); }; std::sort(vector->begin(), vector->end(), device_sort); diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index 608705c32f7..f59f84c2066 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -90,8 +90,8 @@ class DeviceSet { // // After a call to this function, the argument vector will be sorted by // explicit priority (the second element in the `std::pair`), then by `DeviceTypeOrder` of the device type, and lastly - // by device name. + // int32>`), then by `DeviceTypeOrder` of the device type, then by device + // locality, and lastly by device name. static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); // Sorts a PrioritizedDeviceTypeVector according to types and explicit