More typing of TPU codebase.
PiperOrigin-RevId: 353096631 Change-Id: Iee1ee03db223024f656bc3bd23f60b1d840b6acd
This commit is contained in:
parent
05cda15d06
commit
2cbbeaa0bf
@ -25,14 +25,18 @@ from __future__ import print_function
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, List, Optional, Text
|
||||
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
|
||||
|
||||
@ -40,13 +44,14 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
"""Saves checkpoints every N steps or seconds."""
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_dir,
|
||||
save_secs=None,
|
||||
save_steps=None,
|
||||
saver=None,
|
||||
checkpoint_basename="model.ckpt",
|
||||
scaffold=None,
|
||||
listeners=None):
|
||||
checkpoint_dir: Text,
|
||||
save_secs: Optional[int] = None,
|
||||
save_steps: Optional[int] = None,
|
||||
saver: Optional[saver_lib.Saver] = None,
|
||||
checkpoint_basename: Text = "model.ckpt",
|
||||
scaffold: Optional[monitored_session.Scaffold] = None,
|
||||
listeners: Optional[List[
|
||||
basic_session_run_hooks.CheckpointSaverListener]] = None):
|
||||
"""Initializes a `CheckpointSaverHook`.
|
||||
|
||||
Args:
|
||||
@ -98,7 +103,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
for l in self._listeners:
|
||||
l.begin()
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
def after_create_session(self, session: session_lib.Session, coord: Any):
|
||||
global_step = session.run(self._global_step_tensor)
|
||||
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
@ -122,10 +127,11 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
self._save(session, global_step)
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
return SessionRunArgs(self._global_step_tensor)
|
||||
def before_run(self, run_context: Any): # pylint: disable=unused-argument
|
||||
return session_run_hook.SessionRunArgs(self._global_step_tensor)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
def after_run(self, run_context: session_run_hook.SessionRunContext,
|
||||
run_values: Any):
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
@ -133,7 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
if self._save(run_context.session, global_step):
|
||||
run_context.request_stop()
|
||||
|
||||
def end(self, session):
|
||||
def end(self, session: session_lib.Session):
|
||||
if self._save_thread:
|
||||
logging.info("Waiting for any pending checkpoints to finish.")
|
||||
self._save_thread.join()
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Generator, Optional, Text
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -70,10 +72,18 @@ def _get_custom_getter():
|
||||
|
||||
@tf_export(v1=['tpu.bfloat16_scope'])
|
||||
@tf_contextlib.contextmanager
|
||||
def bfloat16_scope(name=None):
|
||||
def bfloat16_scope(
|
||||
name: Optional[Text] = None
|
||||
) -> Generator[variable_scope.variable_scope, None, None]:
|
||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||
|
||||
This enables variables to be read as bfloat16 type when using get_variable.
|
||||
|
||||
Arguments:
|
||||
name: Name to use for scope.
|
||||
|
||||
Yields:
|
||||
a variable scope.
|
||||
"""
|
||||
if name is None:
|
||||
name = ''
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Callable, Optional, Text, Union
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
@ -28,13 +30,13 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
|
||||
|
||||
def _TextLineDataset(filename):
|
||||
def _TextLineDataset(filename: Text) -> dataset_ops.Dataset:
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def _TFRecordDataset(filename):
|
||||
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
@ -47,15 +49,17 @@ _FILETYPE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def StreamingFilesDataset(files,
|
||||
filetype=None,
|
||||
file_reader_job=None,
|
||||
worker_job=None,
|
||||
num_epochs=None,
|
||||
filename_shuffle_buffer_size=None,
|
||||
num_parallel_reads=None,
|
||||
batch_transfer_size=None,
|
||||
sloppy=None):
|
||||
def StreamingFilesDataset(
|
||||
files: Union[Text, dataset_ops.Dataset],
|
||||
filetype: Optional[Union[Text, Callable[[Text],
|
||||
dataset_ops.Dataset]]] = None,
|
||||
file_reader_job: Optional[Text] = None,
|
||||
worker_job: Optional[Text] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
filename_shuffle_buffer_size: Optional[Union[int, bool]] = None,
|
||||
num_parallel_reads: Optional[int] = None,
|
||||
batch_transfer_size: Optional[Union[int, bool]] = None,
|
||||
sloppy: bool = True) -> dataset_ops.Dataset:
|
||||
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
||||
|
||||
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
||||
@ -126,9 +130,6 @@ def StreamingFilesDataset(files,
|
||||
if batch_transfer_size is None:
|
||||
batch_transfer_size = 256
|
||||
|
||||
if sloppy is None:
|
||||
sloppy = True
|
||||
|
||||
if file_reader_job == 'coordinator':
|
||||
file_reader_device = '/job:coordinator/task:0'
|
||||
else:
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import enum
|
||||
import math
|
||||
from typing import List, Optional, Text, Tuple
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
@ -66,7 +67,7 @@ class DeviceAssignment(object):
|
||||
`DeviceAssignment` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, topology, core_assignment):
|
||||
def __init__(self, topology: Topology, core_assignment: np.ndarray):
|
||||
"""Constructs a `DeviceAssignment` object.
|
||||
|
||||
Args:
|
||||
@ -104,22 +105,22 @@ class DeviceAssignment(object):
|
||||
self._core_assignment, topology)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
def topology(self) -> Topology:
|
||||
"""A `Topology` that describes the TPU topology."""
|
||||
return self._topology
|
||||
|
||||
@property
|
||||
def num_cores_per_replica(self):
|
||||
def num_cores_per_replica(self) -> int:
|
||||
"""The number of cores per replica."""
|
||||
return self._num_cores_per_replica
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
def num_replicas(self) -> int:
|
||||
"""The number of replicas of the computation."""
|
||||
return self._num_replicas
|
||||
|
||||
@property
|
||||
def core_assignment(self):
|
||||
def core_assignment(self) -> np.ndarray:
|
||||
"""The logical to physical core mapping.
|
||||
|
||||
Returns:
|
||||
@ -129,11 +130,11 @@ class DeviceAssignment(object):
|
||||
"""
|
||||
return self._core_assignment
|
||||
|
||||
def coordinates(self, replica, logical_core):
|
||||
def coordinates(self, replica: int, logical_core: int) -> Tuple: # pylint:disable=g-bare-generic
|
||||
"""Returns the physical topology coordinates of a logical core."""
|
||||
return tuple(self.core_assignment[replica, logical_core, :])
|
||||
|
||||
def lookup_replicas(self, task_id, logical_core):
|
||||
def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
|
||||
"""Lookup replica ids by task number and logical core.
|
||||
|
||||
Args:
|
||||
@ -153,31 +154,38 @@ class DeviceAssignment(object):
|
||||
"Can not find any replica in task: {} contains logical_core: {} ".
|
||||
format(task_id, logical_core))
|
||||
|
||||
def tpu_ordinal(self, replica=0, logical_core=0):
|
||||
def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
|
||||
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
||||
|
||||
def host_device(self, replica=0, logical_core=0, job=None):
|
||||
def host_device(self,
|
||||
replica: int = 0,
|
||||
logical_core: int = 0,
|
||||
job: Optional[Text] = None) -> Text:
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
def tpu_device(self, replica=0, logical_core=0, job=None):
|
||||
def tpu_device(self,
|
||||
replica: int = 0,
|
||||
logical_core: int = 0,
|
||||
job: Optional[Text] = None) -> Text:
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
@staticmethod
|
||||
def build(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1):
|
||||
def build(topology: Topology,
|
||||
computation_shape: Optional[np.ndarray] = None,
|
||||
computation_stride: Optional[np.ndarray] = None,
|
||||
num_replicas: int = 1) -> "DeviceAssignment":
|
||||
return device_assignment(topology, computation_shape, computation_stride,
|
||||
num_replicas)
|
||||
|
||||
|
||||
def _open_ring_2d(x_size, y_size, z_coord):
|
||||
def _open_ring_2d(x_size: int, y_size: int,
|
||||
z_coord: int) -> List[Tuple[int, int, int]]:
|
||||
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
|
||||
|
||||
For example, in a 4x4 mesh, this returns the following order.
|
||||
@ -213,7 +221,8 @@ def _open_ring_2d(x_size, y_size, z_coord):
|
||||
return ret
|
||||
|
||||
|
||||
def _ring_3d(x_size, y_size, z_size):
|
||||
def _ring_3d(x_size: int, y_size: int,
|
||||
z_size: int) -> List[Tuple[int, int, int]]:
|
||||
"""Ring-order of a X by Y by Z mesh.
|
||||
|
||||
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
|
||||
@ -325,11 +334,13 @@ class DeviceOrderMode(enum.IntEnum):
|
||||
MESH = 2
|
||||
|
||||
|
||||
def device_assignment(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1,
|
||||
device_order_mode=DeviceOrderMode.AUTO):
|
||||
def device_assignment(
|
||||
topology: Topology,
|
||||
computation_shape: Optional[np.ndarray] = None,
|
||||
computation_stride: Optional[np.ndarray] = None,
|
||||
num_replicas: int = 1,
|
||||
device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
|
||||
) -> DeviceAssignment:
|
||||
"""Computes a device_assignment of a computation across a TPU topology.
|
||||
|
||||
Attempts to choose a compact grid of cores for locality.
|
||||
@ -341,11 +352,12 @@ def device_assignment(topology,
|
||||
optimal packing.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the TPU cluster topology.
|
||||
To obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
topology: A `Topology` object that describes the TPU cluster topology. To
|
||||
obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
`initialize_system` using `Session.run`. Either a serialized
|
||||
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
|
||||
here.
|
||||
computation_shape: A rank 1 int32 numpy array with size equal to the
|
||||
topology rank, describing the shape of the computation's block of cores.
|
||||
If None, the `computation_shape` is `[1] * topology_rank`.
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple
|
||||
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -229,7 +229,6 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
model = model_fn(...)
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
||||
feature_config=feature_config,
|
||||
batch_size=1024,
|
||||
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
||||
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
|
||||
checkpoint.restore(...)
|
||||
@ -244,7 +243,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_config: Any,
|
||||
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
|
||||
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
|
||||
pipeline_execution_with_tensor_core: bool = False):
|
||||
"""Creates the TPUEmbedding mid level API object.
|
||||
|
@ -19,15 +19,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.tpu import tensor_tracer
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.types import core as core_types
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
def while_loop(condition: Callable[..., Any],
|
||||
body: Callable[..., Any],
|
||||
inputs: Optional[List[Any]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
name: Any = None) -> Any:
|
||||
"""Builds a training loop for TPUs.
|
||||
|
||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||
@ -41,10 +49,10 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
Args:
|
||||
condition: a Python function that builds the loop condition.
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop, or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
inputs: a list of initial values passed into the training loop, or None
|
||||
(equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||
arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
@ -178,7 +186,12 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
||||
|
||||
|
||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
def repeat(
|
||||
n: int,
|
||||
body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
|
||||
inputs: Optional[List[core_types.TensorLike]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
name: Any = None) -> List[core_types.TensorLike]:
|
||||
"""Builds a training loop that executes a fixed number of iterations.
|
||||
|
||||
The set of loop-carried tensors correspond to `inputs`.
|
||||
@ -188,11 +201,12 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
Args:
|
||||
n: the number of loop iterations
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
inputs: a list of initial values passed into the training loop or None
|
||||
(equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||
arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
Raises:
|
||||
|
Loading…
x
Reference in New Issue
Block a user