Fix the TPU strategy device sorting.
PiperOrigin-RevId: 238252570
This commit is contained in:
parent
3ce012557b
commit
1b64697fd6
@ -227,15 +227,12 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
|
|
||||||
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
|
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
|
||||||
# parallelism.
|
# parallelism.
|
||||||
self._device_index = {
|
self._tpu_devices = [d.name for d in self._tpu_metadata.devices
|
||||||
d.name: i for i, d in enumerate(self._tpu_metadata.devices)
|
if "device:TPU:" in d.name]
|
||||||
if "device:TPU:" in d.name
|
|
||||||
}
|
|
||||||
self._host_device = tpu_strategy_util.get_first_tpu_host_device(
|
self._host_device = tpu_strategy_util.get_first_tpu_host_device(
|
||||||
self._tpu_cluster_resolver)
|
self._tpu_cluster_resolver)
|
||||||
# We sort the devices by the indexes in tpu_metadata.devices.
|
|
||||||
self._tpu_devices = tuple(device[0] for device in sorted(
|
|
||||||
self._device_index.items(), key=lambda device: device[1]))
|
|
||||||
# Only create variables for the number of replicas we're running.
|
# Only create variables for the number of replicas we're running.
|
||||||
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
||||||
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import re
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -113,6 +114,14 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
|
|
||||||
topology = _obtain_topology(master_address, cluster_def)
|
topology = _obtain_topology(master_address, cluster_def)
|
||||||
|
|
||||||
|
# We sort the metadata devices so that downstream users get a sorted list
|
||||||
|
# for creating mirrored variables correctly.
|
||||||
|
def _sort_key(device):
|
||||||
|
spec = tf_device.DeviceSpec.from_string(device.name)
|
||||||
|
return (spec.job, spec.replica, spec.task, spec.device_type,
|
||||||
|
spec.device_index)
|
||||||
|
devices = tuple(sorted(devices, key=_sort_key))
|
||||||
|
|
||||||
metadata = _TPUSystemMetadata(
|
metadata = _TPUSystemMetadata(
|
||||||
num_cores=tpu_core_count,
|
num_cores=tpu_core_count,
|
||||||
num_hosts=len(device_dict),
|
num_hosts=len(device_dict),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user