From 37ab9c3bfcbb278ae003cf32f08b3d41a78401a7 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 21 May 2020 12:18:17 -0700 Subject: [PATCH] Take device locality into account during prioritization. After this CL, if multiple devices with identical device type are viable for a placement of an op, the local device (if available) will be selected. (Prior to this change, the device whose job name comes first alphabetically would be selected.) PiperOrigin-RevId: 312716604 Change-Id: I484c00cf0d34acc23c32ab8dd1cc5c394d32f0f3 --- tensorflow/core/common_runtime/device_set.cc | 7 +++++-- tensorflow/core/common_runtime/device_set.h | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index b062529a3ff..902ca2c2ee2 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -116,12 +116,15 @@ void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) { if (a_type_name != b_type_name) { auto a_priority = DeviceFactory::DevicePriority(a_type_name); auto b_priority = DeviceFactory::DevicePriority(b_type_name); - // First sort by prioritized device type (higher is preferred) and - // then by device name (lexicographically). if (a_priority != b_priority) { return a_priority > b_priority; } } + + if (a.first->IsLocal() != b.first->IsLocal()) { + return a.first->IsLocal(); + } + return StringPiece(a.first->name()) < StringPiece(b.first->name()); }; std::sort(vector->begin(), vector->end(), device_sort); diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index 608705c32f7..f59f84c2066 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -90,8 +90,8 @@ class DeviceSet { // // After a call to this function, the argument vector will be sorted by // explicit priority (the second element in the `std::pair`), then by `DeviceTypeOrder` of the device type, and lastly - // by device name. + // int32>`), then by `DeviceTypeOrder` of the device type, then by device + // locality, and lastly by device name. static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); // Sorts a PrioritizedDeviceTypeVector according to types and explicit