STT-tensorflow/tensorflow/python/tpu/tpu_system_metadata.py
2020-03-18 19:48:04 -07:00

232 lines
8.7 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu
from tensorflow.python.util.tf_export import tf_export
_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
_RETRY_TIMES = 12 * 24 # 1 day
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
@tf_export('tpu.experimental.TPUSystemMetadata')
class TPUSystemMetadata(
collections.namedtuple('TPUSystemMetadata', [
'num_cores',
'num_hosts',
'num_of_cores_per_host',
'topology',
'devices',
])):
"""Describes some metadata about the TPU system.
Attributes:
num_cores: interger. Total number of TPU cores in the TPU system.
num_hosts: interger. Total number of hosts (TPU workers) in the TPU system.
num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker).
topology: an instance of `tf.tpu.experimental.Topology`, which describes the
physical topology of TPU system.
devices: a tuple of strings, which describes all the TPU devices in the
system.
"""
def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology,
devices):
return super(TPUSystemMetadata,
cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host,
topology, devices)
def _query_tpu_system_metadata(master_address, cluster_def=None,
query_topology=False):
"""Automatically detects the TPU system metadata in the system."""
tpu_core_count = 0
devices = []
device_dict = collections.defaultdict(list)
if context.executing_eagerly():
logical_devices = config.list_logical_devices()
# We want the output type to match in both eager and session mode
devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
d.device_type, 0, 0)
for d in logical_devices]
else:
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
while True:
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
master_address)
try:
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master '
'address is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic
# here.
if retry_count <= _RETRY_TIMES:
logging.warning('%s', msg)
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
retry_count += 1
else:
raise ValueError(msg)
for device in devices:
spec = tf_device.DeviceSpec.from_string(device.name)
if spec.device_type == 'TPU':
device_dict[spec.task].append(spec.device_index)
tpu_core_count += 1
num_of_cores_per_host = 0
if tpu_core_count:
num_cores_per_host_set = set(
[len(core_ids) for core_ids in device_dict.values()])
if len(num_cores_per_host_set) != 1:
raise RuntimeError(
'TPU cores on each host is not same. This should not happen!. '
'devices: {}'.format(devices))
num_of_cores_per_host = num_cores_per_host_set.pop()
topology = None
if query_topology:
if not tpu_core_count:
raise RuntimeError(
'Cannot find any TPU cores in the system (master address {}). '
'This usually means the master address is incorrect or the '
'TPU worker has some problems. Available devices: {}'.format(
master_address, devices))
topology = _obtain_topology(master_address, cluster_def)
# We sort the metadata devices so that downstream users get a sorted list
# for creating mirrored variables correctly.
def _sort_key(device):
spec = tf_device.DeviceSpec.from_string(device.name)
return (spec.job, spec.replica, spec.task, spec.device_type,
spec.device_index)
devices = tuple(sorted(devices, key=_sort_key))
metadata = TPUSystemMetadata(
num_cores=tpu_core_count,
num_hosts=len(device_dict),
num_of_cores_per_host=num_of_cores_per_host,
topology=topology,
devices=devices)
if tpu_core_count:
logging.info('Found TPU system:')
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
logging.info('*** Num TPU Cores Per Worker: %d',
metadata.num_of_cores_per_host)
for device in metadata.devices:
logging.info('*** Available Device: %s', device)
else:
logging.info('Failed to find TPU: %s', metadata)
return metadata
def _obtain_topology(master_address, cluster_def):
"""Obtains TPU fabric topology."""
try:
logging.info('Initializing TPU system (master: %s) to fetch topology '
'for model parallelism. This might take a while.',
master_address)
with ops.Graph().as_default():
session_config = get_session_config_with_timeout(
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
with session_lib.Session(
master_address, config=session_config) as sess:
topology = sess.run(tpu.initialize_system())
return topology
except errors.DeadlineExceededError:
raise ValueError(
'Fail to initialize TPU system with master (%s). '
'Please double check the TPU system is functional.' % (
master_address))
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
"""Returns a session given a timeout and a cluster configuration."""
config_proto = config_pb2.ConfigProto(
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
return config_proto
def master_job(master, cluster_def):
"""Returns the canonical job name to use to place TPU computations on.
Args:
master: A `string` representing the TensorFlow master to use.
cluster_def: A ClusterDef object describing the TPU cluster.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
# If the user specifies the tpu_job_name, use that.
if master in _LOCAL_MASTERS:
return None
if (not cluster_def or not cluster_def.job):
return _DEFAULT_JOB_NAME
job_names = set(job.name for job in cluster_def.job)
if _DEFAULT_JOB_NAME in job_names:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise ValueError('Currently, tpu_worker is not an allowed job name.')
if len(job_names) == 1:
return cluster_def.job[0].name
if len(job_names) == 2:
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
return job_names.pop()
# TODO(b/67716447): Include more sophisticated heuristics.
raise ValueError('Could not infer TPU job name.')