Fix the TPU strategy device sorting.

PiperOrigin-RevId: 238252570
This commit is contained in:
Sourabh Bajaj 2019-03-13 10:26:56 -07:00 committed by TensorFlower Gardener
parent 3ce012557b
commit 1b64697fd6
2 changed files with 13 additions and 7 deletions

View File

@ -227,15 +227,12 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
# parallelism.
self._device_index = {
d.name: i for i, d in enumerate(self._tpu_metadata.devices)
if "device:TPU:" in d.name
}
self._tpu_devices = [d.name for d in self._tpu_metadata.devices
if "device:TPU:" in d.name]
self._host_device = tpu_strategy_util.get_first_tpu_host_device(
self._tpu_cluster_resolver)
# We sort the devices by the indexes in tpu_metadata.devices.
self._tpu_devices = tuple(device[0] for device in sorted(
self._device_index.items(), key=lambda device: device[1]))
# Only create variables for the number of replicas we're running.
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)

View File

@ -23,6 +23,7 @@ import re
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
@ -113,6 +114,14 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
topology = _obtain_topology(master_address, cluster_def)
# We sort the metadata devices so that downstream users get a sorted list
# for creating mirrored variables correctly.
def _sort_key(device):
spec = tf_device.DeviceSpec.from_string(device.name)
return (spec.job, spec.replica, spec.task, spec.device_type,
spec.device_index)
devices = tuple(sorted(devices, key=_sort_key))
metadata = _TPUSystemMetadata(
num_cores=tpu_core_count,
num_hosts=len(device_dict),