Fix the TPU strategy device sorting.
PiperOrigin-RevId: 238252570
This commit is contained in:
parent
3ce012557b
commit
1b64697fd6
@ -227,15 +227,12 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
|
||||
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
|
||||
# parallelism.
|
||||
self._device_index = {
|
||||
d.name: i for i, d in enumerate(self._tpu_metadata.devices)
|
||||
if "device:TPU:" in d.name
|
||||
}
|
||||
self._tpu_devices = [d.name for d in self._tpu_metadata.devices
|
||||
if "device:TPU:" in d.name]
|
||||
|
||||
self._host_device = tpu_strategy_util.get_first_tpu_host_device(
|
||||
self._tpu_cluster_resolver)
|
||||
# We sort the devices by the indexes in tpu_metadata.devices.
|
||||
self._tpu_devices = tuple(device[0] for device in sorted(
|
||||
self._device_index.items(), key=lambda device: device[1]))
|
||||
|
||||
# Only create variables for the number of replicas we're running.
|
||||
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
||||
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
||||
|
||||
@ -23,6 +23,7 @@ import re
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -113,6 +114,14 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
|
||||
topology = _obtain_topology(master_address, cluster_def)
|
||||
|
||||
# We sort the metadata devices so that downstream users get a sorted list
|
||||
# for creating mirrored variables correctly.
|
||||
def _sort_key(device):
|
||||
spec = tf_device.DeviceSpec.from_string(device.name)
|
||||
return (spec.job, spec.replica, spec.task, spec.device_type,
|
||||
spec.device_index)
|
||||
devices = tuple(sorted(devices, key=_sort_key))
|
||||
|
||||
metadata = _TPUSystemMetadata(
|
||||
num_cores=tpu_core_count,
|
||||
num_hosts=len(device_dict),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user