232 lines
8.7 KiB
Python
232 lines
8.7 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===================================================================
|
|
"""TPU system metadata and associated tooling."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.python.client import session as session_lib
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import config
|
|
from tensorflow.python.framework import device as tf_device
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.tpu import tpu
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
|
|
_RETRY_TIMES = 12 * 24 # 1 day
|
|
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
|
|
|
_DEFAULT_JOB_NAME = 'tpu_worker'
|
|
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
|
_LOCAL_MASTERS = ('', 'local')
|
|
|
|
|
|
@tf_export('tpu.experimental.TPUSystemMetadata')
|
|
class TPUSystemMetadata(
|
|
collections.namedtuple('TPUSystemMetadata', [
|
|
'num_cores',
|
|
'num_hosts',
|
|
'num_of_cores_per_host',
|
|
'topology',
|
|
'devices',
|
|
])):
|
|
"""Describes some metadata about the TPU system.
|
|
|
|
Attributes:
|
|
num_cores: interger. Total number of TPU cores in the TPU system.
|
|
num_hosts: interger. Total number of hosts (TPU workers) in the TPU system.
|
|
num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker).
|
|
topology: an instance of `tf.tpu.experimental.Topology`, which describes the
|
|
physical topology of TPU system.
|
|
devices: a tuple of strings, which describes all the TPU devices in the
|
|
system.
|
|
"""
|
|
|
|
def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology,
|
|
devices):
|
|
return super(TPUSystemMetadata,
|
|
cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host,
|
|
topology, devices)
|
|
|
|
|
|
def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|
query_topology=False):
|
|
"""Automatically detects the TPU system metadata in the system."""
|
|
tpu_core_count = 0
|
|
devices = []
|
|
device_dict = collections.defaultdict(list)
|
|
|
|
if context.executing_eagerly():
|
|
logical_devices = config.list_logical_devices()
|
|
|
|
# We want the output type to match in both eager and session mode
|
|
devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
|
|
d.device_type, 0, 0)
|
|
for d in logical_devices]
|
|
else:
|
|
# TODO(b/120564445): Replace with standard library for retries.
|
|
retry_count = 1
|
|
while True:
|
|
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
|
|
master_address)
|
|
try:
|
|
with ops.Graph().as_default():
|
|
with session_lib.Session(
|
|
master_address,
|
|
config=get_session_config_with_timeout(
|
|
_PINGING_MASTER_TIMEOUT_IN_MS,
|
|
cluster_def)) as sess:
|
|
devices = sess.list_devices()
|
|
break
|
|
except errors.DeadlineExceededError:
|
|
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
|
'not be ready (still scheduling) or the Tensorflow master '
|
|
'address is incorrect: got (%s).' %
|
|
(master_address))
|
|
|
|
# TODO(xiejw): For local or grpc master we might not need retry logic
|
|
# here.
|
|
if retry_count <= _RETRY_TIMES:
|
|
logging.warning('%s', msg)
|
|
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
|
|
retry_count += 1
|
|
else:
|
|
raise ValueError(msg)
|
|
|
|
for device in devices:
|
|
spec = tf_device.DeviceSpec.from_string(device.name)
|
|
if spec.device_type == 'TPU':
|
|
device_dict[spec.task].append(spec.device_index)
|
|
tpu_core_count += 1
|
|
|
|
num_of_cores_per_host = 0
|
|
if tpu_core_count:
|
|
num_cores_per_host_set = set(
|
|
[len(core_ids) for core_ids in device_dict.values()])
|
|
if len(num_cores_per_host_set) != 1:
|
|
raise RuntimeError(
|
|
'TPU cores on each host is not same. This should not happen!. '
|
|
'devices: {}'.format(devices))
|
|
num_of_cores_per_host = num_cores_per_host_set.pop()
|
|
|
|
topology = None
|
|
if query_topology:
|
|
if not tpu_core_count:
|
|
raise RuntimeError(
|
|
'Cannot find any TPU cores in the system (master address {}). '
|
|
'This usually means the master address is incorrect or the '
|
|
'TPU worker has some problems. Available devices: {}'.format(
|
|
master_address, devices))
|
|
|
|
topology = _obtain_topology(master_address, cluster_def)
|
|
|
|
# We sort the metadata devices so that downstream users get a sorted list
|
|
# for creating mirrored variables correctly.
|
|
def _sort_key(device):
|
|
spec = tf_device.DeviceSpec.from_string(device.name)
|
|
return (spec.job, spec.replica, spec.task, spec.device_type,
|
|
spec.device_index)
|
|
devices = tuple(sorted(devices, key=_sort_key))
|
|
|
|
metadata = TPUSystemMetadata(
|
|
num_cores=tpu_core_count,
|
|
num_hosts=len(device_dict),
|
|
num_of_cores_per_host=num_of_cores_per_host,
|
|
topology=topology,
|
|
devices=devices)
|
|
|
|
if tpu_core_count:
|
|
logging.info('Found TPU system:')
|
|
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
|
|
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
|
|
logging.info('*** Num TPU Cores Per Worker: %d',
|
|
metadata.num_of_cores_per_host)
|
|
for device in metadata.devices:
|
|
logging.info('*** Available Device: %s', device)
|
|
else:
|
|
logging.info('Failed to find TPU: %s', metadata)
|
|
return metadata
|
|
|
|
|
|
def _obtain_topology(master_address, cluster_def):
|
|
"""Obtains TPU fabric topology."""
|
|
try:
|
|
logging.info('Initializing TPU system (master: %s) to fetch topology '
|
|
'for model parallelism. This might take a while.',
|
|
master_address)
|
|
with ops.Graph().as_default():
|
|
session_config = get_session_config_with_timeout(
|
|
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
|
|
with session_lib.Session(
|
|
master_address, config=session_config) as sess:
|
|
topology = sess.run(tpu.initialize_system())
|
|
return topology
|
|
except errors.DeadlineExceededError:
|
|
raise ValueError(
|
|
'Fail to initialize TPU system with master (%s). '
|
|
'Please double check the TPU system is functional.' % (
|
|
master_address))
|
|
|
|
|
|
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
|
|
"""Returns a session given a timeout and a cluster configuration."""
|
|
config_proto = config_pb2.ConfigProto(
|
|
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
|
|
return config_proto
|
|
|
|
|
|
def master_job(master, cluster_def):
|
|
"""Returns the canonical job name to use to place TPU computations on.
|
|
|
|
Args:
|
|
master: A `string` representing the TensorFlow master to use.
|
|
cluster_def: A ClusterDef object describing the TPU cluster.
|
|
|
|
Returns:
|
|
A string containing the job name, or None if no job should be specified.
|
|
|
|
Raises:
|
|
ValueError: If the user needs to specify a tpu_job_name, because we are
|
|
unable to infer the job name automatically, or if the user-specified job
|
|
names are inappropriate.
|
|
"""
|
|
# If the user specifies the tpu_job_name, use that.
|
|
|
|
if master in _LOCAL_MASTERS:
|
|
return None
|
|
|
|
if (not cluster_def or not cluster_def.job):
|
|
return _DEFAULT_JOB_NAME
|
|
job_names = set(job.name for job in cluster_def.job)
|
|
if _DEFAULT_JOB_NAME in job_names:
|
|
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
|
|
raise ValueError('Currently, tpu_worker is not an allowed job name.')
|
|
if len(job_names) == 1:
|
|
return cluster_def.job[0].name
|
|
if len(job_names) == 2:
|
|
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
|
|
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
|
|
return job_names.pop()
|
|
# TODO(b/67716447): Include more sophisticated heuristics.
|
|
raise ValueError('Could not infer TPU job name.')
|