More typing of TPU codebase.

PiperOrigin-RevId: 353096631
Change-Id: Iee1ee03db223024f656bc3bd23f60b1d840b6acd
This commit is contained in:
Revan Sopher 2021-01-21 13:57:36 -08:00 committed by TensorFlower Gardener
parent 05cda15d06
commit 2cbbeaa0bf
6 changed files with 107 additions and 65 deletions

View File

@ -25,14 +25,18 @@ from __future__ import print_function
import os
import threading
import time
from typing import Any, List, Optional, Text
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
@ -40,13 +44,14 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
checkpoint_dir: Text,
save_secs: Optional[int] = None,
save_steps: Optional[int] = None,
saver: Optional[saver_lib.Saver] = None,
checkpoint_basename: Text = "model.ckpt",
scaffold: Optional[monitored_session.Scaffold] = None,
listeners: Optional[List[
basic_session_run_hooks.CheckpointSaverListener]] = None):
"""Initializes a `CheckpointSaverHook`.
Args:
@ -98,7 +103,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
for l in self._listeners:
l.begin()
def after_create_session(self, session, coord):
def after_create_session(self, session: session_lib.Session, coord: Any):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
@ -122,10 +127,11 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def before_run(self, run_context: Any): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
def after_run(self, run_context: session_run_hook.SessionRunContext,
run_values: Any):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
@ -133,7 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
if self._save(run_context.session, global_step):
run_context.request_stop()
def end(self, session):
def end(self, session: session_lib.Session):
if self._save_thread:
logging.info("Waiting for any pending checkpoints to finish.")
self._save_thread.join()

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Generator, Optional, Text
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
@ -70,10 +72,18 @@ def _get_custom_getter():
@tf_export(v1=['tpu.bfloat16_scope'])
@tf_contextlib.contextmanager
def bfloat16_scope(name=None):
def bfloat16_scope(
name: Optional[Text] = None
) -> Generator[variable_scope.variable_scope, None, None]:
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
Arguments:
name: Name to use for scope.
Yields:
a variable scope.
"""
if name is None:
name = ''

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Callable, Optional, Text, Union
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
@ -28,13 +30,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import functional_ops
def _TextLineDataset(filename):
def _TextLineDataset(filename: Text) -> dataset_ops.Dataset:
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
return dataset
def _TFRecordDataset(filename):
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
return dataset
@ -47,15 +49,17 @@ _FILETYPE_MAP = {
}
def StreamingFilesDataset(files,
filetype=None,
file_reader_job=None,
worker_job=None,
num_epochs=None,
filename_shuffle_buffer_size=None,
num_parallel_reads=None,
batch_transfer_size=None,
sloppy=None):
def StreamingFilesDataset(
files: Union[Text, dataset_ops.Dataset],
filetype: Optional[Union[Text, Callable[[Text],
dataset_ops.Dataset]]] = None,
file_reader_job: Optional[Text] = None,
worker_job: Optional[Text] = None,
num_epochs: Optional[int] = None,
filename_shuffle_buffer_size: Optional[Union[int, bool]] = None,
num_parallel_reads: Optional[int] = None,
batch_transfer_size: Optional[Union[int, bool]] = None,
sloppy: bool = True) -> dataset_ops.Dataset:
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
@ -126,9 +130,6 @@ def StreamingFilesDataset(files,
if batch_transfer_size is None:
batch_transfer_size = 256
if sloppy is None:
sloppy = True
if file_reader_job == 'coordinator':
file_reader_device = '/job:coordinator/task:0'
else:

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import enum
import math
from typing import List, Optional, Text, Tuple
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -66,7 +67,7 @@ class DeviceAssignment(object):
`DeviceAssignment` directly.
"""
def __init__(self, topology, core_assignment):
def __init__(self, topology: Topology, core_assignment: np.ndarray):
"""Constructs a `DeviceAssignment` object.
Args:
@ -104,22 +105,22 @@ class DeviceAssignment(object):
self._core_assignment, topology)
@property
def topology(self):
def topology(self) -> Topology:
"""A `Topology` that describes the TPU topology."""
return self._topology
@property
def num_cores_per_replica(self):
def num_cores_per_replica(self) -> int:
"""The number of cores per replica."""
return self._num_cores_per_replica
@property
def num_replicas(self):
def num_replicas(self) -> int:
"""The number of replicas of the computation."""
return self._num_replicas
@property
def core_assignment(self):
def core_assignment(self) -> np.ndarray:
"""The logical to physical core mapping.
Returns:
@ -129,11 +130,11 @@ class DeviceAssignment(object):
"""
return self._core_assignment
def coordinates(self, replica, logical_core):
def coordinates(self, replica: int, logical_core: int) -> Tuple: # pylint:disable=g-bare-generic
"""Returns the physical topology coordinates of a logical core."""
return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
"""Lookup replica ids by task number and logical core.
Args:
@ -153,31 +154,38 @@ class DeviceAssignment(object):
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
def tpu_ordinal(self, replica=0, logical_core=0):
def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
def host_device(self, replica=0, logical_core=0, job=None):
def host_device(self,
replica: int = 0,
logical_core: int = 0,
job: Optional[Text] = None) -> Text:
"""Returns the CPU device attached to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
def tpu_device(self, replica=0, logical_core=0, job=None):
def tpu_device(self,
replica: int = 0,
logical_core: int = 0,
job: Optional[Text] = None) -> Text:
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
@staticmethod
def build(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1):
def build(topology: Topology,
computation_shape: Optional[np.ndarray] = None,
computation_stride: Optional[np.ndarray] = None,
num_replicas: int = 1) -> "DeviceAssignment":
return device_assignment(topology, computation_shape, computation_stride,
num_replicas)
def _open_ring_2d(x_size, y_size, z_coord):
def _open_ring_2d(x_size: int, y_size: int,
z_coord: int) -> List[Tuple[int, int, int]]:
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
For example, in a 4x4 mesh, this returns the following order.
@ -213,7 +221,8 @@ def _open_ring_2d(x_size, y_size, z_coord):
return ret
def _ring_3d(x_size, y_size, z_size):
def _ring_3d(x_size: int, y_size: int,
z_size: int) -> List[Tuple[int, int, int]]:
"""Ring-order of a X by Y by Z mesh.
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
@ -325,11 +334,13 @@ class DeviceOrderMode(enum.IntEnum):
MESH = 2
def device_assignment(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1,
device_order_mode=DeviceOrderMode.AUTO):
def device_assignment(
topology: Topology,
computation_shape: Optional[np.ndarray] = None,
computation_stride: Optional[np.ndarray] = None,
num_replicas: int = 1,
device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
) -> DeviceAssignment:
"""Computes a device_assignment of a computation across a TPU topology.
Attempts to choose a compact grid of cores for locality.
@ -341,11 +352,12 @@ def device_assignment(topology,
optimal packing.
Args:
topology: A `Topology` object that describes the TPU cluster topology.
To obtain a TPU topology, evaluate the `Tensor` returned by
topology: A `Topology` object that describes the TPU cluster topology. To
obtain a TPU topology, evaluate the `Tensor` returned by
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
here.
computation_shape: A rank 1 int32 numpy array with size equal to the
topology rank, describing the shape of the computation's block of cores.
If None, the `computation_shape` is `[1] * topology_rank`.

View File

@ -20,7 +20,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import functools
from typing import Any, Dict, Callable, List, Optional, Text, Tuple
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
@ -229,7 +229,6 @@ class TPUEmbedding(tracking.AutoTrackable):
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=1024,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
@ -244,7 +243,7 @@ class TPUEmbedding(tracking.AutoTrackable):
def __init__(
self,
feature_config: Any,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
pipeline_execution_with_tensor_core: bool = False):
"""Creates the TPUEmbedding mid level API object.

View File

@ -19,15 +19,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, Iterable, List, Optional, Union
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import tpu_function
from tensorflow.python.types import core as core_types
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
def while_loop(condition: Callable[..., Any],
body: Callable[..., Any],
inputs: Optional[List[Any]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
name: Any = None) -> Any:
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
@ -41,10 +49,10 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
inputs: a list of initial values passed into the training loop, or None
(equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple of
arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
@ -178,7 +186,12 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
def repeat(
n: int,
body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
inputs: Optional[List[core_types.TensorLike]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
name: Any = None) -> List[core_types.TensorLike]:
"""Builds a training loop that executes a fixed number of iterations.
The set of loop-carried tensors correspond to `inputs`.
@ -188,11 +201,12 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None):
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
inputs: a list of initial values passed into the training loop or None
(equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple of
arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises: