Move TPUClusterResolver into tpu subdirectory.

PiperOrigin-RevId: 311410592
Change-Id: I7c4ca01621ae27cd4c36ff996cf90237328d75e4
This commit is contained in:
Bruce Fontaine 2020-05-13 14:56:38 -07:00 committed by TensorFlower Gardener
parent 2560d6fd31
commit 2046f7c450
8 changed files with 403 additions and 364 deletions

View File

@ -6,8 +6,6 @@ tensorflow/compat_template_v1.__init__.py
tensorflow/compiler/mlir/glob_lit_test.bzl
tensorflow/lite/micro/build_def.bzl
tensorflow/python/autograph/core/config.py
tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py
tensorflow/python/eager/benchmarks_test_base.py
tensorflow/python/tpu/profiler/pip_package/BUILD
tensorflow/python/tpu/profiler/pip_package/README

View File

@ -1,10 +1,6 @@
# Description: Operations defined for Cluster Resolvers
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_rpc_deps",
)
package(
default_visibility = [
@ -64,12 +60,7 @@ py_library(
name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/tpu/client",
] + tf_additional_rpc_deps(),
deps = ["//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py"],
)
py_library(
@ -137,25 +128,6 @@ tf_py_test(
],
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",
srcs = ["tpu_cluster_resolver_test.py"],
grpc_enabled = True,
main = "tpu_cluster_resolver_test.py",
python_version = "PY3",
deps = [
":tpu_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client",
"@absl_py//absl/testing:flagsaver",
],
)
tf_py_test(
name = "slurm_cluster_resolver_py_test",
size = "small",

View File

@ -0,0 +1,44 @@
# Description: OSS only cluster resolvers
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_rpc_deps",
)
package(
default_visibility = [
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:training_server_lib",
"//tensorflow/python/distribute/cluster_resolver:base_cluster_resolver_py",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/tpu/client",
] + tf_additional_rpc_deps(),
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",
srcs = ["tpu_cluster_resolver_test.py"],
grpc_enabled = True,
main = "tpu_cluster_resolver_test.py",
python_version = "PY3",
deps = [
":tpu_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client",
],
)

View File

@ -0,0 +1,349 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Cloud TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.debug(
'Falling back to TensorFlow client; we recommended you install the Cloud '
'TPU client directly with pip install cloud-tpu-client.')
from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top
def is_running_in_gce():
return True
_TPU_DEVICE_REGEX = re.compile(
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
_TPU_CONN_RETRIES = 120
DeviceDetails = collections.namedtuple(
'DeviceDetails', ['device_map', 'total_cores'])
class TPUClusterResolver(cluster_resolver.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
TPUClusterResolver supports the following distinct environments:
Google Compute Engine
Google Kubernetes Engine
Google internal
"""
@staticmethod
def _get_device_dict_and_cores(devices):
"""Returns a dict of hosts to cores and total cores given devices names.
Returns a namedtuple with two attributes:
device_map: A map of host_ids to a list of core_ids.
total_cores: The total number of cores within the TPU system.
Args:
devices: A list of devices returned by session.list_devices()
"""
device_map = collections.defaultdict(list)
num_cores = 0
for device in devices:
match = _TPU_DEVICE_REGEX.match(device.name)
if match:
host_id = match.group('host_id')
core_id = match.group('core_id')
device_map[host_id].append(core_id)
num_cores += 1
return DeviceDetails(device_map, num_cores)
@staticmethod
def _verify_and_return_same_core_count(device_dict):
"""Verifies that every device in device_dict has the same # of cores."""
num_cores_per_host_set = (
{len(core_ids) for core_ids in device_dict.values()})
if len(num_cores_per_host_set) != 1:
raise RuntimeError('TPU cores on each device is not the same. This '
'should never happen. Devices: {}'.format(device_dict))
return num_cores_per_host_set.pop()
def __init__(self,
tpu=None,
zone=None,
project=None,
job_name='worker',
coordinator_name=None,
coordinator_address=None,
credentials='default',
service=None,
discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
for the IP addresses and ports of each Cloud TPU listed.
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://', then
it is assumed to not correspond with a Cloud TPU and will instead be
passed as the session master and no ClusterSpec propagation will be
done. In the future, this may also support a list of strings when
multiple Cloud TPUs are used.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
project: Name of the GCP project containing Cloud TPUs. If omitted or
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
coordinator_name: The name to use for the coordinator. Set to None if the
coordinator should not be included in the computed ClusterSpec.
coordinator_address: The address of the coordinator (typically an ip:port
pair). If set to None, a TF server will be started. If coordinator_name
is None, a TF server will not be started even if coordinator_address is
None.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
discovery_url: A URL template that points to the location of the discovery
service. It should have two parameters {api} and {apiVersion} that when
filled in produce an absolute URL to the discovery document for that
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
this.
Raises:
ImportError: If the googleapiclient is not installed.
ValueError: If no TPUs are specified.
RuntimeError: If an empty TPU name is specified and this is running in a
Google Cloud environment.
"""
self._cloud_tpu_client = client.Client(
tpu=tpu,
zone=zone,
project=project,
credentials=credentials,
service=service,
discovery_url=discovery_url)
self._tpu = self._cloud_tpu_client.name()
# By default the task_type is 'worker` and the task_id is 0 (which is the
# first worker in the task).
self.task_type = job_name
self.task_id = 0
self._coordinator_name = coordinator_name
if (coordinator_name and not coordinator_address):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
def __enter__(self):
self._cloud_tpu_client.enter()
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
self._cloud_tpu_client.exit(type, value, traceback)
def master(self, task_type=None, task_id=None, rpc_layer=None):
"""Get the Master string to be used for the session.
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
first instance in the ClusterSpec returned by the cluster_spec function.
If a non-TPU name is used when constructing a TPUClusterResolver, that will
be returned instead (e.g. If the tpus argument's value when constructing
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
'grpc://10.240.1.2:8470' will be returned).
Args:
task_type: (Optional, string) The type of the TensorFlow task of the
master.
task_id: (Optional, integer) The index of the TensorFlow task of the
master.
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
communicate with TPUs.
Returns:
string, the connection string to use when creating a session.
Raises:
ValueError: If none of the TPUs specified exists.
"""
cluster_spec = self.cluster_spec()
if task_type is not None and task_id is not None:
# task_type and task_id is from the function parameter
master = cluster_spec.task_address(task_type, task_id)
elif self.task_type is not None and self.task_id is not None:
# task_type and task_id is from the object
master = cluster_spec.task_address(self.task_type, self.task_id)
else:
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
return cluster_resolver.format_master_url(master, 'grpc')
def get_master(self):
return self.master()
def get_job_name(self):
return self.task_type
def get_tpu_system_metadata(self):
"""Returns the metadata of the TPU system.
Users can call this method to get some facts of the TPU system, like
total number of cores, number of TPU workers and the devices. E.g.
```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tpu_system_medata = resolver.get_tpu_system_metadata()
num_hosts = tpu_system_medata.num_hosts
```
Returns:
A `tf.tpu.experimental.TPUSystemMetadata` object.
"""
cluster_spec = self.cluster_spec()
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access
self.master(),
cluster_def=cluster_def,
query_topology=False))
return tpu_system_metadata
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
We retrieve the information from the GCE APIs every time this method is
called.
Returns:
A ClusterSpec containing host information returned from Cloud TPUs,
or None.
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
############################################################################
# There are 5 potential cases this code must handle:
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
# a. Create a ClusterSpec that includes the coordinator job
# b. Create a ClusterSpec without the coordinator job.
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
# tasks and
# a. Create a ClusterSpec with the coordinator
# b. Create a ClusterSpec without the coordinator
############################################################################
network_endpoints = self._cloud_tpu_client.network_endpoints()
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in network_endpoints
]
cluster_spec = {self.task_type: worker_list}
if self._coordinator_address:
# {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
def num_accelerators(self,
task_type=None,
task_id=None,
config_proto=None):
"""Returns the number of TPU cores per worker.
Connects to the master and list all the devices present in the master,
and counts them up. Also verifies that the device counts per host in the
cluster is the same before returning the number of TPU cores per host.
Args:
task_type: Unused.
task_id: Unused.
config_proto: Used to create a connection to a TPU master in order to
retrieve the system metadata.
Raises:
RuntimeError: If we cannot talk to a TPU worker after retrying or if the
number of TPU devices per host is different.
"""
retry_count = 1
# TODO(b/120564445): Replace with standard library for retries.
while True:
try:
device_details = TPUClusterResolver._get_device_dict_and_cores(
cluster_resolver.get_accelerator_devices(
self.master(), config_proto=config_proto))
break
except errors.DeadlineExceededError:
error_message = ('Failed to connect to master. The TPU might not be '
'ready (e.g. still scheduling) or the master '
'address is incorrect: got (%s)' % self.master())
if retry_count <= _TPU_CONN_RETRIES:
logging.warning(error_message)
logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
retry_count += 1
else:
raise RuntimeError(error_message)
if device_details.total_cores:
return {'TPU': TPUClusterResolver._verify_and_return_same_core_count(
device_details.device_map)}
return {'TPU': 0}
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
return self._environment
def _start_local_server(self):
address = compat.as_text(self._cloud_tpu_client.get_local_ip())
self._server = server_lib.Server({'local': ['0.0.0.0:0']},
protocol='grpc',
config=None,
start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))
assert len(splits) == 3, self._server.target
assert splits[0] == compat.as_bytes('grpc'), self._server.target
self._coordinator_port = compat.as_text(splits[2])
self._coordinator_address = '%s:%s' % (
address, compat.as_text(self._coordinator_port))
def __deepcopy__(self, memo):
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
return self

View File

@ -25,7 +25,7 @@ from six.moves.urllib.error import URLError
from tensorflow.python import framework
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver as resolver
from tensorflow.python.eager.context import LogicalDevice
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
@ -41,7 +41,7 @@ except ImportError:
logging.debug(
'Falling back to TensorFlow client; we recommended you install the Cloud '
'TPU client directly with pip install cloud-tpu-client.')
from tensorflow.python.tpu.client import client
from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top
class MockRequestClass(object):

View File

@ -12,339 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Cloud TPUs."""
"""Shim so that direct imports of tpu_cluster_resolver get correct symbols.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import is_running_in_gce # pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.util.tf_export import tf_export
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.debug(
'Falling back to TensorFlow client; we recommended you install the Cloud '
'TPU client directly with pip install cloud-tpu-client.')
from tensorflow.python.tpu.client import client
def is_running_in_gce():
return True
_TPU_DEVICE_REGEX = re.compile(
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
_TPU_CONN_RETRIES = 120
DeviceDetails = collections.namedtuple(
'DeviceDetails', ['device_map', 'total_cores'])
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
class TPUClusterResolver(cluster_resolver.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
TPUClusterResolver supports the following distinct environments:
Google Compute Engine
Google Kubernetes Engine
Google internal
"""
@staticmethod
def _get_device_dict_and_cores(devices):
"""Returns a dict of hosts to cores and total cores given devices names.
Returns a namedtuple with two attributes:
device_map: A map of host_ids to a list of core_ids.
total_cores: The total number of cores within the TPU system.
Args:
devices: A list of devices returned by session.list_devices()
"""
device_map = collections.defaultdict(list)
num_cores = 0
for device in devices:
match = _TPU_DEVICE_REGEX.match(device.name)
if match:
host_id = match.group('host_id')
core_id = match.group('core_id')
device_map[host_id].append(core_id)
num_cores += 1
return DeviceDetails(device_map, num_cores)
@staticmethod
def _verify_and_return_same_core_count(device_dict):
"""Verifies that every device in device_dict has the same # of cores."""
num_cores_per_host_set = (
{len(core_ids) for core_ids in device_dict.values()})
if len(num_cores_per_host_set) != 1:
raise RuntimeError('TPU cores on each device is not the same. This '
'should never happen. Devices: {}'.format(device_dict))
return num_cores_per_host_set.pop()
def __init__(self,
tpu=None,
zone=None,
project=None,
job_name='worker',
coordinator_name=None,
coordinator_address=None,
credentials='default',
service=None,
discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
for the IP addresses and ports of each Cloud TPU listed.
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://', then
it is assumed to not correspond with a Cloud TPU and will instead be
passed as the session master and no ClusterSpec propagation will be
done. In the future, this may also support a list of strings when
multiple Cloud TPUs are used.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
project: Name of the GCP project containing Cloud TPUs. If omitted or
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
coordinator_name: The name to use for the coordinator. Set to None if the
coordinator should not be included in the computed ClusterSpec.
coordinator_address: The address of the coordinator (typically an ip:port
pair). If set to None, a TF server will be started. If coordinator_name
is None, a TF server will not be started even if coordinator_address is
None.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
discovery_url: A URL template that points to the location of the discovery
service. It should have two parameters {api} and {apiVersion} that when
filled in produce an absolute URL to the discovery document for that
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
this.
Raises:
ImportError: If the googleapiclient is not installed.
ValueError: If no TPUs are specified.
RuntimeError: If an empty TPU name is specified and this is running in a
Google Cloud environment.
"""
self._cloud_tpu_client = client.Client(
tpu=tpu,
zone=zone,
project=project,
credentials=credentials,
service=service,
discovery_url=discovery_url)
self._tpu = self._cloud_tpu_client.name()
# By default the task_type is 'worker` and the task_id is 0 (which is the
# first worker in the task).
self.task_type = job_name
self.task_id = 0
self._coordinator_name = coordinator_name
if (coordinator_name and not coordinator_address):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
def __enter__(self):
self._cloud_tpu_client.enter()
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
self._cloud_tpu_client.exit(type, value, traceback)
def master(self, task_type=None, task_id=None, rpc_layer=None):
"""Get the Master string to be used for the session.
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
first instance in the ClusterSpec returned by the cluster_spec function.
If a non-TPU name is used when constructing a TPUClusterResolver, that will
be returned instead (e.g. If the tpus argument's value when constructing
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
'grpc://10.240.1.2:8470' will be returned).
Args:
task_type: (Optional, string) The type of the TensorFlow task of the
master.
task_id: (Optional, integer) The index of the TensorFlow task of the
master.
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
communicate with TPUs.
Returns:
string, the connection string to use when creating a session.
Raises:
ValueError: If none of the TPUs specified exists.
"""
cluster_spec = self.cluster_spec()
if task_type is not None and task_id is not None:
# task_type and task_id is from the function parameter
master = cluster_spec.task_address(task_type, task_id)
elif self.task_type is not None and self.task_id is not None:
# task_type and task_id is from the object
master = cluster_spec.task_address(self.task_type, self.task_id)
else:
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
return cluster_resolver.format_master_url(master, 'grpc')
def get_master(self):
return self.master()
def get_job_name(self):
return self.task_type
def get_tpu_system_metadata(self):
"""Returns the metadata of the TPU system.
Users can call this method to get some facts of the TPU system, like
total number of cores, number of TPU workers and the devices. E.g.
```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tpu_system_medata = resolver.get_tpu_system_metadata()
num_hosts = tpu_system_medata.num_hosts
```
Returns:
A `tf.tpu.experimental.TPUSystemMetadata` object.
"""
cluster_spec = self.cluster_spec()
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access
self.master(),
cluster_def=cluster_def,
query_topology=False))
return tpu_system_metadata
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
We retrieve the information from the GCE APIs every time this method is
called.
Returns:
A ClusterSpec containing host information returned from Cloud TPUs,
or None.
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
############################################################################
# There are 5 potential cases this code must handle:
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
# a. Create a ClusterSpec that includes the coordinator job
# b. Create a ClusterSpec without the coordinator job.
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
# tasks and
# a. Create a ClusterSpec with the coordinator
# b. Create a ClusterSpec without the coordinator
############################################################################
network_endpoints = self._cloud_tpu_client.network_endpoints()
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in network_endpoints
]
cluster_spec = {self.task_type: worker_list}
if self._coordinator_address:
# {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
def num_accelerators(self,
task_type=None,
task_id=None,
config_proto=None):
"""Returns the number of TPU cores per worker.
Connects to the master and list all the devices present in the master,
and counts them up. Also verifies that the device counts per host in the
cluster is the same before returning the number of TPU cores per host.
Args:
task_type: Unused.
task_id: Unused.
config_proto: Used to create a connection to a TPU master in order to
retrieve the system metadata.
Raises:
RuntimeError: If we cannot talk to a TPU worker after retrying or if the
number of TPU devices per host is different.
"""
retry_count = 1
# TODO(b/120564445): Replace with standard library for retries.
while True:
try:
device_details = TPUClusterResolver._get_device_dict_and_cores(
cluster_resolver.get_accelerator_devices(
self.master(), config_proto=config_proto))
break
except errors.DeadlineExceededError:
error_message = ('Failed to connect to master. The TPU might not be '
'ready (e.g. still scheduling) or the master '
'address is incorrect: got (%s)' % self.master())
if retry_count <= _TPU_CONN_RETRIES:
logging.warning(error_message)
logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
retry_count += 1
else:
raise RuntimeError(error_message)
if device_details.total_cores:
return {'TPU': TPUClusterResolver._verify_and_return_same_core_count(
device_details.device_map)}
return {'TPU': 0}
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
return self._environment
def _start_local_server(self):
address = compat.as_text(self._cloud_tpu_client.get_local_ip())
self._server = server_lib.Server({'local': ['0.0.0.0:0']},
protocol='grpc',
config=None,
start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))
assert len(splits) == 3, self._server.target
assert splits[0] == compat.as_bytes('grpc'), self._server.target
self._coordinator_port = compat.as_text(splits[2])
self._coordinator_address = '%s:%s' % (
address, compat.as_text(self._coordinator_port))
def __deepcopy__(self, memo):
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
return self
tf_export('distribute.cluster_resolver.TPUClusterResolver')(TPUClusterResolver)

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.cluster_resolver.TPUClusterResolver"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver\'>"
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver\'>"
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.cluster_resolver.ClusterResolver\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.cluster_resolver.TPUClusterResolver"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver\'>"
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver\'>"
is_instance: "<class \'tensorflow.python.distribute.cluster_resolver.cluster_resolver.ClusterResolver\'>"
is_instance: "<type \'object\'>"
member {