Add types to tpu.py.
PiperOrigin-RevId: 340518665 Change-Id: I2dcbb84f1191b7482b553eaed2a08678e3986981
This commit is contained in:
parent
335255c6b3
commit
aea2e0ffd1
tensorflow/python/tpu
@ -90,6 +90,16 @@ tpu_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "device_assignment",
|
||||
srcs = ["device_assignment.py"],
|
||||
deps = [
|
||||
":topology",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "preempted_hook_py",
|
||||
srcs = ["preempted_hook.py"],
|
||||
@ -154,6 +164,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "topology",
|
||||
srcs = ["topology.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_py",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu",
|
||||
srcs = [
|
||||
@ -192,31 +213,30 @@ pytype_library(
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"bfloat16.py",
|
||||
"device_assignment.py",
|
||||
"session_support.py",
|
||||
"tensor_tracer.py",
|
||||
"tensor_tracer_flags.py",
|
||||
"tensor_tracer_report.py",
|
||||
"topology.py",
|
||||
"tpu_feed.py",
|
||||
"tpu_optimizer.py",
|
||||
"tpu_sharding.py",
|
||||
"tpu_strategy_util.py",
|
||||
"training_loop.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":datasets",
|
||||
":device_assignment",
|
||||
":functional",
|
||||
":topology",
|
||||
":tpu_feed",
|
||||
":tpu_function",
|
||||
":tpu_ops",
|
||||
":tpu_sharding",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -246,9 +266,13 @@ pytype_library(
|
||||
name = "tpu_py",
|
||||
srcs = ["tpu.py"],
|
||||
deps = [
|
||||
":device_assignment",
|
||||
":tpu_feed",
|
||||
":tpu_function",
|
||||
":tpu_name_util",
|
||||
":tpu_ops",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:auto_control_deps",
|
||||
"//tensorflow/python:c_api_util",
|
||||
@ -276,6 +300,21 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "tpu_feed",
|
||||
srcs = ["tpu_feed.py"],
|
||||
deps = [
|
||||
":tpu_name_util",
|
||||
":tpu_ops",
|
||||
":tpu_sharding",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "tpu_function",
|
||||
srcs = ["tpu_function.py"],
|
||||
@ -283,6 +322,14 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "tpu_sharding",
|
||||
srcs = ["tpu_sharding.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "tpu_system_metadata",
|
||||
srcs = ["tpu_system_metadata.py"],
|
||||
@ -351,7 +398,7 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["tpu_sharding_test.py"],
|
||||
deps = [
|
||||
":tpu",
|
||||
":tpu_sharding",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
@ -384,7 +431,7 @@ tf_py_test(
|
||||
size = "medium",
|
||||
srcs = ["topology_test.py"],
|
||||
deps = [
|
||||
":tpu",
|
||||
":topology",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
@ -426,6 +473,14 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_name_util",
|
||||
srcs = ["tpu_name_util.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "feature_column",
|
||||
srcs = ["feature_column.py"],
|
||||
|
@ -22,7 +22,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import enum
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -48,10 +49,16 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu import tpu_name_util
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.types import core as core_types
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
ops.NotDifferentiable("TPUReplicatedInput")
|
||||
@ -89,7 +96,10 @@ _OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
|
||||
_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
|
||||
|
||||
|
||||
def _tpu_system_device_name(job):
|
||||
core = tpu_name_util.core
|
||||
|
||||
|
||||
def _tpu_system_device_name(job: Optional[Text]) -> Text:
|
||||
"""Returns the device name for the TPU_SYSTEM device of `job`."""
|
||||
if job is None:
|
||||
return "/device:TPU_SYSTEM:0"
|
||||
@ -98,9 +108,11 @@ def _tpu_system_device_name(job):
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.initialize_system"])
|
||||
def initialize_system(embedding_config=None,
|
||||
job=None,
|
||||
compilation_failure_closes_chips=True):
|
||||
def initialize_system(
|
||||
embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
|
||||
job: Optional[Text] = None,
|
||||
compilation_failure_closes_chips: bool = True
|
||||
) -> core_types.Tensor:
|
||||
"""Initializes a distributed TPU system for use with TensorFlow.
|
||||
|
||||
Args:
|
||||
@ -136,7 +148,10 @@ def initialize_system(embedding_config=None,
|
||||
return array_ops.identity(topology, name="tpu_init_identity")
|
||||
|
||||
|
||||
def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
||||
def initialize_system_for_tpu_embedding(
|
||||
embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
|
||||
job: Optional[Text] = None,
|
||||
) -> ops.Operation:
|
||||
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
|
||||
|
||||
The following two are equivalent:
|
||||
@ -163,7 +178,7 @@ def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.shutdown_system"])
|
||||
def shutdown_system(job=None):
|
||||
def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
|
||||
"""Shuts down a running a distributed TPU system.
|
||||
|
||||
Args:
|
||||
@ -177,20 +192,7 @@ def shutdown_system(job=None):
|
||||
return shutdown_distributed_tpu
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.core"])
|
||||
def core(num):
|
||||
"""Returns the device name for a core in a replicated TPU computation.
|
||||
|
||||
Args:
|
||||
num: the virtual core number within each replica to which operators should
|
||||
be assigned.
|
||||
Returns:
|
||||
A device name, suitable for passing to `tf.device()`.
|
||||
"""
|
||||
return "device:TPU_REPLICATED_CORE:{}".format(num)
|
||||
|
||||
|
||||
def _enclosing_tpu_context_and_graph():
|
||||
def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
|
||||
"""Returns the TPUReplicateContext and its associated graph."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
@ -207,13 +209,14 @@ def _enclosing_tpu_context_and_graph():
|
||||
"a bug.")
|
||||
|
||||
|
||||
def is_tpu_strategy(strategy):
|
||||
def is_tpu_strategy(strategy: Any) -> bool:
|
||||
is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
|
||||
clz = strategy.__class__
|
||||
return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
|
||||
|
||||
|
||||
def _enclosing_tpu_device_assignment():
|
||||
def _enclosing_tpu_device_assignment(
|
||||
) -> Optional[device_assignment_lib.DeviceAssignment]:
|
||||
if not distribution_strategy_context.has_strategy():
|
||||
return None
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
@ -223,7 +226,10 @@ def _enclosing_tpu_device_assignment():
|
||||
|
||||
|
||||
@auto_control_deps.register_acd_resource_resolver
|
||||
def tpu_replicated_input_resolver(op, resource_reads, resource_writes):
|
||||
def tpu_replicated_input_resolver(
|
||||
op: ops.Operation,
|
||||
resource_reads: object_identity.ObjectIdentitySet,
|
||||
resource_writes: object_identity.ObjectIdentitySet) -> bool:
|
||||
"""Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
|
||||
# Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
|
||||
# control deps on the replicated inputs.
|
||||
@ -251,8 +257,8 @@ def tpu_replicated_input_resolver(op, resource_reads, resource_writes):
|
||||
resource_inputs.update(to_add)
|
||||
return to_add or to_remove
|
||||
|
||||
return (replace_with_unreplicated_resources(resource_reads) or
|
||||
replace_with_unreplicated_resources(resource_writes))
|
||||
return bool(replace_with_unreplicated_resources(resource_reads) or
|
||||
replace_with_unreplicated_resources(resource_writes))
|
||||
|
||||
|
||||
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
@ -270,7 +276,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
outside the replicated computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, num_replicas, pivot):
|
||||
def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
|
||||
"""Builds a new TPUReplicateContext.
|
||||
|
||||
Args:
|
||||
@ -301,8 +307,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._pivot = pivot
|
||||
self._replicated_vars = {}
|
||||
|
||||
def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
|
||||
is_packed=False):
|
||||
def get_replicated_var_handle(
|
||||
self,
|
||||
name: Text,
|
||||
vars_: List[variables.Variable],
|
||||
is_mirrored: bool = False,
|
||||
is_packed: bool = False) -> core_types.Tensor:
|
||||
"""Returns a variable handle for replicated TPU variable 'var'.
|
||||
|
||||
This is a method used by an experimental replicated variable implementation
|
||||
@ -368,7 +378,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._replicated_vars[name] = handle
|
||||
return handle
|
||||
|
||||
def report_unsupported_operations(self):
|
||||
def report_unsupported_operations(self) -> None:
|
||||
if self._unsupported_ops:
|
||||
op_str = "\n".join(" %s (%s)" % (op.type, op.name)
|
||||
for op in self._unsupported_ops[:_MAX_WARNING_LINES])
|
||||
@ -378,7 +388,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
logging.warning("... and %d more" %
|
||||
(len(self._unsupported_ops) - _MAX_WARNING_LINES))
|
||||
|
||||
def EnterGradientColocation(self, op, gradient_uid):
|
||||
def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
|
||||
if op is not None:
|
||||
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
|
||||
# If we are in TF 2 functions (control flow V2 functions, or
|
||||
@ -432,7 +442,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
# The attr was not present: do nothing.
|
||||
pass
|
||||
|
||||
def ExitGradientColocation(self, op, gradient_uid):
|
||||
def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
|
||||
if op is not None:
|
||||
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
|
||||
# Inside a TF2 tf.function or control flow graph and `op` was not
|
||||
@ -460,7 +470,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
op.node_def, op, "Badly nested gradient colocation, expected " +
|
||||
last_op + ", got " + op.name)
|
||||
|
||||
def _EnterOutsideCompilationScope(self, cluster=None):
|
||||
def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
|
||||
|
||||
class FakeOp(object):
|
||||
"""A helper class to determine the current device.
|
||||
@ -515,7 +525,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
graph = ops.get_default_graph()
|
||||
graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
|
||||
|
||||
def Enter(self):
|
||||
def Enter(self) -> None:
|
||||
if not self._outer_device_function_stack:
|
||||
# Capture the device function stack at the time of first entry
|
||||
# since that is the stack that will be used outside_compilation.
|
||||
@ -525,10 +535,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
# pylint: enable=protected-access
|
||||
super(TPUReplicateContext, self).Enter()
|
||||
|
||||
def HostComputeCore(self):
|
||||
def HostComputeCore(self) -> List[Text]:
|
||||
return self._host_compute_core
|
||||
|
||||
def _RemoveExternalControlEdges(self, op):
|
||||
def _RemoveExternalControlEdges(
|
||||
self, op: ops.Operation
|
||||
) -> Tuple[List[ops.Operation], List[ops.Operation]]:
|
||||
"""Remove any external control dependency on this op."""
|
||||
internal_control_inputs = []
|
||||
external_control_inputs = []
|
||||
@ -552,12 +564,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
# pylint: enable=protected-access
|
||||
return internal_control_inputs, external_control_inputs
|
||||
|
||||
def AddOp(self, op):
|
||||
def AddOp(self, op: ops.Operation) -> None:
|
||||
# pylint: disable=protected-access
|
||||
if op.type in _DENYLISTED_OPS:
|
||||
logging.error("Operation of type %s (%s) is not supported on the TPU. "
|
||||
"Execution will fail if this op is used in the graph. " %
|
||||
(op.type, op.name))
|
||||
"Execution will fail if this op is used in the graph. ",
|
||||
op.type, op.name)
|
||||
|
||||
if op.type in _UNSUPPORTED_OPS:
|
||||
self._unsupported_ops.append(op)
|
||||
@ -631,7 +643,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
def AddValue(self, val):
|
||||
def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
|
||||
"""Add `val` to the current context and its outer context recursively."""
|
||||
if not self._outer_context:
|
||||
return val
|
||||
@ -651,7 +663,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
|
||||
return result
|
||||
|
||||
def AddInnerOp(self, op):
|
||||
def AddInnerOp(self, op: ops.Operation):
|
||||
self.AddOp(op)
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
@ -671,7 +683,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
return self.GetWhileContext().back_prop
|
||||
return False
|
||||
|
||||
def GetControlPivot(self):
|
||||
def GetControlPivot(self) -> ops.Operation:
|
||||
return self._pivot
|
||||
|
||||
def RequiresUniqueFunctionRetracing(self):
|
||||
@ -688,11 +700,11 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||
attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: Text):
|
||||
control_flow_ops.ControlFlowContext.__init__(self)
|
||||
self._name = name
|
||||
|
||||
def AddOp(self, op):
|
||||
def AddOp(self, op: ops.Operation) -> None:
|
||||
if self._outer_context:
|
||||
self._outer_context.AddOp(op)
|
||||
# pylint: disable=protected-access
|
||||
@ -700,7 +712,7 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def AddInnerOp(self, op):
|
||||
def AddInnerOp(self, op: ops.Operation) -> None:
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
# pylint: disable=protected-access
|
||||
@ -713,7 +725,9 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.outside_compilation"])
|
||||
def outside_compilation(computation, *args, **kwargs):
|
||||
def outside_compilation(
|
||||
computation: Callable[..., Any], *args, **kwargs
|
||||
) -> Any:
|
||||
"""Builds part of a computation outside any current TPU replicate scope.
|
||||
|
||||
`tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
|
||||
@ -862,14 +876,15 @@ class XLAOptions(
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.replicate"])
|
||||
def replicate(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
maximum_shapes=None,
|
||||
padding_spec=None,
|
||||
xla_options=None):
|
||||
def replicate(
|
||||
computation: Callable[..., Any],
|
||||
inputs: Optional[List[List[core_types.Tensor]]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
maximum_shapes: Any = None,
|
||||
padding_spec: Optional[PaddingSpec] = None,
|
||||
xla_options: Optional[XLAOptions] = None) -> List[Any]:
|
||||
"""Builds a graph operator that runs a replicated TPU computation.
|
||||
|
||||
Example for the basic usage that `inputs` has static shape:
|
||||
@ -978,7 +993,11 @@ def _ceil_to_pow_of_n(x, n):
|
||||
return result
|
||||
|
||||
|
||||
def _pad_all_input(inputs, padded_shapes, padding_spec):
|
||||
def _pad_all_input(
|
||||
inputs: Iterable[core_types.Tensor],
|
||||
padded_shapes: List[Optional[tensor_shape.TensorShape]],
|
||||
padding_spec: PaddingSpec
|
||||
) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
|
||||
"""Pad all input tensors given padded_shapes.
|
||||
|
||||
The real shape tensors will be concatenated with the padded original inputs.
|
||||
@ -1150,15 +1169,17 @@ def _flatten_and_filter_composite(maybe_composite, non_composite_output,
|
||||
return non_composite_output
|
||||
|
||||
|
||||
def split_compile_and_replicate(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
use_tpu=True,
|
||||
maximum_shapes=None,
|
||||
padding_spec=None,
|
||||
xla_options=None):
|
||||
def split_compile_and_replicate(
|
||||
computation: Callable[..., Any],
|
||||
inputs: List[List[Optional[core_types.Tensor]]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
use_tpu: bool = True,
|
||||
maximum_shapes: Any = None,
|
||||
padding_spec: Optional[PaddingSpec] = None,
|
||||
xla_options: Optional[XLAOptions] = None,
|
||||
) -> List[List[core_types.Tensor]]:
|
||||
"""Builds graph operators that runs compilation and replicated computation.
|
||||
|
||||
This is a lower level interface than replicate that returns a separate compile
|
||||
@ -1517,7 +1538,9 @@ def split_compile_and_replicate(computation,
|
||||
return [compile_status, replicated_outputs]
|
||||
|
||||
|
||||
def _postprocess_flat_outputs(outputs):
|
||||
def _postprocess_flat_outputs(
|
||||
outputs: Any
|
||||
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
|
||||
"""Validates non-flat outputs, add backs device assignments and other attrs.
|
||||
|
||||
Args:
|
||||
@ -1593,7 +1616,9 @@ def _postprocess_flat_outputs(outputs):
|
||||
return new_output_tensors, output_operations, pack_template
|
||||
|
||||
|
||||
def _postprocess_non_flat_outputs(outputs):
|
||||
def _postprocess_non_flat_outputs(
|
||||
outputs: Any
|
||||
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
|
||||
"""Validates non-flat outputs, add backs device assignments and other attrs.
|
||||
|
||||
Args:
|
||||
@ -1641,16 +1666,18 @@ def _postprocess_non_flat_outputs(outputs):
|
||||
return flat_outputs, [], outputs
|
||||
|
||||
|
||||
def split_compile_and_shard(computation,
|
||||
inputs=None,
|
||||
num_shards=1,
|
||||
input_shard_axes=None,
|
||||
outputs_from_all_shards=True,
|
||||
output_shard_axes=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
xla_options=None):
|
||||
def split_compile_and_shard(
|
||||
computation: Callable[..., Any],
|
||||
inputs: List[List[Optional[core_types.Tensor]]] = None,
|
||||
num_shards: int = 1,
|
||||
input_shard_axes: Optional[List[int]] = None,
|
||||
outputs_from_all_shards: Union[bool, List[bool]] = True,
|
||||
output_shard_axes: Optional[List[int]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
xla_options: Optional[XLAOptions] = None,
|
||||
) -> Tuple[ops.Operation, List[core_types.Tensor]]:
|
||||
"""Shards `computation` for parallel execution.
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||
@ -1796,16 +1823,17 @@ def split_compile_and_shard(computation,
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.shard"])
|
||||
def shard(computation,
|
||||
inputs=None,
|
||||
num_shards=1,
|
||||
input_shard_axes=None,
|
||||
outputs_from_all_shards=True,
|
||||
output_shard_axes=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
xla_options=None):
|
||||
def shard(
|
||||
computation: Callable[..., Any],
|
||||
inputs: Optional[List[core_types.Tensor]] = None,
|
||||
num_shards: int = 1,
|
||||
input_shard_axes: Optional[List[int]] = None,
|
||||
outputs_from_all_shards: Union[bool, List[bool]] = True,
|
||||
output_shard_axes: Optional[List[int]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
|
||||
"""Shards `computation` for parallel execution.
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
|
||||
@ -1881,13 +1909,14 @@ def shard(computation,
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.batch_parallel"])
|
||||
def batch_parallel(computation,
|
||||
inputs=None,
|
||||
num_shards=1,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
xla_options=None):
|
||||
def batch_parallel(
|
||||
computation: Callable[..., Any],
|
||||
inputs: List[List[Optional[core_types.Tensor]]] = None,
|
||||
num_shards: int = 1,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
xla_options: Optional[XLAOptions] = None):
|
||||
"""Shards `computation` along the batch dimension for parallel execution.
|
||||
|
||||
Convenience wrapper around shard().
|
||||
@ -1942,12 +1971,13 @@ def batch_parallel(computation,
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.rewrite"])
|
||||
def rewrite(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
xla_options=None):
|
||||
def rewrite(
|
||||
computation: Callable[..., Any],
|
||||
inputs: List[List[Optional[core_types.Tensor]]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None,
|
||||
xla_options: Optional[XLAOptions] = None) -> Any:
|
||||
"""Rewrites `computation` for execution on a TPU system.
|
||||
|
||||
Args:
|
||||
@ -2012,7 +2042,7 @@ _DENYLISTED_INFERENCE_OPS = set([
|
||||
])
|
||||
|
||||
|
||||
def under_tpu_inference_context():
|
||||
def under_tpu_inference_context() -> bool:
|
||||
"""Check if it is currently under `_TPUInferenceContext`."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph:
|
||||
@ -2037,7 +2067,7 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
|
||||
tpu.rewrite_for_inference() computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, check_ops=True):
|
||||
def __init__(self, name: Text, check_ops: bool = True):
|
||||
super(_TPUInferenceContext, self).__init__()
|
||||
self._name = name
|
||||
self._check_ops = check_ops
|
||||
@ -2069,7 +2099,7 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
|
||||
return None
|
||||
|
||||
|
||||
def validate_inference_rewrite_for_variables(graph):
|
||||
def validate_inference_rewrite_for_variables(graph: ops.Graph):
|
||||
"""Validates whether rewrite_for_inference() 'worked' for variables.
|
||||
|
||||
The rewrite_for_inference() method is supposed to append GuaranteeConstOps
|
||||
@ -2098,11 +2128,12 @@ def validate_inference_rewrite_for_variables(graph):
|
||||
"computation.")
|
||||
|
||||
|
||||
def rewrite_for_inference(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None):
|
||||
def rewrite_for_inference(
|
||||
computation: Callable[..., Any],
|
||||
inputs: Optional[List[core_types.Tensor]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
|
||||
name: Optional[Text] = None) -> List[core_types.Tensor]:
|
||||
"""Rewrites `computation` for inference on a TPU system.
|
||||
|
||||
Other than 'rewriting' the computation to run on a TPU, if using variables
|
||||
@ -2167,7 +2198,7 @@ def rewrite_for_inference(computation,
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
||||
def prune_unconnected_ops_from_xla(prune_graph):
|
||||
def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
|
||||
"""Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
|
||||
|
||||
Args:
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_name_util
|
||||
from tensorflow.python.tpu import tpu_sharding
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
@ -502,7 +502,7 @@ class InfeedQueue(object):
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
if tpu_device is not None:
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
with ops.device(tpu_name_util.core(tpu_device)):
|
||||
dequeue_op = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
else:
|
||||
@ -808,7 +808,7 @@ class _PartitionedInfeedQueue(InfeedQueue):
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
with ops.device(tpu_name_util.core(tpu_device)):
|
||||
values = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
return tag_sharding_attribute_for_dequeued_tensors(
|
||||
|
35
tensorflow/python/tpu/tpu_name_util.py
Normal file
35
tensorflow/python/tpu/tpu_name_util.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Helper functions for TPU device names."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Text
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.core"])
|
||||
def core(num: int) -> Text:
|
||||
"""Returns the device name for a core in a replicated TPU computation.
|
||||
|
||||
Args:
|
||||
num: the virtual core number within each replica to which operators should
|
||||
be assigned.
|
||||
Returns:
|
||||
A device name, suitable for passing to `tf.device()`.
|
||||
"""
|
||||
return "device:TPU_REPLICATED_CORE:{}".format(num)
|
Loading…
Reference in New Issue
Block a user