Add types to tpu.py.

PiperOrigin-RevId: 340518665
Change-Id: I2dcbb84f1191b7482b553eaed2a08678e3986981
This commit is contained in:
Revan Sopher 2020-11-03 13:55:05 -08:00 committed by TensorFlower Gardener
parent 335255c6b3
commit aea2e0ffd1
4 changed files with 239 additions and 118 deletions

View File

@ -90,6 +90,16 @@ tpu_py_test(
],
)
pytype_library(
name = "device_assignment",
srcs = ["device_assignment.py"],
deps = [
":topology",
"//tensorflow/python:platform",
"//tensorflow/python:tf_export",
],
)
pytype_library(
name = "preempted_hook_py",
srcs = ["preempted_hook.py"],
@ -154,6 +164,17 @@ py_library(
],
)
pytype_library(
name = "topology",
srcs = ["topology.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core/protobuf/tpu:topology_proto_py",
"//tensorflow/python:tf_export",
"//third_party/py/numpy",
],
)
py_library(
name = "tpu",
srcs = [
@ -192,31 +213,30 @@ pytype_library(
srcs = [
"__init__.py",
"bfloat16.py",
"device_assignment.py",
"session_support.py",
"tensor_tracer.py",
"tensor_tracer_flags.py",
"tensor_tracer_report.py",
"topology.py",
"tpu_feed.py",
"tpu_optimizer.py",
"tpu_sharding.py",
"tpu_strategy_util.py",
"training_loop.py",
],
srcs_version = "PY2AND3",
deps = [
":datasets",
":device_assignment",
":functional",
":topology",
":tpu_feed",
":tpu_function",
":tpu_ops",
":tpu_sharding",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/core:protos_all_py",
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
"//tensorflow/core/protobuf/tpu:topology_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
"//tensorflow/python:array_ops",
@ -246,9 +266,13 @@ pytype_library(
name = "tpu_py",
srcs = ["tpu.py"],
deps = [
":device_assignment",
":tpu_feed",
":tpu_function",
":tpu_name_util",
":tpu_ops",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:auto_control_deps",
"//tensorflow/python:c_api_util",
@ -276,6 +300,21 @@ pytype_library(
],
)
pytype_library(
name = "tpu_feed",
srcs = ["tpu_feed.py"],
deps = [
":tpu_name_util",
":tpu_ops",
":tpu_sharding",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:ops",
"//tensorflow/python:tensor_shape",
],
)
pytype_library(
name = "tpu_function",
srcs = ["tpu_function.py"],
@ -283,6 +322,14 @@ pytype_library(
],
)
pytype_library(
name = "tpu_sharding",
srcs = ["tpu_sharding.py"],
deps = [
"//tensorflow/python:tensor_shape",
],
)
pytype_library(
name = "tpu_system_metadata",
srcs = ["tpu_system_metadata.py"],
@ -351,7 +398,7 @@ tf_py_test(
size = "small",
srcs = ["tpu_sharding_test.py"],
deps = [
":tpu",
":tpu_sharding",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
@ -384,7 +431,7 @@ tf_py_test(
size = "medium",
srcs = ["topology_test.py"],
deps = [
":tpu",
":topology",
"//tensorflow/python:framework_test_lib",
],
)
@ -426,6 +473,14 @@ pytype_library(
],
)
py_library(
name = "tpu_name_util",
srcs = ["tpu_name_util.py"],
deps = [
"//tensorflow/python:tf_export",
],
)
pytype_library(
name = "feature_column",
srcs = ["feature_column.py"],

View File

@ -22,7 +22,7 @@ from __future__ import print_function
import collections
import enum
import typing
from typing import Any
from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
import numpy as np
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
@ -48,10 +49,16 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu import tpu_name_util
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.types import core as core_types
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("TPUReplicatedInput")
@ -89,7 +96,10 @@ _OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
def _tpu_system_device_name(job):
core = tpu_name_util.core
def _tpu_system_device_name(job: Optional[Text]) -> Text:
"""Returns the device name for the TPU_SYSTEM device of `job`."""
if job is None:
return "/device:TPU_SYSTEM:0"
@ -98,9 +108,11 @@ def _tpu_system_device_name(job):
@tf_export(v1=["tpu.initialize_system"])
def initialize_system(embedding_config=None,
job=None,
compilation_failure_closes_chips=True):
def initialize_system(
embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
job: Optional[Text] = None,
compilation_failure_closes_chips: bool = True
) -> core_types.Tensor:
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
@ -136,7 +148,10 @@ def initialize_system(embedding_config=None,
return array_ops.identity(topology, name="tpu_init_identity")
def initialize_system_for_tpu_embedding(embedding_config, job=None):
def initialize_system_for_tpu_embedding(
embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
job: Optional[Text] = None,
) -> ops.Operation:
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
The following two are equivalent:
@ -163,7 +178,7 @@ def initialize_system_for_tpu_embedding(embedding_config, job=None):
@tf_export(v1=["tpu.shutdown_system"])
def shutdown_system(job=None):
def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
"""Shuts down a running a distributed TPU system.
Args:
@ -177,20 +192,7 @@ def shutdown_system(job=None):
return shutdown_distributed_tpu
@tf_export(v1=["tpu.core"])
def core(num):
"""Returns the device name for a core in a replicated TPU computation.
Args:
num: the virtual core number within each replica to which operators should
be assigned.
Returns:
A device name, suitable for passing to `tf.device()`.
"""
return "device:TPU_REPLICATED_CORE:{}".format(num)
def _enclosing_tpu_context_and_graph():
def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
"""Returns the TPUReplicateContext and its associated graph."""
graph = ops.get_default_graph()
while graph is not None:
@ -207,13 +209,14 @@ def _enclosing_tpu_context_and_graph():
"a bug.")
def is_tpu_strategy(strategy):
def is_tpu_strategy(strategy: Any) -> bool:
is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
clz = strategy.__class__
return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
def _enclosing_tpu_device_assignment():
def _enclosing_tpu_device_assignment(
) -> Optional[device_assignment_lib.DeviceAssignment]:
if not distribution_strategy_context.has_strategy():
return None
strategy = distribution_strategy_context.get_strategy()
@ -223,7 +226,10 @@ def _enclosing_tpu_device_assignment():
@auto_control_deps.register_acd_resource_resolver
def tpu_replicated_input_resolver(op, resource_reads, resource_writes):
def tpu_replicated_input_resolver(
op: ops.Operation,
resource_reads: object_identity.ObjectIdentitySet,
resource_writes: object_identity.ObjectIdentitySet) -> bool:
"""Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
# Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
# control deps on the replicated inputs.
@ -251,8 +257,8 @@ def tpu_replicated_input_resolver(op, resource_reads, resource_writes):
resource_inputs.update(to_add)
return to_add or to_remove
return (replace_with_unreplicated_resources(resource_reads) or
replace_with_unreplicated_resources(resource_writes))
return bool(replace_with_unreplicated_resources(resource_reads) or
replace_with_unreplicated_resources(resource_writes))
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
@ -270,7 +276,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
def __init__(self, name, num_replicas, pivot):
def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
"""Builds a new TPUReplicateContext.
Args:
@ -301,8 +307,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
is_packed=False):
def get_replicated_var_handle(
self,
name: Text,
vars_: List[variables.Variable],
is_mirrored: bool = False,
is_packed: bool = False) -> core_types.Tensor:
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
@ -368,7 +378,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._replicated_vars[name] = handle
return handle
def report_unsupported_operations(self):
def report_unsupported_operations(self) -> None:
if self._unsupported_ops:
op_str = "\n".join(" %s (%s)" % (op.type, op.name)
for op in self._unsupported_ops[:_MAX_WARNING_LINES])
@ -378,7 +388,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
logging.warning("... and %d more" %
(len(self._unsupported_ops) - _MAX_WARNING_LINES))
def EnterGradientColocation(self, op, gradient_uid):
def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
if op is not None:
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
# If we are in TF 2 functions (control flow V2 functions, or
@ -432,7 +442,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# The attr was not present: do nothing.
pass
def ExitGradientColocation(self, op, gradient_uid):
def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
if op is not None:
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
# Inside a TF2 tf.function or control flow graph and `op` was not
@ -460,7 +470,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
op.node_def, op, "Badly nested gradient colocation, expected " +
last_op + ", got " + op.name)
def _EnterOutsideCompilationScope(self, cluster=None):
def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
class FakeOp(object):
"""A helper class to determine the current device.
@ -515,7 +525,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
graph = ops.get_default_graph()
graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
def Enter(self):
def Enter(self) -> None:
if not self._outer_device_function_stack:
# Capture the device function stack at the time of first entry
# since that is the stack that will be used outside_compilation.
@ -525,10 +535,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# pylint: enable=protected-access
super(TPUReplicateContext, self).Enter()
def HostComputeCore(self):
def HostComputeCore(self) -> List[Text]:
return self._host_compute_core
def _RemoveExternalControlEdges(self, op):
def _RemoveExternalControlEdges(
self, op: ops.Operation
) -> Tuple[List[ops.Operation], List[ops.Operation]]:
"""Remove any external control dependency on this op."""
internal_control_inputs = []
external_control_inputs = []
@ -552,12 +564,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# pylint: enable=protected-access
return internal_control_inputs, external_control_inputs
def AddOp(self, op):
def AddOp(self, op: ops.Operation) -> None:
# pylint: disable=protected-access
if op.type in _DENYLISTED_OPS:
logging.error("Operation of type %s (%s) is not supported on the TPU. "
"Execution will fail if this op is used in the graph. " %
(op.type, op.name))
"Execution will fail if this op is used in the graph. ",
op.type, op.name)
if op.type in _UNSUPPORTED_OPS:
self._unsupported_ops.append(op)
@ -631,7 +643,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if self._outer_context:
self._outer_context.AddInnerOp(op)
def AddValue(self, val):
def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
"""Add `val` to the current context and its outer context recursively."""
if not self._outer_context:
return val
@ -651,7 +663,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return result
def AddInnerOp(self, op):
def AddInnerOp(self, op: ops.Operation):
self.AddOp(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
@ -671,7 +683,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return self.GetWhileContext().back_prop
return False
def GetControlPivot(self):
def GetControlPivot(self) -> ops.Operation:
return self._pivot
def RequiresUniqueFunctionRetracing(self):
@ -688,11 +700,11 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
attribute.
"""
def __init__(self, name):
def __init__(self, name: Text):
control_flow_ops.ControlFlowContext.__init__(self)
self._name = name
def AddOp(self, op):
def AddOp(self, op: ops.Operation) -> None:
if self._outer_context:
self._outer_context.AddOp(op)
# pylint: disable=protected-access
@ -700,7 +712,7 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
# pylint: enable=protected-access
def AddInnerOp(self, op):
def AddInnerOp(self, op: ops.Operation) -> None:
if self._outer_context:
self._outer_context.AddInnerOp(op)
# pylint: disable=protected-access
@ -713,7 +725,9 @@ class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
@tf_export(v1=["tpu.outside_compilation"])
def outside_compilation(computation, *args, **kwargs):
def outside_compilation(
computation: Callable[..., Any], *args, **kwargs
) -> Any:
"""Builds part of a computation outside any current TPU replicate scope.
`tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
@ -862,14 +876,15 @@ class XLAOptions(
@tf_export(v1=["tpu.replicate"])
def replicate(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None,
maximum_shapes=None,
padding_spec=None,
xla_options=None):
def replicate(
computation: Callable[..., Any],
inputs: Optional[List[List[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
maximum_shapes: Any = None,
padding_spec: Optional[PaddingSpec] = None,
xla_options: Optional[XLAOptions] = None) -> List[Any]:
"""Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape:
@ -978,7 +993,11 @@ def _ceil_to_pow_of_n(x, n):
return result
def _pad_all_input(inputs, padded_shapes, padding_spec):
def _pad_all_input(
inputs: Iterable[core_types.Tensor],
padded_shapes: List[Optional[tensor_shape.TensorShape]],
padding_spec: PaddingSpec
) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
"""Pad all input tensors given padded_shapes.
The real shape tensors will be concatenated with the padded original inputs.
@ -1150,15 +1169,17 @@ def _flatten_and_filter_composite(maybe_composite, non_composite_output,
return non_composite_output
def split_compile_and_replicate(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None,
use_tpu=True,
maximum_shapes=None,
padding_spec=None,
xla_options=None):
def split_compile_and_replicate(
computation: Callable[..., Any],
inputs: List[List[Optional[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
use_tpu: bool = True,
maximum_shapes: Any = None,
padding_spec: Optional[PaddingSpec] = None,
xla_options: Optional[XLAOptions] = None,
) -> List[List[core_types.Tensor]]:
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
@ -1517,7 +1538,9 @@ def split_compile_and_replicate(computation,
return [compile_status, replicated_outputs]
def _postprocess_flat_outputs(outputs):
def _postprocess_flat_outputs(
outputs: Any
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
"""Validates non-flat outputs, add backs device assignments and other attrs.
Args:
@ -1593,7 +1616,9 @@ def _postprocess_flat_outputs(outputs):
return new_output_tensors, output_operations, pack_template
def _postprocess_non_flat_outputs(outputs):
def _postprocess_non_flat_outputs(
outputs: Any
) -> Tuple[List[core_types.Tensor], List[ops.Operation], List[Any]]:
"""Validates non-flat outputs, add backs device assignments and other attrs.
Args:
@ -1641,16 +1666,18 @@ def _postprocess_non_flat_outputs(outputs):
return flat_outputs, [], outputs
def split_compile_and_shard(computation,
inputs=None,
num_shards=1,
input_shard_axes=None,
outputs_from_all_shards=True,
output_shard_axes=None,
infeed_queue=None,
device_assignment=None,
name=None,
xla_options=None):
def split_compile_and_shard(
computation: Callable[..., Any],
inputs: List[List[Optional[core_types.Tensor]]] = None,
num_shards: int = 1,
input_shard_axes: Optional[List[int]] = None,
outputs_from_all_shards: Union[bool, List[bool]] = True,
output_shard_axes: Optional[List[int]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
xla_options: Optional[XLAOptions] = None,
) -> Tuple[ops.Operation, List[core_types.Tensor]]:
"""Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1796,16 +1823,17 @@ def split_compile_and_shard(computation,
@tf_export(v1=["tpu.shard"])
def shard(computation,
inputs=None,
num_shards=1,
input_shard_axes=None,
outputs_from_all_shards=True,
output_shard_axes=None,
infeed_queue=None,
device_assignment=None,
name=None,
xla_options=None):
def shard(
computation: Callable[..., Any],
inputs: Optional[List[core_types.Tensor]] = None,
num_shards: int = 1,
input_shard_axes: Optional[List[int]] = None,
outputs_from_all_shards: Union[bool, List[bool]] = True,
output_shard_axes: Optional[List[int]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
"""Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty list), each
@ -1881,13 +1909,14 @@ def shard(computation,
@tf_export(v1=["tpu.batch_parallel"])
def batch_parallel(computation,
inputs=None,
num_shards=1,
infeed_queue=None,
device_assignment=None,
name=None,
xla_options=None):
def batch_parallel(
computation: Callable[..., Any],
inputs: List[List[Optional[core_types.Tensor]]] = None,
num_shards: int = 1,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
xla_options: Optional[XLAOptions] = None):
"""Shards `computation` along the batch dimension for parallel execution.
Convenience wrapper around shard().
@ -1942,12 +1971,13 @@ def batch_parallel(computation,
@tf_export(v1=["tpu.rewrite"])
def rewrite(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None,
xla_options=None):
def rewrite(
computation: Callable[..., Any],
inputs: List[List[Optional[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
xla_options: Optional[XLAOptions] = None) -> Any:
"""Rewrites `computation` for execution on a TPU system.
Args:
@ -2012,7 +2042,7 @@ _DENYLISTED_INFERENCE_OPS = set([
])
def under_tpu_inference_context():
def under_tpu_inference_context() -> bool:
"""Check if it is currently under `_TPUInferenceContext`."""
graph = ops.get_default_graph()
while graph:
@ -2037,7 +2067,7 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
tpu.rewrite_for_inference() computation.
"""
def __init__(self, name, check_ops=True):
def __init__(self, name: Text, check_ops: bool = True):
super(_TPUInferenceContext, self).__init__()
self._name = name
self._check_ops = check_ops
@ -2069,7 +2099,7 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
return None
def validate_inference_rewrite_for_variables(graph):
def validate_inference_rewrite_for_variables(graph: ops.Graph):
"""Validates whether rewrite_for_inference() 'worked' for variables.
The rewrite_for_inference() method is supposed to append GuaranteeConstOps
@ -2098,11 +2128,12 @@ def validate_inference_rewrite_for_variables(graph):
"computation.")
def rewrite_for_inference(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None):
def rewrite_for_inference(
computation: Callable[..., Any],
inputs: Optional[List[core_types.Tensor]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None) -> List[core_types.Tensor]:
"""Rewrites `computation` for inference on a TPU system.
Other than 'rewriting' the computation to run on a TPU, if using variables
@ -2167,7 +2198,7 @@ def rewrite_for_inference(computation,
# pylint: enable=undefined-variable
def prune_unconnected_ops_from_xla(prune_graph):
def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
"""Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
Args:

View File

@ -30,7 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_name_util
from tensorflow.python.tpu import tpu_sharding
from tensorflow.python.tpu.ops import tpu_ops
@ -502,7 +502,7 @@ class InfeedQueue(object):
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
if tpu_device is not None:
with ops.device(tpu.core(tpu_device)):
with ops.device(tpu_name_util.core(tpu_device)):
dequeue_op = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
else:
@ -808,7 +808,7 @@ class _PartitionedInfeedQueue(InfeedQueue):
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
with ops.device(tpu.core(tpu_device)):
with ops.device(tpu_name_util.core(tpu_device)):
values = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
return tag_sharding_attribute_for_dequeued_tensors(

View File

@ -0,0 +1,35 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Helper functions for TPU device names."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Text
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["tpu.core"])
def core(num: int) -> Text:
"""Returns the device name for a core in a replicated TPU computation.
Args:
num: the virtual core number within each replica to which operators should
be assigned.
Returns:
A device name, suitable for passing to `tf.device()`.
"""
return "device:TPU_REPLICATED_CORE:{}".format(num)