Move TPU python files to TF core.

PiperOrigin-RevId: 234851893
This commit is contained in:
Jonathan Hseu 2019-02-20 12:59:38 -08:00 committed by TensorFlower Gardener
parent 2825b62b27
commit d23fc2b7ff
72 changed files with 14714 additions and 13802 deletions

View File

@ -28,8 +28,7 @@ py_library(
srcs = ["python/ops/tpu_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python/tpu:tpu_py",
],
)
@ -38,19 +37,7 @@ py_library(
srcs = ["python/tpu/async_checkpoint.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/tpu:async_checkpoint",
],
)
@ -72,24 +59,7 @@ py_library(
":tpu_embedding",
":tpu_lib",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:function",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:util",
"@six_archive//:six",
"//tensorflow/python/tpu:tpu_estimator",
],
)
@ -101,7 +71,7 @@ py_library(
"//visibility:public",
],
deps = [
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python/tpu:functional",
],
)
@ -110,10 +80,7 @@ py_library(
srcs = ["python/profiler/__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc",
"//tensorflow/core/profiler:profiler_analysis_proto_py",
"//tensorflow/core/profiler:protos_all_py",
"//tensorflow/python:util",
"//tensorflow/python/tpu/profiler",
],
)
@ -130,6 +97,7 @@ py_library(
":tpu_embedding",
":tpu_estimator",
":tpu_lib",
"//tensorflow/python/tpu",
],
)
@ -196,29 +164,9 @@ py_library(
":functional",
":profiler",
":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/compiler:xla",
"//tensorflow/core:protos_all_py",
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
"//tensorflow/core/protobuf/tpu:topology_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/losses",
"//tensorflow/python/tpu:tpu_lib",
],
)
@ -229,106 +177,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
tf_py_test(
name = "datasets_test",
size = "medium",
srcs = ["python/tpu/datasets_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
":datasets",
],
grpc_enabled = True,
shard_count = 4,
tags = ["no_oss"],
)
tf_py_test(
name = "tpu_test",
size = "small",
srcs = ["python/tpu/tpu_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:layers",
],
tags = ["no_windows"], # TODO: needs investigation on Windows
)
tf_py_test(
name = "tpu_sharding_test",
size = "small",
srcs = ["python/tpu/tpu_sharding_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "bfloat16_test",
size = "small",
srcs = ["python/tpu/bfloat16_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "tpu_infeed_test",
size = "small",
srcs = ["python/tpu/tpu_infeed_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_config_test",
size = "small",
srcs = ["python/tpu/tpu_config_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_estimator_signals_test",
size = "small",
srcs = ["python/tpu/tpu_estimator_signals_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "topology_test",
size = "medium",
srcs = ["python/tpu/topology_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/tpu:datasets",
],
)
@ -341,16 +190,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":tpu_lib",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:partitioned_variables",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"@six_archive//:six",
"//tensorflow/python/tpu:tpu_embedding",
],
)
@ -359,31 +199,6 @@ py_library(
srcs = ["python/tpu/feature_column.py"],
deps = [
":tpu_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/tpu:feature_column",
],
)
tf_py_test(
name = "feature_column_test",
srcs = [
"python/tpu/feature_column_test.py",
],
additional_deps = [
":feature_column",
"//third_party/py/numpy",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:variables",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
],
main = "python/tpu/feature_column_test.py",
)

View File

@ -63,11 +63,3 @@ tf_cc_test(
"@jsoncpp_git//:jsoncpp",
],
)
py_library(
name = "tpu_profiler_analysis_pb2_grpc",
srcs = ["tpu_profiler_analysis_pb2_grpc.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
)

View File

@ -1,418 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations for TPUs."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.python.ops import gen_tpu_ops
from tensorflow.python.ops.gen_tpu_ops import *
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
def _create_default_group_assignment():
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
logging.warning(
"cross_replica_sum should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
group_assignment = [list(range(num_shards))]
return group_assignment
def all_to_all(x,
concat_dimension,
split_dimension,
split_count,
group_assignment=None,
name=None):
"""Exchange data across TPU replicas.
Args:
x: The local tensor.
concat_dimension: The dimension number to concatenate.
split_dimension: The dimension number to split.
split_count: The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1])
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
A `Tensor` which is concatenated by data from different replicas.
"""
if group_assignment is None:
group_assignment = _create_default_group_assignment()
return gen_tpu_ops.all_to_all(
x,
group_assignment,
concat_dimension=concat_dimension,
split_dimension=split_dimension,
split_count=split_count,
name=name)
@ops.RegisterGradient("AllToAll")
def _all_to_all_grad(op, grad):
# The gradient of a all-to-all is also a all-to-all but the
# split_dimension and concat_dimension is swapped.
# The graident with respect to group_assignment is None.
return [
gen_tpu_ops.all_to_all(
grad,
op.inputs[1],
concat_dimension=op.get_attr("split_dimension"),
split_dimension=op.get_attr("concat_dimension"),
split_count=op.get_attr("split_count")), None
]
def cross_replica_sum(x, group_assignment=None, name=None):
"""Sum the input tensor across replicas according to group_assignment.
Args:
x: The local tensor to the sum.
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
A `Tensor` which is summed across replicas.
"""
if group_assignment is None:
group_assignment = _create_default_group_assignment()
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def collective_permute(x, source_target_pairs, name=None):
"""Permute the input tensor across replicas given source_target_pairs.
For each source_target_pair <a, b>, we send replica a's input to replica b.
Each replica id must only appear once in the source column. Also it must
only appear once in the target column.
For the replica id not in the target column, this op returns a zero tensor
with the same shape and dtype of the input x.
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
`[0, A, B, C]`.
Args:
x: The local tensor to be permuted.
source_target_pairs: 2d int lists with shape [num_pairs, 2].
source_target_pairs[i][0] represents the source replica id and
source_target_pairs[i][1] represents the target replica id.
name: Optional op name.
Returns:
A `Tensor` which is permuted.
"""
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
@ops.RegisterGradient("CollectivePermute")
def _collective_permute_grad(op, grad):
# The gradient of a collective permute operation is also a collective
# permute, but with source/target pairs reversed. The gradient with respect
# to input argument `source_target_pairs` is `None`.
source_target_pairs = op.inputs[1][:, ::-1]
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
# The gradient with respect to group_assignment is None.
return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
# types are supported.
_SUPPORTED_INFEED_DTYPES = set([
dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
dtypes.complex64, dtypes.uint32
])
@ops.RegisterGradient("TPUEmbeddingActivations")
def _embedding_activations_grad(activations_op, grad_wrt_activations):
"""Saves the gradient of embedding activations ops in a graph collection."""
g = ops.get_default_graph()
table_id = activations_op.get_attr("table_id")
lookup_id = activations_op.get_attr("lookup_id")
table_gradients = g.get_collection_ref(
"tpu_embedding_gradients_table_%d" % table_id)
if not table_gradients:
raise RuntimeError(
"Gradients for TPUEmbedding have been generated in non-training mode."
"This is not expected. Consider putting your Optimizer.minimize code "
"behind the training mode condition check. For Estimator, you can "
"do \n\n"
" if mode == tf.estimator.ModeKeys.TRAIN:\n"
" train_op = opt.minimize(loss)\n"
"\n")
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
return [
# RegisterGradient requires that value be returned for all inputs. Since
# the first argument (tpu_gradient_variable_{table_name}) has shape [1],
# we will return zeros(shape=[1]). The actual gradient w.r.t. the
# embedding activations (grad_wrt_activations) has the same shape as the
# activations returned by embedding_activations.
array_ops.zeros(arg.shape, dtype=dtypes.float32)
for arg in activations_op.inputs
]
def infeed_dequeue(dtype, shape, name=None):
"""A placeholder op for a value that will be fed into the computation.
Args:
dtype: A `tf.DType`. The type of elements in the tensor.
shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `dtype`.
A tensor that will be provided using the infeed mechanism.
Raises:
TypeError: If 'dtype` is not a supported infeed type.
"""
if dtype not in _SUPPORTED_INFEED_DTYPES:
raise TypeError(
"{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
# pylint: disable=redefined-outer-name
def infeed_dequeue_tuple(dtypes, shapes, name=None):
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
Args:
dtypes: A list of `tf.DType`s that has length `>= 1`.
The element types of each element in `outputs`.
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
The shapes of each tensor in `outputs`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `dtypes`.
A list of tensors that will be provided using the infeed mechanism.
Raises:
TypeError: If a type in 'dtypes` is not a supported infeed type.
"""
for dtype in dtypes:
if dtype not in _SUPPORTED_INFEED_DTYPES:
raise TypeError(
"{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name
# pylint: disable=protected-access
def send_tpu_embedding_gradients(inputs,
config,
learning_rates=None,
name=None):
"""A placeholder op for feeding per-sample gradients to the embedding layer.
Args:
inputs: A TensorList of gradients with which to update embedding tables.
This argument has the same length and shapes as the return value of
RecvTPUEmbeddingActivations, but contains gradients of the model's
loss with respect to the embedding activations. The embedding tables
are updated from these gradients via the optimizers specified in the
TPU embedding configuration given to tpu.initialize_system.
config: Serialized TPUEmbeddingConfiguration proto.
learning_rates: A TensorList of float32 scalars, one for each dynamic
learning rate tag: see the comments in
//third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto.
Multiple tables can share the same dynamic learning rate tag as
specified in the configuration. If the learning rates for all tables
are constant, this list should be empty.
name: A name for the operation (optional).
Returns:
A SendTPUEmbeddingGradients operation.
"""
if learning_rates is None:
learning_rates = []
return gen_tpu_ops.send_tpu_embedding_gradients(
inputs=inputs, learning_rates=learning_rates, config=config, name=name)
send_tpu_embedding_gradients.__doc__ = (
gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_integer_batch(batch,
device_ordinal,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
batch: A list of 1D tensors, one for each embedding table, containing the
indices into the tables.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingIntegerBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
batch=batch,
device_ordinal=device_ordinal,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_integer_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_sparse_batch(sample_indices,
embedding_indices,
aggregation_weights,
device_ordinal,
combiners=None,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 1 Tensors specifying the training example
and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size).
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables.
aggregation_weights: A list of rank 1 Tensors containing per sample --
i.e. per (training example, feature) -- aggregation weights.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
is to use 'sum' for all tables (optional).
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingSparseBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
sample_indices=sample_indices,
embedding_indices=embedding_indices,
aggregation_weights=aggregation_weights,
device_ordinal=device_ordinal,
combiners=combiners,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_sparse_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
embedding_indices,
aggregation_weights,
table_ids,
device_ordinal,
combiners=None,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 1 Tensors specifying the training example
to which the corresponding embedding_indices and aggregation_weights
values belong. It corresponds to sp_ids.indices[:,0] in
embedding_lookup_sparse().
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse().
table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using
table_ids[i]. The size of the table_ids list must be equal to that of
sample_indices, embedding_indices and aggregation_weights.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
is to use 'sum' for all tables (optional).
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingSparseTensorBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
sample_indices=sample_indices,
embedding_indices=embedding_indices,
aggregation_weights=aggregation_weights,
table_ids=table_ids,
device_ordinal=device_ordinal,
combiners=combiners,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
pass
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.ops.tpu_ops import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,35 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations to select TPU core to run."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.python.ops.gen_tpu_ops import tpu_ordinal_selector
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
pass
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.ops.tpu_ordinal_selector_op import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,31 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Classes for TPU trace events."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.core.profiler.trace_events_pb2 import *
from tensorflow.core.profiler.profiler_analysis_pb2 import *
from tensorflow.python.tpu.profiler import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent']
remove_undocumented(__name__, _allowed_symbols)

View File

@ -1,20 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Ops related to Tensor Processing Units."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,334 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tooling for support TPU embedding in TPUEstimator."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc
from tensorflow.contrib.tpu.python.tpu import tpu_embedding
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.feature_column import feature_column as core_fc
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
# pylint: disable=protected-access
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
tpu_fc._TPUSharedEmbeddingColumn)
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
core_fc_lib.EmbeddingColumn,
core_fc._SharedEmbeddingColumn)
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
# pylint: enable=protected-access
_TABLE_NAME_PREFIX = 'tbl_'
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
def _get_table_name_from_embedding_var_name(embedding_var_name):
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
def _get_embedding_var_name_from_table_name(table_name):
return table_name[_LEN_TABLE_NAME_PREFIX:]
def _get_embedding_variable_name(scope_name, var_name):
return '{}/{}'.format(scope_name, var_name)
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
"""Return embedding variable names which are consistent with CPU runs."""
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
return tpu_embedding.AdagradSlotVariableName(
'{}/{}/Adagrad'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
return tpu_embedding.AdamSlotVariableNames(
'{}/{}/Adam/m'.format(scope_name, var_name),
'{}/{}/Adam/v'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters,
tpu_embedding.StochasticGradientDescentParameters):
return None
else:
raise ValueError('Support to infer full variable name '
'for optimization_parameter {} has not been added.'
.format(optimization_parameters))
def get_full_variable_names(
graph, table_to_config_dict, optimization_parameters):
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
if not collection:
raise RuntimeError(
'Embedding feature column did not capture any thing. Make sure the '
'feature columns passed to TPUEstimator constructor is properly '
'used in model_fn.')
embedding_variable_name_by_table = {}
slot_variable_names_by_table = {}
for table_name in table_to_config_dict:
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
(scope_name, var_name) = collection[0][embedding_var_name]
embedding_variable_name_by_table[table_name] = (
_get_embedding_variable_name(scope_name, var_name))
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
scope_name, var_name, optimization_parameters)
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
return embedding_variable_name_by_table, slot_variable_names_by_table
def get_tpu_embedding_config_from_feature_columns(feature_columns):
"""Create configs for TPUEmbedding from a list of feature columns.
This function will place one embedding tensor per table and the return is
intended to be used as input to TPUEmbedding.
Args:
feature_columns: a list of supported feature columns.
Returns:
A pair of dicts, the first maps tables to their config, the second maps
features to tables.
"""
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
for column in feature_columns:
if not isinstance(column, allowed):
raise TypeError(
'Unsupported feature column {}. Supported types are {}.'.format(
type(column), allowed))
table_to_config = {}
feature_to_table = {}
for column in feature_columns:
feature_name = column.get_feature_key_name()
table_name = _get_table_name_from_embedding_var_name(
column.get_embedding_var_name())
if feature_name in feature_to_table:
raise ValueError(
'Feature column {} is used with multiple embeddings and this is '
'not supported.'.format(feature_name))
feature_to_table[feature_name] = table_name
vocabulary_size, dimension = column.get_embedding_table_size()
table_to_config[table_name] = tpu_embedding.TableConfig(
vocabulary_size=vocabulary_size,
dimension=dimension,
initializer=column.get_initializer(),
combiner=column.get_combiner())
return table_to_config, feature_to_table
def _get_tpu_embedding_optimization_parameters(embedding_config_spec):
"""Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec."""
if embedding_config_spec.optimizer_type == 'adagrad':
return tpu_embedding.AdagradParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.adagrad_initial_accumulator,
embedding_config_spec.use_gradient_accumulation)
elif embedding_config_spec.optimizer_type == 'sgd':
return tpu_embedding.StochasticGradientDescentParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.use_gradient_accumulation)
elif embedding_config_spec.optimizer_type == 'adam':
return tpu_embedding.AdamParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.adam_parameters.beta1,
embedding_config_spec.adam_parameters.beta2,
embedding_config_spec.adam_parameters.epsilon,
use_gradient_accumulation=embedding_config_spec
.use_gradient_accumulation)
else:
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
AdamParameters = collections.namedtuple('AdamParameters',
['beta1', 'beta2', 'epsilon'])
# TODO(shizhiw): Improve the API to support more optimizer parameters in API.
class EmbeddingConfigSpec(
collections.namedtuple('EmbeddingConfigSpec', [
'feature_columns', 'learning_rate', 'optimizer_type',
'adagrad_initial_accumulator', 'clipping_limit',
'use_gradient_accumulation', 'adam_parameters'
])):
"""Class to keep track of embedding config specification."""
def __new__(cls,
feature_columns,
learning_rate,
optimizer_type='adagrad',
adagrad_initial_accumulator=None,
clipping_limit=None,
use_gradient_accumulation=False,
adam_parameters=None):
"""Creates an EmbeddingConfigSpec instance.
Args:
feature_columns: All `FeatureColumn`s used by model.
learning_rate: embedding optimizer learning rate.
optimizer_type: (String) Name of the optimizer for embedding gradients
updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default
value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam'
(`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will
be applied to all embedding variables specified by `feature_columns`.
adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when
optimizer_type is 'adagrad'. Default is `0.1`.
clipping_limit: (Optional) Clipping limit (absolute value).
use_gradient_accumulation: (Experimental) Whether to accumulate the
gradients across TPU embedding mini-batches. Gradient accumulation does
not affect SGD and therefore this is applicable only for Adagrad.
adam_parameters: AdamParameters. Used when optimizer_type is 'adam'.
Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon.
Returns:
An EmbeddingConfigSpec instance.
Raises:
ValueError: If the feature_columns are not specified.
TypeError: If the feature columns are not of ths correct type (one of
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
_EMBEDDING_COLUMN_CLASSES).
ValueError: If use_gradient_accumulation is True for SGD.
ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or
"adam".
"""
if not feature_columns:
raise ValueError('`feature_columns` cannot be `None` or empty.')
# It is unknown at this moment, whether the TPUEstimator is running in CPU
# or TPU mode. So allow non-TPU embedding columns also.
supported_classes = tuple(
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
list(_EMBEDDING_COLUMN_CLASSES))
for column in feature_columns:
if not isinstance(column, supported_classes):
raise TypeError(
'All feature columns must be supported types in {}. Got {}'.format(
supported_classes, type(column)))
if optimizer_type == 'adagrad':
if adagrad_initial_accumulator is None:
adagrad_initial_accumulator = 0.1
if adagrad_initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
elif optimizer_type == 'sgd':
if use_gradient_accumulation:
raise ValueError('Gradient accumulation makes sense for Adagrad only.')
elif optimizer_type == 'adam':
if adam_parameters is None:
adam_parameters = AdamParameters(0.9, 0.999, 1e-8)
if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(
adam_parameters.beta1))
if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.:
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(
adam_parameters.beta2))
if adam_parameters.epsilon <= 0.:
raise ValueError('epsilon must be positive; got {}.'.format(
adam_parameters.epsilon))
else:
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
return super(EmbeddingConfigSpec, cls).__new__(
cls,
feature_columns=feature_columns,
learning_rate=learning_rate,
optimizer_type=optimizer_type,
adagrad_initial_accumulator=adagrad_initial_accumulator,
clipping_limit=clipping_limit,
use_gradient_accumulation=use_gradient_accumulation,
adam_parameters=adam_parameters)
class EmbeddingConfig(object):
"""This is the internal immutable object for embedding config.
`_EmbeddingConfig` is responsible to _translate_ user provided
`EmbeddingConfigSpec` to internal data structures, mostly constructor
arguments of `TPUEmbedding`.
"""
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
num_hosts, num_cores, master):
self._embedding_config_spec = embedding_config_spec
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._num_hosts = num_hosts
self._num_cores = num_cores
self._master = master
self._table_to_config_dict, self._feature_to_table_dict = (
get_tpu_embedding_config_from_feature_columns(
embedding_config_spec.feature_columns))
self._optimization_parameters = _get_tpu_embedding_optimization_parameters(
self._embedding_config_spec)
self._mode_to_tpu_embedding_dict = {}
self.dummy_table_variables = None
def has_embedding_tables(self):
return bool(self._table_to_config_dict)
def _create_tpu_embedding(self, mode):
"""Create tpu_embedding.TPUEmbedding based on mode."""
if mode == model_fn_lib.ModeKeys.TRAIN:
batch_size = self._train_batch_size
else:
batch_size = self._eval_batch_size
if mode == model_fn_lib.ModeKeys.TRAIN:
tpu_embedding_mode = tpu_embedding.TRAINING
elif (mode == model_fn_lib.ModeKeys.EVAL or
mode == model_fn_lib.ModeKeys.PREDICT):
tpu_embedding_mode = tpu_embedding.INFERENCE
else:
raise ValueError('Mode {} is not supported.'.format(mode))
tpu_embedding_ = tpu_embedding.TPUEmbedding(
self._table_to_config_dict,
self._feature_to_table_dict,
batch_size,
tpu_embedding_mode,
self._master,
self._optimization_parameters,
)
return tpu_embedding_
def get_tpu_embedding(self, mode):
if mode not in self._mode_to_tpu_embedding_dict:
self._mode_to_tpu_embedding_dict[mode] = (
self._create_tpu_embedding(mode))
return self._mode_to_tpu_embedding_dict[mode]
def split_inputs(ctx, features, labels):
"""Splits the dense and sparse tensors inside the features and labels."""
sparse_features = collections.OrderedDict()
if ctx.embedding_config:
tpu_embedding_ = ctx.embedding_config.tpu_embedding
for feature_key in tpu_embedding_.feature_to_table_dict:
sparse_features[feature_key] = features.pop(feature_key)
return features, labels, sparse_features
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu._tpu_estimator_embedding import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,212 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
allow execution to continue on the main thread.
"""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
"""Initializes a `CheckpointSaverHook`.
Args:
checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances. Used for
callbacks that run immediately before or after this hook saves the
checkpoint.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
self._save_thread = None
self._write_graph_thread = None
self._checkpoint_dir = checkpoint_dir
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or []
self._steps_per_run = 1
self._summary_writer = None
self._global_step_tensor = None
self._last_checkpoint_step = None
def _set_steps_per_run(self, steps_per_run):
self._steps_per_run = steps_per_run
def begin(self):
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
def _write_graph_fn(self):
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
self._write_graph_thread = threading.Thread(target=_write_graph_fn,
args=[self])
self._write_graph_thread.start()
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
logging.info("Triggering checkpoint. %s", global_step)
if self._save(run_context.session, global_step):
run_context.request_stop()
def end(self, session):
if self._save_thread:
logging.info("Waiting for any pending checkpoints to finish.")
self._save_thread.join()
if self._write_graph_thread:
logging.info("Waiting for any pending write_graph to finish.")
self._write_graph_thread.join()
last_step = session.run(self._global_step_tensor)
if self._last_checkpoint_step != last_step:
self._save(session, last_step, asynchronous=False)
for l in self._listeners:
l.end(session, last_step)
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
# Skip saving on step 0
if step == 0:
return
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
start_time = time.time()
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
for l in self._listeners:
l.after_save(session, step)
end_time = time.time()
logging.info("Checkpoint actual writing time: (%.3f sec)",
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
if not asynchronous:
self._last_checkpoint_step = step
_save_fn()
return
if self._save_thread is not None:
self._save_thread.join(timeout=0.1)
if self._save_thread.is_alive():
logging.info("Saver thread still in progress, skipping checkpoint.")
return
self._last_checkpoint_step = step
self._save_thread = threading.Thread(target=_save_fn)
self._save_thread.start()
def _get_saver(self):
if self._saver is not None:
return self._saver
elif self._scaffold is not None:
return self._scaffold.saver
# Get saver from the SAVERS collection if present.
collection_key = ops.GraphKeys.SAVERS
savers = ops.get_collection(collection_key)
if not savers:
raise RuntimeError(
"No items in collection {}. Please add a saver to the collection "
"or provide a saver or scaffold.".format(collection_key))
elif len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
self._saver = savers[0]
return savers[0]
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.async_checkpoint import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,77 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper context for running models with bfloat16."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_contextlib
def _get_custom_getter():
"""Returns a custom getter that this class's methods must be called under.
All methods of this class must be called under a variable scope that was
passed this custom getter. Example:
```python
network = ConvNetBuilder(...)
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
network.conv(...)
# Call more methods of network here
```
Currently, this custom getter only does anything if self.use_tf_layers is
True. In that case, it causes variables to be stored as dtype
self.variable_type, then casted to the requested dtype, instead of directly
storing the variable as the requested dtype.
"""
def inner_custom_getter(getter, *args, **kwargs):
"""Custom getter that forces variables to have type self.variable_type."""
cast_to_bfloat16 = False
requested_dtype = kwargs['dtype']
if requested_dtype == dtypes.bfloat16:
# Only change the variable dtype if doing so does not decrease variable
# precision.
kwargs['dtype'] = dtypes.float32
cast_to_bfloat16 = True
var = getter(*args, **kwargs)
# This if statement is needed to guard the cast, because batch norm
# assigns directly to the return value of this custom getter. The cast
# makes the return value not a variable so it cannot be assigned. Batch
# norm variables are always in fp32 so this if statement is never
# triggered for them.
if cast_to_bfloat16:
var = math_ops.cast(var, dtypes.bfloat16)
return var
return inner_custom_getter
@tf_contextlib.contextmanager
def bfloat16_scope():
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
"""
with variable_scope.variable_scope(
'', custom_getter=_get_custom_getter()) as varscope:
yield varscope
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.bfloat16 import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,191 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of Cloud TPU helper functions for data loading."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import functional_ops
def _TextLineDataset(filename):
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
return dataset
def _TFRecordDataset(filename):
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
return dataset
_FILETYPE_MAP = {
'tfrecord': _TFRecordDataset,
'textline': _TextLineDataset,
'text': _TextLineDataset,
}
def StreamingFilesDataset(files,
filetype=None,
file_reader_job=None,
worker_job=None,
num_epochs=None,
filename_shuffle_buffer_size=None,
num_parallel_reads=None,
batch_transfer_size=None,
sloppy=None):
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
files local to your GCE VM. In order to train using files stored on your local
VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset
helper to generate a dataset to feed your Cloud TPU with files from your GCE
VM.
The resulting dataset may return an OutOfRangeError if there are no files
found as a result of the fileglob expansion.
Note: StreamingFilesDataset assumes that the session is using a
TPUClusterResolver and has therefore a worker and a coordinator job. File
loading will be done on the coordinator job.
Args:
files: A string glob to match files, or a `tf.data.Dataset` generating file
names.
filetype: A string (one of 'tfrecord', or 'textline') or a single-argument
TensorFlow function that when given a filename returns a dataset.
file_reader_job: An optional string that corresponds to the job that should
perform the file reads.
worker_job: An optional string that corresponds to the job that should
process the tensors (i.e. your GPU or TPU worker).
num_epochs: The number of epochs through the training set that should be
generated. By default, it will repeat infinitely.
filename_shuffle_buffer_size: An optional integer whose value controls the
shuffling of the file names. If you would like to read from the files in
the same order, set to 0 or False.
num_parallel_reads: An optional integer controlling the number of files to
read from concurrently. (Set to 1 for no parallelism.)
batch_transfer_size: An optional integer controlling the batching used to
amortize the remote function invocation overhead. Set to a very large
number to increase throughput. Set to a very small number to reduce memory
consumption. Set to False to skip batching.
sloppy: (Optional.) If `False`, read input data while maintaining a
deterministic order. (This may have significant performance impacts.)
sloppy defaults to: True.
Returns:
A `tf.data.Dataset` with an infinite stream of elements generated by a
parallel interleaving of the set of files matched (or generated) by `files`
with a type is the output of the dataset specified by `filetype`.
Raises:
ValueError: if any argument is not of the expected type.
"""
if filetype is None:
filetype = 'tfrecord'
if isinstance(filetype, str):
if filetype not in _FILETYPE_MAP:
raise ValueError('Unexpected filetype: %s' % filetype)
reader_fn = _FILETYPE_MAP[filetype]
elif callable(filetype):
reader_fn = filetype
else:
raise ValueError('filetype should be a string or a callable')
file_reader_job = file_reader_job or 'coordinator'
worker_job = worker_job or 'worker'
if filename_shuffle_buffer_size is None:
filename_shuffle_buffer_size = 4096
num_parallel_reads = num_parallel_reads or 8
if batch_transfer_size is None:
batch_transfer_size = 256
if sloppy is None:
sloppy = True
with ops.device('/job:%s' % file_reader_job):
if isinstance(files, str):
source_dataset = dataset_ops.Dataset.list_files(files)
elif isinstance(files, dataset_ops.DatasetV2):
source_dataset = files
else:
raise ValueError('files was not a string or a dataset: %s' % files)
if filename_shuffle_buffer_size:
source_dataset = source_dataset.shuffle(
buffer_size=filename_shuffle_buffer_size)
source_dataset = source_dataset.apply(
interleave_ops.parallel_interleave(
reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
source_dataset = source_dataset.repeat(num_epochs)
if batch_transfer_size:
source_dataset = source_dataset.batch(batch_transfer_size)
source_dataset = source_dataset.prefetch(1)
source_iterator = dataset_ops.make_one_shot_iterator(source_dataset)
source_handle = source_iterator.string_handle()
@function.Defun(dtypes.string)
def LoadingFunc(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, source_dataset.output_types, source_dataset.output_shapes)
return remote_iterator.get_next()
def MapFn(unused_input):
if isinstance(source_dataset.output_types, dtypes.DType):
output_types = [source_dataset.output_types]
elif isinstance(source_dataset.output_types, (list, tuple)):
output_types = source_dataset.output_types
else:
raise ValueError('source dataset has invalid output types')
remote_calls = functional_ops.remote_call(
args=[source_handle],
Tout=output_types,
f=LoadingFunc,
target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
if len(remote_calls) == 1:
return remote_calls[0]
else:
return remote_calls
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
MapFn, num_parallel_calls=4 if sloppy else None)
output_dataset = output_dataset.prefetch(1)
if batch_transfer_size:
# Undo the batching used during the transfer.
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
return output_dataset
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.datasets import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,313 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of TPU helper functions."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.tpu.topology import Topology
SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]]
def _compute_task_and_cores_to_replicas(core_assignment, topology):
"""Computes a nested dict which maps task and logical core to replicas."""
task_and_cores_to_replicas = {}
for replica in xrange(core_assignment.shape[0]):
for logical_core in xrange(core_assignment.shape[1]):
coordinates = core_assignment[replica, logical_core, :]
task_id = topology.task_ordinal_at_coordinates(coordinates)
if task_id not in task_and_cores_to_replicas:
task_and_cores_to_replicas[task_id] = {}
if logical_core not in task_and_cores_to_replicas[task_id]:
task_and_cores_to_replicas[task_id][logical_core] = set()
task_and_cores_to_replicas[task_id][logical_core].add(replica)
task_to_sorted_replica_id = {}
for task, core_to_replicas in task_and_cores_to_replicas.items():
core_to_sorted_replicas = {}
for core, replicas in core_to_replicas.items():
core_to_sorted_replicas[core] = sorted(replicas)
task_to_sorted_replica_id[task] = core_to_sorted_replicas
return task_to_sorted_replica_id
class DeviceAssignment(object):
"""Mapping from logical cores in a computation to the physical TPU topology.
Prefer to use the `device_assignment()` helper to construct a
`DeviceAssignment`; it is easier if less flexible than constructing a
`DeviceAssignment` directly.
"""
def __init__(self, topology, core_assignment):
"""Constructs a `DeviceAssignment` object.
Args:
topology: A `Topology` object that describes the physical TPU topology.
core_assignment: A logical to physical core mapping, represented as a
rank 3 numpy array. See the description of the `core_assignment`
property for more details.
Raises:
ValueError: If `topology` is not `Topology` object.
ValueError: If `core_assignment` is not a rank 3 numpy array.
"""
if not isinstance(topology, Topology):
raise ValueError("topology must be a Topology object, got {}".format(
type(topology)))
core_assignment = np.asarray(core_assignment, dtype=np.int32)
self._topology = topology
if core_assignment.ndim != 3:
raise ValueError("core_assignment must be a rank 3 numpy array, "
"got shape {}".format(core_assignment.shape))
self._num_replicas = core_assignment.shape[0]
self._num_cores_per_replica = core_assignment.shape[1]
if core_assignment.shape[-1] != topology.mesh_rank:
raise ValueError(
"minor dimension of core_assignment must have size equal to topology "
"rank ({}), got shape {}".format(topology.mesh_rank,
core_assignment.shape))
self._core_assignment = core_assignment
self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
self._core_assignment, topology)
@property
def topology(self):
"""A `Topology` that describes the TPU topology."""
return self._topology
@property
def num_cores_per_replica(self):
"""The number of cores per replica."""
return self._num_cores_per_replica
@property
def num_replicas(self):
"""The number of replicas of the computation."""
return self._num_replicas
@property
def core_assignment(self):
"""The logical to physical core mapping.
Returns:
An integer numpy array of rank 3, with shape
`[num_replicas, num_cores_per_replica, topology_rank]`. Maps
(replica, logical core) pairs to physical topology coordinates.
"""
return self._core_assignment
def coordinates(self, replica, logical_core):
"""Returns the physical topology coordinates of a logical core."""
return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
"""Lookup replica ids by task number and logical core.
Args:
task_id: TensorFlow task number.
logical_core: An integer, identifying a logical core.
Returns:
A sorted list of the replicas that are attached to that task and
logical_core.
Raises:
ValueError: If no replica exists in the task which contains the logical
core.
"""
try:
return self._task_and_cores_to_replicas[task_id][logical_core]
except KeyError:
raise ValueError(
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
def tpu_ordinal(self, replica=0, logical_core=0):
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
def host_device(self, replica=0, logical_core=0, job=None):
"""Returns the CPU device attached to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
def tpu_device(self, replica=0, logical_core=0, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
def device_assignment(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1):
"""Computes a device_assignment of a computation across a TPU topology.
Attempts to choose a compact grid of cores for locality.
Returns a `DeviceAssignment` that describes the cores in the topology assigned
to each core of each replica.
`computation_shape` and `computation_stride` values should be powers of 2 for
optimal packing.
Args:
topology: A `Topology` object that describes the TPU cluster topology.
To obtain a TPU topology, evaluate the `Tensor` returned by
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
computation_shape: A rank 1 int32 numpy array with size equal to the
topology rank, describing the shape of the computation's block of cores.
If None, the `computation_shape` is `[1] * topology_rank`.
computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
describing the inter-core spacing of the `computation_shape` cores in the
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
Returns:
A DeviceAssignment object, which describes the mapping between the logical
cores in each computation replica and the physical cores in the TPU
topology.
Raises:
ValueError: If `topology` is not a valid `Topology` object.
ValueError: If `computation_shape` or `computation_stride` are not 1D int32
numpy arrays with shape [3] where all values are positive.
ValueError: If computation's replicas cannot fit into the TPU topology.
"""
# Deserialize the Topology proto, if it is a string.
if isinstance(topology, bytes):
topology = Topology(serialized=topology)
if not isinstance(topology, Topology):
raise ValueError("`topology` is not a Topology object; got {}".format(
type(topology)))
topology_rank = len(topology.mesh_shape)
mesh_shape = topology.mesh_shape
if computation_shape is None:
computation_shape = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_shape = np.asarray(computation_shape, dtype=np.int32)
if computation_stride is None:
computation_stride = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_stride = np.asarray(computation_stride, dtype=np.int32)
if computation_shape.shape != (topology_rank,):
raise ValueError("computation_shape must have shape [{}]; got {}".format(
topology_rank, computation_shape.shape))
if computation_stride.shape != (topology_rank,):
raise ValueError("computation_stride must have shape [{}]; got {}".format(
topology_rank, computation_stride.shape))
if any(computation_shape < 1):
raise ValueError(
"computation_shape must be positive; got computation_shape={}".format(
computation_shape))
if any(computation_stride < 1):
raise ValueError(
"computation_stride must be positive; got computation_stride={}".format(
computation_stride))
# Computes the physical size of one computation instance.
computation_footprint = computation_shape * computation_stride
if any(computation_footprint > mesh_shape):
raise ValueError(
"computation footprint {} does not fit in TPU topology shape {}".format(
computation_footprint, mesh_shape))
# Computes how many copies of the computation footprint fit in the mesh.
block_counts = mesh_shape // computation_footprint
replica_counts = block_counts * computation_stride
max_replicas = np.prod(replica_counts)
if num_replicas > max_replicas:
raise ValueError(
"requested {} replicas but only {} replicas with shape {} and "
"computation_stride {} fit in a TPU mesh of shape {}".format(
num_replicas, max_replicas, computation_shape, computation_stride,
mesh_shape))
def ceil_of_ratio(n, m):
return (n + m - 1) // m
replica_shape = [0] * topology_rank
if num_replicas > 0:
remaining_replicas = num_replicas
remaining_dims = topology_rank
# Choose dimensions as close to an equal cube as possible, in order of
# increasing dimension size. By visiting dimensions in increasing size, we
# assign the most constrained dimension first, so we won't make infeasible
# choices.
#
# As a secondary sort order, visit the dimensions in reverse order. This
# means we try to use both cores on the same chip in preference to two cores
# on different chips.
for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
i = -ni
target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
replica_shape[i] = min(target_size, x)
remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
remaining_dims -= 1
assert remaining_replicas == 1 and remaining_dims == 0
# Assigns an offset to each replica such that no two replicas overlap.
replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
for replica in xrange(num_replicas):
# Chooses a replica number in each axis.
t = replica
pos = []
for dim in replica_shape[::-1]:
pos.append(t % dim)
t //= dim
replica_pos = np.array(pos[::-1], dtype=np.int32)
# Determines where that replica starts in each axis.
outer = replica_pos // computation_stride
inner = replica_pos % computation_stride
replica_offsets[replica, :] = outer * computation_footprint + inner
# Computes a complete logical core -> physical core mapping for each replica.
indices = [
np.arange(0, computation_shape[i] * computation_stride[i],
computation_stride[i]) for i in xrange(topology_rank)
]
indices = np.concatenate(
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
axis=-1)
indices = indices.reshape((-1, topology_rank))
assignment = indices + replica_offsets[:, np.newaxis, :]
return DeviceAssignment(topology, core_assignment=assignment)
# pylint: disable=wildcard-import,unused-import,redefined-builtin
from tensorflow.python.tpu.device_assignment import *
# pylint: enable=wildcard-import,unused-import,redefined-builtin

View File

@ -1,132 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""ErrorRendezvous handler for collecting errors from multiple threads."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import sys
import threading
import time
import six
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
_UNINTERESTING_ERRORS = (errors.CancelledError,)
class ErrorRendezvous(object):
"""Resolve errors from multiple threads during TPU execution.
TPU errors can occur on the infeed or outfeed threads as well as the main
training thread.
Depending on which thread "wins" and receives the session error first, we may
end up showing users a confusing and non-actionable error message (session
cancelled) instead of a root cause (e.g. a bad filename).
The rendezvous object provides a location to capture these errors until all
threads terminate. At that point we can choose the most informative error
to report.
"""
def __init__(self, num_sources):
# string -> (message, traceback)
self._errors = {}
self._num_sources = num_sources
self._session_cancel_timer = None
def record_error(self, source, exc_info, session=None):
"""Report an exception from the given source.
If a session is passed, a timer will be registered to close it after a few
seconds. This is necessary to ensure the main training loop does not hang
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
interesting error from another thread to propagate.
Args:
source: string, source of the error
exc_info: Output from `sys.exc_info` (type, value, traceback)
session: Session to close after delay.
"""
_, value, _ = exc_info
self._errors[source] = exc_info
logging.info('Error recorded from %s: %s', source, value)
if session is not None and self._session_cancel_timer is None:
def _cancel_session():
time.sleep(5)
try:
session.close()
except: # pylint: disable=bare-except
pass
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
self._session_cancel_timer.daemon = True
self._session_cancel_timer.start()
def record_done(self, source):
"""Mark execution source `source` as done.
If an error was originally reported from `source` it is left intact.
Args:
source: `str`, source being recorded
"""
logging.info('%s marked as finished', source)
if source not in self._errors:
self._errors[source] = None
@contextlib.contextmanager
def catch_errors(self, source, session=None):
"""Context manager to report any errors within a block."""
try:
yield
except Exception: # pylint: disable=broad-except
self.record_error(source, sys.exc_info(), session)
def raise_errors(self, timeout_sec=0):
"""Wait for up to `timeout` seconds for all error sources to finish.
Preferentially raise "interesting" errors (errors not in the
_UNINTERESTING_ERRORS) set.
Args:
timeout_sec: Seconds to wait for other error sources.
"""
for _ in range(timeout_sec):
if len(self._errors) == self._num_sources:
break
time.sleep(1)
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
# First check for any interesting errors, then fall back on the session
# cancelled errors etc.
for k, (typ, value, traceback) in kept_errors:
if isinstance(value, _UNINTERESTING_ERRORS):
continue
else:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)
for k, (typ, value, traceback) in kept_errors:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.error_handling import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,435 +1,30 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU Feature Column Library."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
# pylint: disable=protected-access
_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
fc._VocabularyFileCategoricalColumn,
fc._VocabularyListCategoricalColumn,
fc._WeightedCategoricalColumn,
fc_lib.IdentityCategoricalColumn,
fc_lib.VocabularyFileCategoricalColumn,
fc_lib.VocabularyListCategoricalColumn,
fc_lib.WeightedCategoricalColumn)
def embedding_column(categorical_column,
dimension,
combiner='mean',
initializer=None):
"""TPU embedding_column for `tf.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU
version. The following args available for the non-TPU version are NOT
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
Args:
categorical_column: A categorical_column returned from
categorical_column_with_identity, weighted_categorical_column,
categorical_column_with_vocabulary_list or
categorical_column_with_vocabulary_file.
dimension: An integer specifying dimension of the embedding, must be > 0.
combiner: A string specifying how to reduce if there are multiple entries
in a single row. For more information, see
`tf.feature_column.embedding_column`.
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
`1/sqrt(dimension)`.
Returns:
A _TPUEmbeddingColumn.
Raises:
ValueError: if `dimension` not > 0.
ValueError: if `initializer` is specified but not callable.
"""
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '
' embedding_column must be type %s, got %s.' % (' or '.join([
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
]), type(categorical_column)))
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified. '
'Embedding of column_name: {}'.format(
categorical_column.name))
if initializer is None:
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
def _creator(weight_collections, scope):
embedding_column_layer = fc._EmbeddingColumnLayer(
embedding_shape=embedding_shape,
initializer=initializer,
weight_collections=weight_collections,
trainable=True,
name='embedding_column_layer')
return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
column = _TPUEmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
layer_creator=_creator,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
# For Embedding column, the initializer is hidden inside the creator Fn, which
# is not accessiable later. So, we attach it to a speicial field. Also note
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
# initializer differently. See shared_embedding_columns for details.
column._tpu_initializer = initializer
return column
def shared_embedding_columns(categorical_columns,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None):
"""List of dense columns that convert from sparse, categorical input."""
for categorical_column in categorical_columns:
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '
' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
]), type(categorical_column)))
columns = fc_lib.shared_embedding_columns(
categorical_columns,
dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
# Use the initializer and shared_embedding_collection_name to create TPU
# version
initializer = columns[0].initializer
shared_embedding_collection_name = columns[0].shared_embedding_collection_name
tpu_columns = []
# Create the state (_SharedEmbeddingColumnLayer) here.
for categorical_column in categorical_columns:
column = _TPUSharedEmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
tpu_columns.append(column)
return tpu_columns
class _TPUBaseEmbeddingColumn(object):
"""Base class for TPU Embedding Column."""
def __init__(self, categorical_column):
self._tpu_categorical_column = categorical_column
def get_combiner(self):
"""Returns the embedding combiner."""
raise NotImplementedError('not implemented')
def get_embedding_table_size(self):
"""Returns the embedding table size, tuple of vocab size and dimension."""
raise NotImplementedError('not implemented')
def get_feature_key_name(self):
"""Returns the feature key name in the features dict."""
raise NotImplementedError('not impl')
def get_weight_key_name(self):
"""Return the key name for weights."""
raise NotImplementedError('not impl')
def get_embedding_var_name(self):
"""Returns the embedding variable name.
Feature key name and embedding variable name are usually one-to-one mapping.
But for shared embedding columns, it is many-to-one mapping.
"""
raise NotImplementedError('not impl')
def get_initializer(self):
"""Returns the initializer."""
raise NotImplementedError('not impl')
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
raise NotImplementedError('not impl')
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
"""Core Embedding Column."""
def __new__(cls,
categorical_column,
dimension,
combiner='mean',
layer_creator=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn.
return fc._EmbeddingColumn.__new__(
cls,
categorical_column,
dimension,
combiner=combiner,
layer_creator=layer_creator,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable)
def __init__(self,
categorical_column,
dimension,
combiner='mean',
layer_creator=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
self._key = None
def get_combiner(self):
return self.combiner
def get_embedding_table_size(self):
"""Returns num_ids and width."""
return (self.categorical_column._num_buckets, self.dimension)
def get_feature_key_name(self):
"""get_feature_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.categorical_column.name
return self.categorical_column.name
def get_weight_key_name(self):
"""get_weight_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.weight_feature_key
return None
def get_embedding_var_name(self):
"""get_embedding_var_name."""
return self.categorical_column.name
def get_initializer(self):
return self._tpu_initializer
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
if isinstance(
self.categorical_column,
(
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
fc_lib.WeightedCategoricalColumn)):
return True
return False
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
def host_computation():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
# TPU mode
# Get the embeddings from the LazyBuilder.
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
return tensor
class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
fc._SharedEmbeddingColumn):
"""Core Shared Embedding Column."""
def __new__(cls,
categorical_column,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
return fc._SharedEmbeddingColumn.__new__(
cls,
categorical_column,
dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable)
def __init__(self,
categorical_column,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
self._key = None
def get_combiner(self):
return self.combiner
def get_embedding_table_size(self):
"""Returns num_ids and width."""
return (self.categorical_column._num_buckets, self.dimension)
def get_feature_key_name(self):
"""get_feature_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.categorical_column.name
return self.categorical_column.name
def get_weight_key_name(self):
"""get_weight_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.weight_feature_key
return None
def get_embedding_var_name(self):
"""get_embedding_var_name."""
return self.shared_embedding_collection_name
def get_initializer(self):
return self.initializer
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
if isinstance(
self.categorical_column,
(
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
fc_lib.WeightedCategoricalColumn)):
return True
return False
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
def host_computation():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
# TPU mode
# Get the embeddings from the LazyBuilder.
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
is_shared_embedding=True)
return tensor
def _record_variable_scope_and_name(embedding_var_name,
embedding_var_name_in_fc,
is_shared_embedding=False):
"""Add embedding variable name and scope to collection."""
g = ops.get_default_graph()
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
if not collection:
collection.append({})
var_def_dict = collection[0]
captured_scope = variable_scope.get_variable_scope()
captured_scope_name = captured_scope.name
if embedding_var_name in var_def_dict:
if (var_def_dict[embedding_var_name][0] != captured_scope_name
and not is_shared_embedding):
raise ValueError(
'For embedding var name {}, the variable scope name is different, '
'got {}; expected {}'.format(embedding_var_name,
captured_scope_name,
var_def_dict[embedding_var_name][0]))
if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
raise ValueError(
'For embedding var name {}, the embedding name is different, '
'got {}; expected {}'.format(embedding_var_name,
embedding_var_name_in_fc,
var_def_dict[embedding_var_name][1]))
else:
var_def_dict[embedding_var_name] = (captured_scope_name,
embedding_var_name_in_fc)
def _is_running_on_cpu():
"""Returns True if the current context is CPU model."""
return tpu_function.get_tpu_context().number_of_shards is None
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.feature_column import *
# used by tests
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
from tensorflow.python.tpu.feature_column import _TPU_FC_TO_SCOPE
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
from tensorflow.python.tpu.feature_column import _TPUEmbeddingColumn
from tensorflow.python.tpu.feature_column import _TPUSharedEmbeddingColumn
# pylint: enable=wildcard-import,unused-import

View File

@ -1,23 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Functional operations."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.ops import tpu_ops
TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.functional import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,438 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Operations for handling session logging and shutdown notifications."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from google.protobuf import text_format
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
_WATCHDOG = None
class CoordinatorShutdownException(Exception):
"""Raised when the coordinator needs to shutdown."""
pass
def _clone_session(session, graph=None):
return session_lib.Session(
target=session.sess_str,
config=session._config, # pylint: disable=protected-access
graph=graph if graph else session.graph)
def _make_heartbeat_op(session, device, request_ph):
"""Return a heartbeat op or None if heartbeats are not supported by device."""
try:
# Test if we can connect in a isolated graph + session
with ops.Graph().as_default():
with _clone_session(session) as temp_session:
with ops.device(device):
heartbeat_op = tpu_ops.worker_heartbeat('')
options = config_pb2.RunOptions(timeout_in_ms=5000)
temp_session.run(heartbeat_op, options=options)
except errors.InvalidArgumentError as _:
logging.warning('Error running heartbeat on %s', device)
return None
except errors.DeadlineExceededError as _:
logging.warning('Timeout connecting to %s when testing heartbeat', device)
return None
# If we successfully connected and pinged the worker, go ahead and construct
# the operation.
with ops.device(device):
return tpu_ops.worker_heartbeat(request_ph)
class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers."""
def __init__(self, session, devices, heartbeat_ops, request_placeholder):
"""Construct a new WorkerHeartbeatManager.
(Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
Args:
session: `tf.Session`, session to use for heartbeat operations.
devices: `list[string]` Set of devices to connect to.
heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
the WorkerHeartbeatRequest protocol buffer.
"""
self._session = session
self._devices = devices
self._ops = heartbeat_ops
self._request_placeholder = request_placeholder
@staticmethod
def from_devices(session, devices):
"""Construct a heartbeat manager for the given devices."""
if not devices:
logging.error('Trying to create heartbeat manager with no devices?')
logging.info('Creating heartbeat manager for %s', devices)
request_placeholder = array_ops.placeholder(
name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = []
kept_devices = []
for device in devices:
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
if heartbeat_op is not None:
kept_devices.append(device)
heartbeat_ops.append(heartbeat_op)
else:
logging.warning('Heartbeat support not available for %s', device)
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
request_placeholder)
def num_workers(self):
return len(self._devices)
def configure(self, message):
"""Configure heartbeat manager for all devices.
Args:
message: `event_pb2.WorkerHeartbeatRequest`
Returns: `None`
"""
logging.info('Configuring worker heartbeat: %s',
text_format.MessageToString(message))
self._session.run(self._ops,
{self._request_placeholder: message.SerializeToString()})
def ping(self, request=None, timeout_in_ms=5000):
"""Ping all workers, returning the parsed status results."""
if request is None:
request = event_pb2.WorkerHeartbeatRequest()
options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
results = self._session.run(
self._ops,
feed_dict={self._request_placeholder: request.SerializeToString()},
options=options)
parsed_results = [
event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
for res_pb in results
]
logging.debug('Ping results: %s', parsed_results)
return parsed_results
def lame_workers(self):
"""Ping all workers, returning manager containing lame workers (or None)."""
ping_results = self.ping()
lame_workers = []
for ping_response, device, op in zip(ping_results, self._devices,
self._ops):
if ping_response.health_status != event_pb2.OK:
lame_workers.append((device, op))
if not lame_workers:
return None
bad_devices, bad_ops = zip(*lame_workers)
return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
self._request_placeholder)
def __repr__(self):
return 'HeartbeatManager(%s)' % ','.join(self._devices)
def shutdown(self, timeout_ms=10000):
"""Shutdown all workers after `shutdown_timeout_secs`."""
logging.info('Shutting down %s.', self)
req = event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
self.configure(req)
# Wait for workers to shutdown. This isn't strictly required
# but it avoids triggering multiple checkpoints with the same lame worker.
logging.info('Waiting %dms for worker shutdown.', timeout_ms)
time.sleep(timeout_ms / 1000)
def all_worker_devices(session):
"""Return a list of devices for each worker in the system."""
devices = session.list_devices()
return [
device.name
for device in devices
if ':CPU:' in device.name and 'coordinator' not in device.name
]
class WatchdogManager(threading.Thread):
"""Configures worker watchdog timer and handles periodic pings.
Usage:
# Ping workers every minute, shutting down workers if they haven't received
# a ping after 1 hour.
watchdog_manager = WatchdogManager(
ping_interval=60, shutdown_timeout=3600
)
# Use as a context manager, resetting watchdog on context exit:
with watchdog_manager:
session.run(...)
# Or setup globally; watchdog will remain active until program exit.
watchdog_manager.configure_and_run()
"""
def __init__(self,
session,
devices=None,
ping_interval=60,
shutdown_timeout=3600):
"""Initialize a watchdog manager.
Args:
session: Session connected to worker devices. A cloned session and graph
will be created for managing worker pings.
devices: Set of devices to monitor. If none, all workers will be
monitored.
ping_interval: Time, in seconds, between watchdog pings.
shutdown_timeout: Time, in seconds, before watchdog timeout.
"""
threading.Thread.__init__(self)
self.ping_interval = ping_interval
self.shutdown_timeout = shutdown_timeout
self.daemon = True
self._config = session._config # pylint: disable=protected-access
self._target = session.sess_str
self._running = False
self._devices = devices
self._graph = None
self._session = None
self._worker_manager = None
def _reset_manager(self):
"""Reset the graph, session and worker manager."""
self._graph = ops.Graph()
self._session = session_lib.Session(
target=self._target,
graph=self._graph,
config=self._config,
)
if self._devices is None:
self._devices = all_worker_devices(self._session)
with self._graph.as_default():
self._worker_manager = WorkerHeartbeatManager.from_devices(
self._session, self._devices)
self._worker_manager.configure(
event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(
timeout_ms=self.shutdown_timeout * 1000,),
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
def configure_and_run(self):
logging.info(
'Enabling watchdog timer with %d second timeout '
'and %d second ping interval.', self.shutdown_timeout,
self.ping_interval)
self._reset_manager()
self._running = True
self.start()
def stop(self):
logging.info('Stopping worker watchdog.')
self._worker_manager.configure(
event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
shutdown_mode=event_pb2.NOT_CONFIGURED))
self._running = False
self.join()
def __enter__(self):
self.configure_and_run()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def run(self):
# Don't fetch logs or adjust timing: just ping the watchdog.
#
# If we hit an exception, reset our session as it is likely broken.
while self._running:
try:
self._worker_manager.ping(request=None)
time.sleep(self.ping_interval)
except errors.OpError as e:
# Catch any TF errors that occur so we don't stop sending heartbeats
logging.debug('Caught error while sending heartbeat: %s', e)
self._reset_manager()
def start_worker_watchdog(session,
devices=None,
ping_interval=60,
shutdown_timeout=3600):
"""Start global worker watchdog to shutdown workers on coordinator exit."""
global _WATCHDOG
if _WATCHDOG is None:
# Ensure we can send a few pings before we timeout!
ping_interval = min(shutdown_timeout / 10., ping_interval)
_WATCHDOG = WatchdogManager(session, devices, ping_interval,
shutdown_timeout)
_WATCHDOG.configure_and_run()
class GracefulShutdownHook(session_run_hook.SessionRunHook):
"""Session hook that watches for shutdown events.
If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
SystemShutdown exception is raised to terminate the main session. If `saver`
is None the `SAVERS` collection will be read to find a saver.
`on_shutdown_hooks` is an optional list of functions that should be called
after checkpointing. The function is called with (`run_context`,
`all_workers`, `lame_workers`).
If `heartbeat_group` is not specified, it will default to all CPU workers
in the system.
"""
def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
self._saver = saver
self._checkpoint_prefix = checkpoint_prefix
self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
# Worker heartbeats are managed independently of the main training graph.
self._graph = ops.Graph()
self._workers = None
self._session = None
self._heartbeat_supported = False
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
# N.B. We have to pull the global step here to avoid it being unavailable
# at checkpoint time; the graph has been frozen at that point.
if training_util.get_global_step() is None and self.saver() is not None:
raise ValueError(
'Saver defined but no global step. Run `get_or_create_global_step()`'
' in your model definition to allow checkpointing.')
with self._graph.as_default():
logging.info('Installing graceful shutdown hook.')
self._session = _clone_session(training_session, self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
else:
logging.warn(
'No workers support hearbeats. Failure handling will be disabled.')
def saver(self):
if self._saver:
return self._saver
savers = ops.get_collection(ops.GraphKeys.SAVERS)
if not savers:
return None
if not isinstance(savers, list):
return savers
if len(savers) > 1:
logging.error(
'Multiple savers in the SAVERS collection. On-demand checkpointing '
'will be disabled. Pass an explicit `saver` to the constructor to '
'override this behavior.')
return None
return savers[0]
def after_run(self, run_context, run_values):
del run_values
if not self._heartbeat_supported:
return
lame_workers = self._workers.lame_workers()
if lame_workers:
logging.info('ShutdownHook: lame workers found: %s', lame_workers)
if self.saver():
logging.info('ShutdownHook: saving checkpoint to %s',
self._checkpoint_prefix)
self.saver().save(
run_context.session,
self._checkpoint_prefix,
global_step=training_util.get_global_step(),
write_state=True,
)
else:
logging.info('ShutdownHook: no Saver defined.')
for fn in self._on_shutdown_hooks:
fn(run_context, self._workers, lame_workers)
class RestartComputation(object):
"""Restart the entire computation.
This hook shuts down all workers and returns control to the top-level by
throwing a CoordinatorShutdownException.
"""
def __init__(self, timeout_ms=10000):
self.timeout_ms = timeout_ms
def __call__(self, run_context, all_workers, lame_workers):
del run_context, lame_workers
all_workers.shutdown(timeout_ms=self.timeout_ms)
logging.info('Terminating coordinator.')
raise CoordinatorShutdownException()
class ShutdownLameWorkers(object):
"""Shutdown lamed workers.
Processing will continue normally (typically by waiting for the down
workers to be restarted).
"""
def __init__(self, timeout_ms=10000):
self.timeout_in_ms = timeout_ms
def __call__(self, run_context, all_workers, lame_workers):
lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.session_support import *
# pylint: enable=wildcard-import,unused-import

File diff suppressed because it is too large Load Diff

View File

@ -1,220 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Defines the `Topology` class, that describes a TPU fabric topology."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf.tpu import topology_pb2
def _tpu_device_name(job, task, device):
"""Returns the device name for the TPU `device` on `task` of `job`."""
if job is None:
return "/task:%d/device:TPU:%d" % (task, device)
else:
return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
def _tpu_host_device_name(job, task):
"""Returns the device name for the CPU device on `task` of `job`."""
if job is None:
return "/task:%d/device:CPU:0" % task
else:
return "/job:%s/task:%d/device:CPU:0" % (job, task)
class Topology(object):
"""Describes a set of TPU devices.
Represents both the shape of the physical mesh, and the mapping between
TensorFlow TPU devices to physical mesh coordinates.
"""
def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
"""Builds a Topology object.
If `serialized` is not `None`, the topology is parsed from `serialized` and
the other arguments are ignored. Otherwise, the topology is computed from
`mesh_shape` and `device_coordinates`.
Args:
serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
serialized proto is parsed to discover the topology.
mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
the shape of the TPU topology, in number of cores. Ignored if
`serialized` is not `None`.
device_coordinates: A rank 3 numpy array that describes the mapping from
TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
if `serialized is not `None`.
Raises:
ValueError: If `serialized` does not describe a well-formed topology.
ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
of 3 positive integers.
ValueError: If `serialized` is `None` and `device_coordinates` is not a
rank 3 numpy int32 array that describes a valid coordinate mapping.
"""
self._serialized = serialized
if serialized:
self._parse_topology(serialized)
else:
self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
self._device_coordinates = np.asarray(device_coordinates, np.int32)
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
raise ValueError("`mesh_shape` must be a sequence of 3 positive "
"entries; got {}".format(self._mesh_shape))
if (len(self._device_coordinates.shape) != 3 or
self._device_coordinates.shape[2] != len(self._mesh_shape)):
raise ValueError("`device_coordinates` must be a rank 3 int32 array "
"with minor dimension equal to the mesh shape rank")
self._topology_tasks, self._topology_devices = self._invert_topology()
def _parse_topology(self, serialized):
"""Parses a serialized `TopologyProto` into `self`."""
proto = topology_pb2.TopologyProto()
proto.ParseFromString(serialized)
self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
"entries; got {}".format(self._mesh_shape))
if proto.num_tasks < 0:
raise ValueError("`num_tasks` must be >= 0; got {}".format(
proto.num_tasks))
if proto.num_tpu_devices_per_task < 0:
raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
proto.num_tpu_devices_per_task))
expected_coordinates_size = (
proto.num_tasks * proto.num_tpu_devices_per_task * len(
proto.mesh_shape))
if len(proto.device_coordinates) != expected_coordinates_size:
raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
"num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
"got shape {}".format(proto.num_tasks,
proto.num_tpu_devices_per_task,
proto.mesh_shape,
len(proto.device_coordinates)))
coords = np.array(proto.device_coordinates, dtype=np.int32)
if any(coords < 0):
raise ValueError("`device_coordinates` must be >= 0")
coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
len(proto.mesh_shape)))
self._device_coordinates = coords
def _invert_topology(self):
"""Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
for task in xrange(self.device_coordinates.shape[0]):
for device in xrange(self.device_coordinates.shape[1]):
x, y, z = self.device_coordinates[task, device, :]
tasks[x, y, z] = task
devices[x, y, z] = device
return tasks, devices
@property
def mesh_shape(self):
"""A rank 1 int32 array describing the shape of the TPU topology."""
return self._mesh_shape
@property
def mesh_rank(self):
"""Returns the number of dimensions in the mesh."""
return len(self._mesh_shape)
@property
def device_coordinates(self):
"""Describes the mapping from TPU devices to topology coordinates.
Returns:
A rank 3 int32 array with shape `[tasks, devices, axis]`.
`tasks` is the number of tasks in the TPU cluster, `devices` is the number
of TPU devices per task, and `axis` is the number of axes in the TPU
cluster topology. Each entry gives the `axis`-th coordinate in the
topology of a task/device pair. TPU topologies are 3-dimensional, with
dimensions `(x, y, core number)`.
"""
return self._device_coordinates
def task_ordinal_at_coordinates(self, device_coordinates):
"""Returns the TensorFlow task number attached to `device_coordinates`.
Args:
device_coordinates: An integer sequence describing a device's physical
coordinates in the TPU fabric.
Returns:
Returns the TensorFlow task number that contains the TPU device with those
physical coordinates.
"""
return self._topology_tasks[tuple(device_coordinates)]
def tpu_device_ordinal_at_coordinates(self, device_coordinates):
"""Returns the TensorFlow device number at `device_coordinates`.
Args:
device_coordinates: An integer sequence describing a device's physical
coordinates in the TPU fabric.
Returns:
Returns the TensorFlow device number within the task corresponding to
attached to the device with those physical coordinates.
"""
return self._topology_devices[tuple(device_coordinates)]
def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
"""Returns the CPU device attached to a logical core."""
return _tpu_host_device_name(
job, self._topology_tasks[tuple(device_coordinates)])
def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
return _tpu_device_name(job,
self._topology_tasks[tuple(device_coordinates)],
self._topology_devices[tuple(device_coordinates)])
@property
def num_tasks(self):
"""Returns the number of TensorFlow tasks in the TPU slice."""
return self._device_coordinates.shape[0]
@property
def num_tpus_per_task(self):
"""Returns the number of TPU devices per task in the TPU slice."""
return self._device_coordinates.shape[1]
def serialized(self):
"""Returns the serialized form of the topology."""
if self._serialized is None:
proto = topology_pb2.TopologyProto()
proto.mesh_shape[:] = list(self._mesh_shape)
proto.num_tasks = self._device_coordinates.shape[0]
proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
self._serialized = proto.SerializeToString()
return self._serialized
# pylint: disable=wildcard-import,unused-import,redefined-builtin
from tensorflow.python.tpu.topology import *
# pylint: enable=wildcard-import,unused-import,redefined-builtin

File diff suppressed because it is too large Load Diff

View File

@ -1,276 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""A RunConfig subclass with TPU support."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import tf_logging as logging
# pylint: disable=protected-access
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
# pylint: enable=protected-access
class InputPipelineConfig(object):
r"""Please see the definition of these values in TPUConfig."""
PER_SHARD_V1 = 1
PER_HOST_V1 = 2
PER_HOST_V2 = 3
BROADCAST = 4
class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
'input_partition_dims',
])):
r"""TPU related configuration required by `TPUEstimator`.
Args:
iterations_per_loop: This is the number of train steps running in TPU
system before returning to CPU host for each `Session.run`. This means
global step is increased `iterations_per_loop` times in one `Session.run`.
It is recommended to be set as number of global steps for next checkpoint.
num_shards: (Deprecated, ignored by TPUEstimator).
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
the model to multiple cores. Currently num_cores_per_replica must be
1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked once on each host. With the per-core input pipeline
configuration, it is invoked once for each core.
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
the batch size for each shard is `train_batch_size` // #hosts in the
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
invoked once on host 0 and the tensors are broadcasted to all other
replicas. The batch size equals to train_batch_size`. With the per-core
input pipeline configuration, the shard batch size is also
`train_batch_size` // #cores.
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
string.
initial_infeed_sleep_secs: The number of seconds the infeed thread should
wait before enqueueing the first batch. This helps avoid timeouts for
models that require a long compilation time.
input_partition_dims: A nested list to describe the partition dims
for all the tensors from input_fn(). The structure of
input_partition_dims must match the structure of `features` and
`labels` from input_fn(). The total number of partitions must match
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
images with shape [N, H, W, C] and labels [N].
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
to all the TPU cores since the partition dims is `None`.
Current limitations: This feature is only supported with the PER_HOST_V2
input mode.
Raises:
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None,
input_partition_dims=None):
# Check iterations_per_loop.
util_lib.check_positive_integer(iterations_per_loop,
'TPUConfig iterations_per_loop')
# Check num_shards.
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
if input_partition_dims is not None:
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
raise ValueError(
'input_partition_dims must be a list/tuple with one or two'
' elements.')
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
raise ValueError(
'input_partition_dims is only supported in PER_HOST_V2 mode.')
if num_cores_per_replica is None:
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
# Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
if per_host_input_for_training is False:
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
elif per_host_input_for_training is True:
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
'TPUConfig initial_infeed_sleep_secs')
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
input_partition_dims=input_partition_dims)
class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
def __init__(self,
tpu_config=None,
evaluation_master=None,
master=None,
cluster=None,
**kwargs):
"""Constructs a RunConfig.
Args:
tpu_config: the TPUConfig that specifies TPU-specific configuration.
evaluation_master: a string. The address of the master to use for eval.
Defaults to master if not set.
master: a string. The address of the master to use for training.
cluster: a ClusterResolver
**kwargs: keyword config parameters.
Raises:
ValueError: if cluster is not None and the provided session_config has a
cluster_def already.
"""
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TPUConfig()
self._cluster = cluster
# If user sets master and/or evaluation_master explicitly, including empty
# string '', take it. Otherwise, take the values set by parent class.
if master is not None:
if cluster is not None:
raise ValueError('Both master and cluster are set.')
self._master = master
else:
if cluster:
self._master = cluster.master()
if evaluation_master is not None:
self._evaluation_master = evaluation_master
elif (not self._evaluation_master and
self.task_type != run_config_lib.TaskType.EVALUATOR):
# If the task type is EVALUATOR, it means some cluster manager sets the
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
#
# Otherwise, it means user executes the code without external cluster
# manager. For that, we optimize the user experience by setting
# evaluation_master to master, unless user overwrites it.
self._evaluation_master = self._master
# Set the ClusterSpec to use
if cluster:
self._cluster_spec = cluster.cluster_spec()
# Merge the cluster_def into the ConfigProto.
if self._session_config is None: # pylint: disable=access-member-before-definition
self._session_config = config_pb2.ConfigProto(
allow_soft_placement=True, isolate_session_state=True)
if self._session_config.HasField('cluster_def'):
raise ValueError(
'You cannot provide a ClusterResolver and '
'session_config.cluster_def.')
if self._cluster_spec:
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
def _maybe_overwrite_session_config_for_distributed_training(self):
# Overrides the parent class session_config overwrite for between-graph. TPU
# runs with in-graph, which should not have device filter. Doing nothing
# ("pass") basically disables it.
pass
@property
def evaluation_master(self):
return self._evaluation_master
@property
def master(self):
return self._master
@property
def tpu_config(self):
return self._tpu_config
@property
def cluster(self):
return self._cluster
def replace(self, **kwargs):
if 'tpu_config' not in kwargs:
return super(RunConfig, self).replace(**kwargs)
tpu_config = kwargs.pop('tpu_config')
new_instance = super(RunConfig, self).replace(**kwargs)
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
return new_instance
def _get_tpu_job_name_from_tf_config():
"""Extracts the TPU job name from TF_CONFIG env variable."""
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
# spec propagation.
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
if tpu_job_name:
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
return tpu_job_name
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_config import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,763 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from contextlib import contextmanager
import copy
from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.platform import tf_logging as logging
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
_NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
8: [2, 2, 2],
16: [4, 2, 2],
}
class TPUContext(object):
"""A context that holds the current configuration of the TPU computation."""
def __init__(self,
internal_ctx,
input_device=None,
invocation_index=None,
call_from_input_fn=True):
self._internal_ctx = internal_ctx
self._input_device = input_device
self._invocation_index = invocation_index
self._call_from_input_fn = call_from_input_fn
def current_input_fn_deployment(self):
"""The configuration of the current input_fn invocation.
The configuration depends on `TPUConfig.per_host_input_for_training`. See
`TPUConfig` for details.
Only set in params dict of input_fn
Returns:
A tuple of
1. Device spec string: String, is the current CPU host where the
input_fn is invoked.
2. Current invocation index: Int, 0-based index of the input_fn
invocation. See next item for details.
3. Total invocation count: Int, the total number of times to invoke the
input_fn on all CPU hosts. Each invocation will be passed with a new
`TPUContext` instance with current invocation index set properly.
4. Total number of replicas consumed by current_invocation: Int, the
number of replicas fed by the data returned by current input_fn. For
example, for per_core input pipeline deployment
and non-model-parallelism, total invocation count is equal to
the number of cores in the system and num replicas consumed by
current invocation is 1. For per-host v2 input pipeline deployment,
total invocation count is equal to the number of hosts in the system
and num replicas consumed by current invocation is equal to number of
cores per host.
Raises:
RuntimeError: If this method must not be called from input_fn.
"""
if not self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' model_fn.')
if self._internal_ctx.is_input_sharded_per_core():
total_invocation_count = (self._internal_ctx.num_hosts
* self._internal_ctx.num_of_replicas_per_host)
replicas_consumed = 1
elif self._internal_ctx.is_input_broadcast_with_iterators():
total_invocation_count = 1
replicas_consumed = self._internal_ctx.num_replicas
else:
total_invocation_count = self._internal_ctx.num_hosts
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
return (self._input_device, self._invocation_index,
total_invocation_count, replicas_consumed)
@property
def num_replicas(self):
"""The total number of replicas.
For non-model-parallelism, num_replicas should be the total num of TPU
cores in the system.
Returns:
The number of replicas.
"""
return self._internal_ctx.num_replicas
@property
def num_hosts(self):
"""The number of hosts for the TPU system."""
return self._internal_ctx.num_hosts
@property
def current_host(self):
"""The current host index for the TPU system."""
return self._invocation_index
@property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
raise ValueError(
'num_of_replicas_per_host is not supported for model_parallelism')
return self._internal_ctx.num_of_replicas_per_host
@property
def device_assignment(self):
"""Returns device_assignment object."""
if self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' input_fn.')
return self._internal_ctx.device_assignment
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
# Note that: For the non-model parallelism, the mapping could be
# a random permutation. The order should not matter in most cases
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function.
The place function takes host_id as the input and returns the TF device
for the correspoding host.
"""
def _placement_function(host_id):
"""Return the host device given host_id."""
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
return _placement_function
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
This immutable object holds TPUEstimator config, train/eval batch size, and
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
provides utility functions, based on the current state, to determine other
information commonly required by TPU computation, such as TPU device names,
TPU hosts, shard batch size, etc.
if eval_on_tpu is False, then execution of eval on TPU is disabled.
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
and TPU execution is disabled for all modes.
N.B. As `mode` is not immutable state in Estimator, but essential to
distinguish between TPU training and evaluation, a common usage for
_InternalTPUContext with `mode` is as follows:
```
with _ctx.with_mode(mode) as ctx:
if ctx.is_running_on_cpu():
...
```
"""
def __init__(self,
config,
train_batch_size,
eval_batch_size,
predict_batch_size,
use_tpu,
eval_on_tpu=True,
embedding_config_spec=None):
self._config = config
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._predict_batch_size = predict_batch_size
self._use_tpu = use_tpu
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
if not use_tpu and eval_on_tpu:
logging.warning('eval_on_tpu ignored because use_tpu is False.')
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
num_cores_per_replica = config.tpu_config.num_cores_per_replica
if self._model_parallelism_enabled:
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
num_cores_per_replica]
else:
self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
self._embedding_config_spec = embedding_config_spec
self._lazy_embedding_config_dict = {} # key by master address
def _assert_mode(self):
if self._mode is None:
raise RuntimeError(
'`mode` needs to be set via contextmanager `with_mode`.')
return self._mode
@contextmanager
def with_mode(self, mode):
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
# such as _lazy_tpu_system_metadata_dict between new copy and the original
# one. Note that all lazy states stored in properties _lazy_foo are sort of
# immutable as they should be same for the process lifetime.
new_ctx = copy.copy(self)
new_ctx._mode = mode # pylint: disable=protected-access
yield new_ctx
@property
def mode(self):
return self._assert_mode()
def _get_master_address(self):
mode = self._assert_mode()
config = self._config
master = (
config.master
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
return master
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
cluster_def = None
if (self._config.session_config and
self._config.session_config.cluster_def.job):
cluster_def = self._config.session_config.cluster_def
# pylint: disable=protected-access
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
query_topology=self.model_parallelism_enabled))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_device_assignment(self):
"""Gets the (maybe cached) TPU device assignment."""
master = self._get_master_address()
device_assignment = self._lazy_device_assignment_dict.get(master)
if device_assignment is not None:
return device_assignment
tpu_system_metadata = self._get_tpu_system_metadata()
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
logging.info('num_cores_per_replica: %s',
str(self._config.tpu_config.num_cores_per_replica))
logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
logging.info('device_assignment.core_assignment: %s',
str(device_assignment.core_assignment))
self._lazy_device_assignment_dict[master] = device_assignment
return device_assignment
@property
def embedding_config(self):
"""Returns the embedding config based on current mode."""
master = self._get_master_address()
if master in self._lazy_embedding_config_dict:
embedding_config = self._lazy_embedding_config_dict[master]
else:
embedding_config = None
if self._use_tpu and self._embedding_config_spec:
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
self._embedding_config_spec, self._train_batch_size,
self._eval_batch_size, self.num_hosts, self.num_cores, master)
if not embedding_config.has_embedding_tables():
embedding_config = None
self._lazy_embedding_config_dict[master] = embedding_config
if embedding_config is not None:
mode = self._assert_mode()
# Dynamically attach tpu_embedding based on mode. With
# this, we could keep embedding_config immutable but call site always
# accesses the unified API '.tpu_embedding'.
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
return embedding_config
@property
def model_parallelism_enabled(self):
return self._model_parallelism_enabled
@property
def input_partition_dims(self):
return self._config.tpu_config.input_partition_dims
@property
def device_assignment(self):
return (self._get_device_assignment()
if self._model_parallelism_enabled else None)
@property
def num_of_cores_per_host(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_of_cores_per_host
@property
def num_cores(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_cores
@property
def num_of_replicas_per_host(self):
"""Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
return self.num_of_cores_per_host
@property
def num_replicas(self):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the total num of '
'TPU cores in the system. num_cores_per_replica: {}, num cores '
'in the system: {}'.format(num_cores_per_replica,
num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
else:
return num_cores_in_system
@property
def num_hosts(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_hosts
@property
def config(self):
return self._config
def is_input_sharded_per_core(self):
"""Return true if input_fn is invoked per-core (other than per-host)."""
mode = self._assert_mode()
return (mode == model_fn_lib.ModeKeys.TRAIN and
(self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1))
def is_input_per_host_with_iterators(self):
"""Return true if input_fn should be run in the per-host v2 config."""
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_HOST_V2)
def is_input_broadcast_with_iterators(self):
"""Return true if input_fn should be run in the full_replicae config."""
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.BROADCAST)
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
This API also validates user provided configuration, such as batch size,
according the lazy initialized TPU system metadata.
Args:
is_export_mode: Indicates whether the current mode is for exporting the
model, when mode == PREDICT. Only with this bool, we could
tell whether user is calling the Estimator.predict or
Estimator.export_savedmodel, which are running on TPU and CPU
respectively. Parent class Estimator does not distinguish these two.
Returns:
bool, whether current input_fn or model_fn should be running on CPU.
Raises:
ValueError: any configuration is invalid.
"""
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
if not is_running_on_cpu:
self._validate_tpu_configuration()
return is_running_on_cpu
def _is_running_on_cpu(self, is_export_mode):
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
mode = self._assert_mode()
if not self._use_tpu:
return True
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
return True
if is_export_mode:
return True
return False
@property
def global_batch_size(self):
mode = self._assert_mode()
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size
elif mode == model_fn_lib.ModeKeys.EVAL:
return self._eval_batch_size
elif mode == model_fn_lib.ModeKeys.PREDICT:
return self._predict_batch_size
else:
return None
@property
def batch_size_for_input_fn(self):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU
if self.is_input_sharded_per_core() or (
self.is_input_per_host_with_iterators()):
return global_batch_size // self.num_replicas
else:
return global_batch_size // self.num_hosts
@property
def batch_size_for_model_fn(self):
"""Returns the shard batch size for `model_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU. always sharded per shard.
return global_batch_size // self.num_replicas
@property
def master_job(self):
"""Returns the job name to use to place TPU computations on.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
run_config = self._config
# If the user specifies the tpu_job_name, use that.
if run_config.tpu_config.tpu_job_name:
return run_config.tpu_config.tpu_job_name
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
mode = self._assert_mode()
master = (
run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
if master in _LOCAL_MASTERS:
return None
if (not run_config.session_config or
not run_config.session_config.cluster_def.job):
return _DEFAULT_JOB_NAME
cluster_def = run_config.session_config.cluster_def
job_names = set([job.name for job in cluster_def.job])
if _DEFAULT_JOB_NAME in job_names:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise ValueError('Currently, tpu_worker is not an allowed job name.')
if len(job_names) == 1:
return cluster_def.job[0].name
if len(job_names) == 2:
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
return job_names.pop()
# TODO(b/67716447): Include more sophisticated heuristics.
raise ValueError(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.')
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
master = self.master_job
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
"""Return the host device given replica_id or host_id."""
assert _sentinal is None
if replica_id is not None and host_id is not None:
raise RuntimeError(
'replica_id and host_id can have only one non-None value.')
if master is None:
return '/replica:0/task:0/device:CPU:0'
else:
if replica_id is not None:
if self.model_parallelism_enabled:
return self.device_assignment.host_device(
replica=replica_id, job=master)
else:
host_id = replica_id / self.num_of_cores_per_host
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
return _placement_function
@property
def tpu_device_placement_function(self):
"""Returns a TPU device placement Fn."""
master = self.master_job
job_device = '' if master is None else ('/job:%s' % master)
def _placement_function(i):
if self.model_parallelism_enabled:
return self.device_assignment.tpu_device(replica=i, job=master)
else:
num_of_cores_per_host = self.num_of_cores_per_host
host_id = i / num_of_cores_per_host
ordinal_id = i % num_of_cores_per_host
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
return _placement_function
def tpu_ordinal_function(self, host_id):
"""Returns the TPU ordinal fn."""
def _tpu_ordinal_function(shard_index_in_host):
"""Return the TPU ordinal associated with a shard.
Required because the enqueue ops are placed on CPU.
Args:
shard_index_in_host: the shard index
Returns:
The ordinal of the TPU device the shard's infeed should be placed on.
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
replica = self.device_assignment.lookup_replicas(host_id,
0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
return _tpu_ordinal_function
def _validate_tpu_configuration(self):
"""Validates the configuration based on the TPU system metadata."""
mode = self._assert_mode()
if self._lazy_validation_dict.get(mode):
return
# All following information is obtained from TPU system metadata.
num_cores = self.num_cores
num_replicas = self.num_replicas
num_hosts = self.num_hosts
if not num_cores:
tpu_system_metadata = self._get_tpu_system_metadata()
raise RuntimeError(
'Cannot find any TPU cores in the system. Please double check '
'Tensorflow master address and TPU worker(s). Available devices '
'are {}.'.format(tpu_system_metadata.devices))
if self._config.tpu_config.num_shards:
user_provided_num_replicas = self._config.tpu_config.num_shards
if user_provided_num_replicas != num_replicas:
message = (
'TPUConfig.num_shards is not set correctly. According to TPU '
'system metadata for Tensorflow master ({}): num_replicas should '
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
raise ValueError(message)
if self._config.tpu_config.num_cores_per_replica:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
if num_cores_per_replica > num_cores_per_host:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the '
'num_cores_per_host. num_cores_per_replica: {}, '
'num_cores_per_host: {}'.format(num_cores_per_replica,
num_cores_per_host))
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
elif mode == model_fn_lib.ModeKeys.EVAL:
if self._eval_batch_size is None:
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
if (self._eval_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.evaluate should be running on single TPU'
' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
if (self._predict_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.predict should be running on single TPU worker. '
'got {}.'.format(num_hosts))
# Record the state "validated" into lazy dictionary.
self._lazy_validation_dict[mode] = True
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
master = self.master_job
if self.model_parallelism_enabled:
return (self.device_assignment.host_device(
replica=replica_id, job=master),
self.device_assignment.tpu_ordinal(replica=replica_id))
job_device = '' if master is None else ('/job:%s' % master)
num_of_replicas_per_host = self.num_of_replicas_per_host
host_id = replica_id / num_of_replicas_per_host
ordinal_id = replica_id % num_of_replicas_per_host
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
return (host_device, ordinal_id)
class _OneCoreTPUContext(_InternalTPUContext):
"""Special _InternalTPUContext for one core usage."""
def __init__(self, config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu):
super(_OneCoreTPUContext, self).__init__(
config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
tpu_system_metadata = (
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
num_cores=1,
num_hosts=1,
num_of_cores_per_host=1,
topology=None,
devices=[]))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_tpu_context(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec):
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
config.tpu_config.num_cores_per_replica is None):
if embedding_config_spec is not None:
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
'when embedding_config_spec is not None.')
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec)
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_context import *
# pylint: enable=wildcard-import,unused-import

File diff suppressed because it is too large Load Diff

View File

@ -1,153 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Optional helper for gradient handling."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
def get_gradients_through_compute_gradients(optimizer, loss, activations):
"""Compute gradients to send to TPU embedding.
Args:
optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
Used to call compute_gradients().
loss: a Tensor to call optimizer.compute_gradients() on.
activations: an OrderedDict mapping feature_name to Tensors of activations.
Returns:
An OrderedDict mapping from feature name Strings to Tensors of gradients of
the loss wrt the activations of the features.
"""
activation_list = activations.values()
grads_and_vars = optimizer.compute_gradients(loss, activation_list)
grads = [grad for grad, _ in grads_and_vars]
feature_to_gradient_dict = collections.OrderedDict(
zip(activations.keys(), grads))
return feature_to_gradient_dict
def create_dummy_table_variables(tpu_embedding):
"""Create dummy embedding table variables.
The sole purpose of these dummy variables are to trigger gradient
calcuation wrt them so that the gradients wrt activation can be captured
and later sent to TPU embedding.
Args:
tpu_embedding: TPUEmbedding, dummy table variables will be created for use
with tpu_embedding.
Returns:
A tuple of dummy variables and their initializer.
Raises:
RuntimeError: if collection to store gradients already exists and is not
empty.
"""
dummy_table_variables = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
dummy_table_variables[table] = (
# Explicitly specifying collections prevents this variable from
# being added to the GLOBAL_VARIABLES collection, so that Saver()
# ignores it.
# But Tensorflow optimizer creates slot variable for these dummy
# variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
# which will be in GLOBAL_VARIABLES collection,
variable_scope.get_variable(
'tpu_embedding_dummy_table_variable_{}'.format(table),
dtype=dtypes.float32,
shape=[1],
use_resource=True,
trainable=True,
collections=['tpu_embedding_dummy_table_variables']))
g = ops.get_default_graph()
table_gradients = g.get_collection_ref(
'tpu_embedding_gradients_table_{}'.format(table_id))
if table_gradients:
raise RuntimeError(
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
table_gradients.extend(
[None] * len(tpu_embedding.table_to_features_dict[table]))
return (dummy_table_variables,
variables.variables_initializer(
dummy_table_variables.values(),
name='tpu_embedding_dummy_table_variables_init'))
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
dummy_table_variables):
"""Have activations depend on dummy table variables for gradient intercept.
Args:
tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
tpu_embedding.
activations: An OrderedDict of feature name String to activation tensors.
dummy_table_variables: An OrderedDict of table name String to dummy table
variables.
Returns:
An OrderedDict of feature name String to activation tensors, which can be
used just as the activations input.
"""
new_activations = collections.OrderedDict()
for feature in activations:
table = tpu_embedding.feature_to_table_dict[feature]
new_activations[feature] = tpu_ops.tpu_embedding_activations(
dummy_table_variables[table],
activations[feature],
table_id=tpu_embedding.table_to_config_dict.keys().index(table),
lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
return new_activations
def get_gradients_through_dummy_table_variables(tpu_embedding):
"""Get gradients wrt the activations of each feature.
Args:
tpu_embedding: TPUEmbedding, create dummy table variable to be used with
tpu_embedding.
Returns:
An OrderedDict mapping feature name to gradient.
Raises:
ValueError: if some gradients are not defined.
"""
g = ops.get_default_graph()
feature_to_gradient_dict = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
table_gradients = g.get_collection(
'tpu_embedding_gradients_table_{}'.format(table_id))
if any(gradient is None for gradient in table_gradients):
raise ValueError(
'Table {} with id {} has undefined gradients: this is probably '
'because the model asked TPUEmbedding to compute activations that '
'were not used.'.format(table, table_id))
for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
table_gradients):
feature_to_gradient_dict[feature] = gradient
return feature_to_gradient_dict
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_embedding_gradient import *
# pylint: enable=wildcard-import,unused-import

File diff suppressed because it is too large Load Diff

View File

@ -1,919 +1,25 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Helper library for handling infeed between hosts and TPUs.
"""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
def partition_or_replicate_on_host(tensor, dims):
"""Partitions or replicates the input tensor.
The ops inside this function are placed on the host side.
Args:
tensor: The input tensor which will be partioned or replicated.
dims: A list of integer describes how to partition the input tensor.
Returns:
An iterator of `Tensor`s or a list of partioned tensors.
"""
if dims is None:
return itertools.repeat(tensor)
dims = np.array(dims)
output = [tensor]
shape_list = np.array(tensor.shape.as_list())
quotients, remainders = np.divmod(shape_list, dims)
for axis, (quotient, remainder, dim, original_size) in enumerate(
zip(quotients, remainders, dims, shape_list)):
if dim <= 1:
continue
if remainder > 0:
# For each dimension, when it cannot be evenly partitioned, XLA assumes
# tensors are partitioned in a greedy manner by using
# ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
# are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
# [[(3, 4), (3, 4), (2, 4), (2, 2)],
# [(2, 4), (2, 4), (2, 4), (2, 2)]]
ceil_ratio = quotient + 1
num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
if len(num_or_size_splits) < dim:
num_or_size_splits += [0] * (dim - len(num_or_size_splits))
new_output = []
for x in output:
new_output.append(
array_ops.split(
x, num_or_size_splits=num_or_size_splits, axis=axis))
output = new_output
else:
output = [array_ops.split(x, dim, axis=axis) for x in output]
output = nest.flatten(output)
return output
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
Args:
tensor: The dequeued tensor on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same tensor with the xla_sharding attribute.
"""
if dims is None:
return xla_sharding.replicate(tensor)
elif np.prod(dims) == 1:
return xla_sharding.assign_device(tensor, 0)
else:
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
Args:
dequeues: A list of dequeued tensors on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same dequeues with appropriate xla_sharding attribute.
"""
nest.assert_shallow_structure(dequeues, dims)
return nest.map_structure_up_to(
dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
class InfeedQueue(object):
"""A helper object to build a device infeed queue.
The InfeedQueue builds the host-side and device-side Ops to enqueue and
dequeue elements, respectively, and ensures that their types and
shapes match.
"""
def __init__(self,
number_of_tuple_elements=None,
tuple_types=None,
tuple_shapes=None,
shard_dimensions=None,
name=None):
"""Creates a new InfeedQueue with the given configuration.
The configuration need not be fully specified at creation since it
can be modified subsequently by methods that set the values
explicitly or infer them from the shapes of inputs.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
tuple_types: if not None, a list of types of the elements of the queue.
tuple_shapes: if not None, a list of shapes of the elements of the queue.
shard_dimensions: if not None, a list of dimensions on which the
elements of the queue should be sharded during automatic
parallelization.
name: the name of the queue.
Raises:
ValueError: if number_of_tuple_elements <= 0; or
number_of_tuple_arguments, tuple_types, tuple_shapes, and
shard_dimensions are all None; or the length of tuple_types,
tuple_shapes, or shard_dimensions is not equal to
number_of_tuple_elements; or any element of shard_dimensions
can't be converted to a Dimension.
TypeError: if any element of tuple_types or tuple_shapes can't
be converted to a dtype or TensorShape, respectively.
"""
self._frozen = False
self._generated_enqueue_ops = False
self._generated_dequeue_op = False
self._name = "InfeedQueue" if name is None else name
if number_of_tuple_elements is None:
if tuple_types is not None:
number_of_tuple_elements = len(tuple_types)
elif tuple_shapes is not None:
number_of_tuple_elements = len(tuple_shapes)
elif shard_dimensions is not None:
number_of_tuple_elements = len(shard_dimensions)
else:
raise ValueError(
"number of tuple elements cannot be inferred from InfeedQueue "
"constructor")
if number_of_tuple_elements <= 0:
raise ValueError("number_of_tuple_elements %d must be > 0" %
number_of_tuple_elements)
# Make an empty sharding policy for each tuple element.
self._sharding_policies = [
tpu_sharding.ShardingPolicy()
for _ in xrange(number_of_tuple_elements)
]
if tuple_types is not None:
self.set_tuple_types(tuple_types)
else:
self._tuple_types = None
if tuple_shapes is not None:
self.set_tuple_shapes(tuple_shapes)
else:
self._tuple_shapes = None
if shard_dimensions is not None:
self.set_shard_dimensions(shard_dimensions)
self._validate()
def _validate(self):
"""Checks that the configuration is self-consistent.
Raises:
ValueError: if the shapes and sharding policies don't match.
"""
if self.tuple_shapes is not None:
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
# Raise an error if the policy is incompatible with the shape.
_ = policy.get_sharded_shape(shape)
@property
def number_of_tuple_elements(self):
"""Returns the number of InfeedQueue tuple elements."""
return len(self._sharding_policies)
@property
def tuple_types(self):
"""Returns the types of the InfeedQueue tuple elements."""
return self._tuple_types
def set_tuple_types(self, tuple_types):
"""Sets the type of each element of the queue.
tuple_types must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a dtype.
Args:
tuple_types: the types of each queue element.
Raises:
ValueError: if tuple_types is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_types cannot be converted to a
dtype.
"""
if len(tuple_types) != self.number_of_tuple_elements:
raise ValueError("tuple_types is %s, but must be a list of length %d" %
(str(tuple_types), self.number_of_tuple_elements))
if self._frozen:
for (frozen, updated) in zip(self._tuple_types, tuple_types):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible type. Frozen types are %s, updated types are %s" % (
str(self._tuple_types), str(tuple_types)))
else:
try:
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
except (TypeError) as e:
raise TypeError(
"tuple_types is %s, but must be a list of elements each "
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
@property
def tuple_shapes(self):
"""Returns the shapes of the InfeedQueue tuple elements."""
return self._tuple_shapes
def set_tuple_shapes(self, tuple_shapes):
"""Sets the shape of each element of the queue.
tuple_shapes must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a TensorShape.
Args:
tuple_shapes: the shapes of each queue element.
Raises:
ValueError: if tuple_shapes is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_shapes cannot be converted to
a TensorShape.
"""
if len(tuple_shapes) != self.number_of_tuple_elements:
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
(str(tuple_shapes), self.number_of_tuple_elements))
try:
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
except (ValueError, TypeError) as e:
raise TypeError(
"tuple_shapes is %s, but must be a list of elements each "
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
str(e)))
if self._frozen:
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
% (str(self._tuple_shapes), str(tuple_shapes)))
else:
self._tuple_shapes = tuple_shapes
self._validate()
@property
def sharding_policies(self):
"""Returns the sharding policies of the InfeedQueue tuple elements."""
return self._sharding_policies
@property
def shard_dimensions(self):
"""Gets the shard dimension of each tuple element.
Returns:
A list of length number_of_tuple_elements, where each list entry
is the shard dimension of that tuple element or None if the
shard dimension has not been set.
"""
# The number of shards is always the same for all the policies.
return [policy.shard_dimension for policy in self._sharding_policies]
def set_shard_dimensions(self, shard_dimensions):
"""Sets the shard_dimension of each element of the queue.
shard_dimensions must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a Dimension compatible with self.tuple_shapes.
Args:
shard_dimensions: the dimensions of each queue element.
Raises:
ValueError: if shard_dimensions is not of length
self.number_of_tuple_elements; or an element of
shard_dimensions cannot be converted to a Dimension; or an
element of shard_dimensions is a Dimension that is out of
range for the corresponding tuple element shape.
"""
if len(shard_dimensions) != self.number_of_tuple_elements:
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
% (str(shard_dimensions),
self.number_of_tuple_elements))
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
policy.set_shard_dimension(dimension)
self._validate()
@property
def number_of_shards(self):
"""Gets the number of shards to use for the InfeedQueue.
Returns:
Number of shards or None if the number of shards has not been set.
"""
# The number of shards is always the same for all the policies.
return self._sharding_policies[0].number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards to use for the InfeedQueue.
Args:
number_of_shards: number of ways to shard the InfeedQueue.
Raises:
ValueError: if number_of_shards is not > 0; or the policies have
been frozen and number_of_shards was already set to something
else.
"""
for policy in self._sharding_policies:
policy.set_number_of_shards(number_of_shards)
self._validate()
def set_configuration_from_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of Tensors whose types and shapes are used
to set the queue configuration.
Args:
input_tensors: list of Tensors of the same types and shapes as
the desired queue Tuple.
Raises:
ValueError: if input_tensors is not a list of length
self.number_of_tuple_elements
"""
if len(input_tensors) != self.number_of_tuple_elements:
raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
% (str(input_tensors), self.number_of_tuple_elements))
self.set_tuple_shapes([t.shape for t in input_tensors])
self.set_tuple_types([t.dtype for t in input_tensors])
def set_configuration_from_sharded_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of lists of Tensors whose types and shapes are used
to set the queue configuration. The length of the outer list is the number
of shards required, and each inner list is the tuple of Tensors to use to
determine the types and shapes of the corresponding shard. This method
depends on the shard dimension, and calling it freezes the shard policy.
Args:
input_tensors: list of lists of Tensors. The outer list length corresponds
to the desired number of shards, and each inner list is the size
and shape of the desired configuration of the corresponding shard.
Raises:
ValueError: if any inner list is not a list of length
self.number_of_tuple_elements; or the inner lists do not combine to
form a consistent unsharded shape.
TypeError: if the types of the Tensors in the inner lists do not match.
"""
if not self._frozen:
# Unset the tuple shapes in case the configuration becomes
# transiently inconsistent.
self._tuple_shapes = None
number_of_shards = len(input_tensors)
self.set_number_of_shards(number_of_shards)
for t in input_tensors:
if len(t) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s but must be a list of lists, where each inner"
" list has length number_of_tuple_elements=%d" % (
str(input_tensors), self.number_of_tuple_elements))
# Transpose the inputs to make a list of shard shapes for each tuple
# element.
sharded_shapes = [[t[i].shape for t in input_tensors]
for i in xrange(self.number_of_tuple_elements)]
# For each tuple, get the unsharded shape using that tuple's policy.
unsharded_shapes = [
policy.get_unsharded_shape(s)
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
]
self.set_tuple_shapes(unsharded_shapes)
for i in xrange(1, self.number_of_shards):
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
if t1.dtype != t2.dtype:
raise TypeError(
"types of the tuple elements of input_tensors %s are not "
"consistent" % str(input_tensors))
self.set_tuple_types([t.dtype for t in input_tensors[0]])
def freeze(self):
"""Freezes the InfeedQueue so it can no longer be modified.
The configuration is implicitly frozen before any host-side or
device-side Ops are generated. The configuration cannot be frozen
until the types and shapes of the tuple elements have been set.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set.
"""
self._frozen = True
if self._tuple_types is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple types.")
if self._tuple_shapes is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for shape in self._tuple_shapes:
if shape.dims is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for policy in self._sharding_policies:
policy.freeze()
self._validate()
def generate_dequeue_op(self, tpu_device=0):
"""Generates the device-side Op to dequeue a tuple from the queue.
Implicitly freezes the queue configuration if it is not already
frozen, which will raise errors if the shapes and types have not
been fully specified.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed. If None, no explicit placement will be performed, and it is up
to the user to call this API from within a proper TPU device scope.
The XLA code will fail if the TPU dequeue instruction is not bound to
any device.
Returns:
A list of Outputs corresponding to a shard of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
if tpu_device is not None:
with ops.device(tpu.core(tpu_device)):
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
else:
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
def _generate_enqueue_op(self,
inputs,
name_prefix,
index,
device=None,
tpu_ordinal=-1):
"""Generate a host-side Op to enqueue a tuple to the queue.
If device is None the inputs are all required to have the same
device specification, and the enqueue Op is colocated with
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
Args:
inputs: a list of Tensors with the types and shapes of the tuple elements.
name_prefix: the base name for the Op.
index: the shard index, used to uniquify the Op name.
device: device to place the Op on, or None if it should be
colocated with the inputs.
tpu_ordinal: ordinal of the TPU device on the host to use for
infeed if device is a CPU device. Should be set to -1 if device
is a TPU device.
Returns:
An Op corresponding to a shard of infeed enqueued at the host,
suitable for use within a replicated block.
Raises:
ValueError: if device is None and inputs do not all have the
same device specification.
"""
full_name = "%s/%d" % (name_prefix, index)
shapes = [t.shape for t in inputs]
if device is None:
devices = [t.device for t in inputs]
for i in xrange(1, self.number_of_tuple_elements):
if devices[0] != devices[i]:
raise ValueError(
"input devices for shard %d are %s, but should all be the same" %
(index, str(devices)))
with ops.colocate_with(inputs[0]):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
else:
with ops.device(device):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
def generate_enqueue_ops(self,
sharded_inputs,
tpu_ordinal_function=None,
placement_function=None):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
shard 0 if the queue. Returns the host-side Ops that must be run to
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
for shard i.
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of sharded_inputs, an error
will be raised.
Args:
sharded_inputs: a list of lists of Tensors. The length of the outer list
determines the number of shards. Each inner list indicates the types
and shapes of the tuples in the corresponding shard.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. tpu_ordinal_function must be
set if the inputs are placed on CPU devices.
placement_function: if not None, a function that takes the shard index as
input and returns the host device where the enqueue op should be placed
on.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
if tpu_ordinal_function is None:
tpu_ordinal_function = lambda index: -1
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
tpu_ordinal=tpu_ordinal_function(index),
device=placement_function(index) if placement_function else None)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
# TODO(misard) Generalize this to the case of systems that don't
# have 8 devices per host, and figure out what to do with
# model-parallelism.
def _default_placement_function(self, index):
return "/task:%d/device:CPU:0" % (index / 8)
def _default_ordinal_function(self, index):
return index % 8
# TODO(b/36470756) remove this from tutorials once we have a better story
# for automatic placement of input pipelines.
def split_inputs_and_generate_enqueue_ops(self,
inputs,
device_assignment=None,
placement_function=None,
tpu_ordinal_function=None):
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
Generates the host-side Ops to enqueue a tuple.
This method performs poorly because it takes an entire input on a single
host, splits it, and distributes it to all of the cores. It is present only
to simplify tutorial examples.
inputs is a list of Tensors to use to feed the queue. Each input is split
into self.number_of_shards shards. Returns an Op for each shard to enqueue
the shard. The Op for shard i is placed on device placement_function(i).
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of inputs, an error
will be raised.
Args:
inputs: a list of Tensors which indicates the types and shapes of the
queue tuple.
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
device_assignment is not `None`, but `placement_function` and
`ordinal_function` are None, then `device_assignment` will be used to
place infeeds on the first k TPU shards, where k is the number of shards
in the queue. If all three are `None`, then default placement and
ordinal functions are used.
placement_function: if not None, a function that takes the shard
index as input and returns a device string indicating which
device the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of inputs are not compatible with the frozen
configuration.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of inputs are not compatible with the frozen
configuration.
"""
if device_assignment is None:
if placement_function is None:
placement_function = self._default_placement_function
if tpu_ordinal_function is None:
tpu_ordinal_function = self._default_ordinal_function
else:
def _placement_function_from_map(index):
return device_assignment.host_device(replica=index)
def _ordinal_function_from_map(index):
return device_assignment.tpu_ordinal(replica=index)
if placement_function is None:
placement_function = _placement_function_from_map
if tpu_ordinal_function is None:
tpu_ordinal_function = _ordinal_function_from_map
self.set_configuration_from_input_tensors(inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
split_name_prefix = "%s/split" % self._name
if self.number_of_shards == 1:
transposed_sharded_inputs = [[inp] for inp in inputs]
else:
def split_fn(inp, num_shards, axis, name):
with ops.colocate_with(inp):
return array_ops.split(inp, num_shards, axis=axis, name=name)
transposed_sharded_inputs = [
split_fn(
inp,
self.number_of_shards,
axis=policy.shard_dimension,
name="%s/%d" % (split_name_prefix, index))
for (inp, policy, index) in zip(inputs, self._sharding_policies,
xrange(self.number_of_tuple_elements))
]
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
for i in xrange(self.number_of_shards)]
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
device=placement_function(index),
tpu_ordinal=tpu_ordinal_function(index))
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
class _PartitionedInfeedQueue(InfeedQueue):
"""A helper object to build a device infeed queue with input partition.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
device_assignment: A TPU `DeviceAssignment` which is used to place all the
partitions to different TPU infeed queues.
host_id: The id of the host machine.
input_partition_dims: A nested list/tuple of integers. Each inner
list/tuple describes how to partition the corresponding input tensor.
tuple_types: If not None, a list of types of the elements of the queue.
tuple_shapes: If not None, a list of shapes of the elements of the queue.
name: The name of the queue.
"""
def __init__(self,
number_of_tuple_elements,
device_assignment,
host_id,
input_partition_dims=None,
tuple_types=None,
tuple_shapes=None,
name=None):
super(_PartitionedInfeedQueue, self).__init__(
number_of_tuple_elements=number_of_tuple_elements,
tuple_types=tuple_types,
tuple_shapes=None,
shard_dimensions=None,
name="PartitionedInfeedQueue" if name is None else name)
self._input_partition_dims = input_partition_dims
self._host_id = host_id
self._device_assignment = device_assignment
def generate_dequeue_op(self, tpu_device=0):
"""Generate TPU dequeue ops.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed.
Returns:
A list of Outputs corresponding to a partition of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
with ops.device(tpu.core(tpu_device)):
values = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
return tag_sharding_attribute_for_dequeued_tensors(
values, self._input_partition_dims)
def generate_enqueue_ops(self, per_host_sharded_inputs):
"""Generates the host-side Ops to enqueue the partitioned inputs.
per_host_sharded_inputs is a list, one for each replica, of lists of
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
replica i.
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
For example, if sharded_inputs[i][j] is a 2-D Tensor:
[[A, B, C, D],
[E ,F, G, H]]
self._input_partition_dims[j] is [2, 4].
sharded_inputs[i][j] will be partitioned and flattened into:
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
Args:
per_host_sharded_inputs: a list of lists of Tensors. The length of the
outer list determines the number of shards. Each inner list indicates
the types and shapes of the tuples in the corresponding shard.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints; or if the partition dims are invalid.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
number_of_replicas_per_host = len(per_host_sharded_inputs)
number_of_tuple_elements = len(per_host_sharded_inputs[0])
assert len(self._input_partition_dims) == number_of_tuple_elements
per_host_enqueue_ops = []
for replica_index in range(number_of_replicas_per_host):
flattened_inputs = per_host_sharded_inputs[replica_index]
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
self._input_partition_dims)
inputs_parted_iters = [
iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
for x, dims in zip(per_host_sharded_inputs[replica_index],
inputs_part_dims_flat)
]
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
replica=replica_id, logical_core=logical_core)
infeed_inputs = []
for it in inputs_parted_iters:
input_for_device = next(it, None)
if input_for_device is not None:
infeed_inputs.append(input_for_device)
if infeed_inputs:
per_host_enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
def _check_input_partition_dims(self, tensor, dims):
"""Checks that input partition dims are valid for the `Tensor`.
Args:
tensor: Input tensor for partitioning.
dims: A list of integer describes how to partition the input tensor.
Raises:
ValueError: If the tensor can't be partitioned by dims or the
num_cores_per_replica doesn't match the number of
partitions(dims.prod()).
"""
# No partitioning specified, so don't perform further checks.
if dims is None:
return
dims = np.array(dims)
if (dims < 1).any():
raise ValueError("All input partition dims must be >= 1.")
# No partitioning, so don't perform further checks.
if dims.prod() == 1:
return
if dims.prod() != self._device_assignment.num_cores_per_replica:
raise ValueError(
"The product of each input parition dim should equal to "
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
if dims.shape[0] != tensor.shape.ndims:
raise ValueError(
"Input partition dims must have the same number of dimensions "
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
"partition dims = {}).".format(tensor.shape.as_list(), dims))
tensor.shape.assert_is_fully_defined()
def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
"""Checks dims and partitions or replicates the input tensor.
The ops inside this function are placed on the host side.
Args:
tensor: The input tensor which will be partioned or replicated.
dims: A list of integer describes how to partition the input tensor.
Returns:
An iterator of `Tensor`s or a list of partioned tensors.
"""
self._check_input_partition_dims(tensor, dims)
return partition_or_replicate_on_host(tensor, dims)
# pylint: disable=wildcard-import,unused-import,redefined-builtin
from tensorflow.python.tpu.tpu_feed import *
# used by tests
from tensorflow.python.tpu.tpu_feed import _PartitionedInfeedQueue
# pylint: enable=wildcard-import,unused-import,redefined-builtin

View File

@ -1,66 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for functions used during TPU compilation."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
class TpuContext(object):
"""A context object holding state about the TPU computation being built."""
def __init__(self):
"""Creates a new TpuContext."""
self._number_of_shards = None
@property
def number_of_shards(self):
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
self._number_of_shards = number_of_shards
# The Tpu context holds the number of shards when a sharded computation is
# being built, or None if no computation is being built.
_current_tpu_context = TpuContext()
@contextlib.contextmanager
def tpu_shard_context(number_of_shards):
if _current_tpu_context.number_of_shards is not None:
raise NotImplementedError("tpu_shard_context cannot be nested.")
try:
_current_tpu_context.set_number_of_shards(number_of_shards)
yield
finally:
_current_tpu_context.set_number_of_shards(None)
def get_tpu_context():
return _current_tpu_context
# Decorator function for tpu computation func that was passed to tpu.rewrite()
# if there is an embedded training loop in this func, trace tools will generate
# step markers for each iteration.
def on_device_training_loop(func):
# Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
return func
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_function import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,203 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Optimizer that implements cross-shard gradient reduction for TPU."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer
class CrossShardOptimizer(optimizer.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
def __init__(self,
opt,
reduction=losses.Reduction.MEAN,
name="CrossShardOptimizer",
group_assignment=None):
"""Construct a new cross-shard optimizer.
Args:
opt: An existing `Optimizer` to encapsulate.
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
group_assignment: Optional 2d int32 lists with shape
[num_groups, num_replicas_per_group] which describles how to apply
optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
"""
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
raise ValueError("Unsupported reduction: %s." % reduction)
super(CrossShardOptimizer, self).__init__(False, name)
self._opt = opt
self._reduction = reduction
self._group_assignment = group_assignment
def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
"""Verify group_assignment and get the subgroup size".
Args:
group_assignment: list of group ids for applying the optimizer
to subgroups.
num_shards: The number of TPU shards.
Returns:
The size of one subgroup in group_assignment.
Raises:
ValueError: If group_assignment is invalid.
"""
if not group_assignment:
return None
if not (isinstance(group_assignment, list) and
all(isinstance(i, list) for i in group_assignment)):
raise ValueError("group_assignment must be a list of list. Got {}".format(
group_assignment))
replica_ids = set()
for g in group_assignment:
for i in g:
replica_ids.add(i)
if set(range(num_shards)) != replica_ids:
raise ValueError("group_assignment must be a permutation of range({0})."
" Got group_assignment={1}".format(
num_shards, group_assignment))
subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else:
raise ValueError("The size of each subgroup in group_assignment must "
"be equal. Got group_assignment={}".format(
self._group_assignment))
def compute_gradients(self, loss, var_list=None, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
This simply wraps the compute_gradients() from the real optimizer. The
gradients will be aggregated in the apply_gradients() so that user can
modify the gradients like clipping with per replica global norm if needed.
The global norm with aggregated gradients can be bad as one replica's huge
gradients can hurt the gradients from other replicas.
Args:
loss: A Tensor containing the value to minimize.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph
under the key `GraphKey.TRAINABLE_VARIABLES`.
**kwargs: Keyword arguments for compute_gradients().
Returns:
A list of (gradient, variable) pairs.
Raises:
ValueError: If not within a tpu_shard_context or group_assignment is
invalid.
"""
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
logging.warning(
"CrossShardOptimizer should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
num_shards)
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
if self._group_assignment:
scale = 1.0 / subgroup_size
else:
scale = 1.0 / num_shards
loss *= scale
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
replicas, and then applies the real optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
compute_gradients().
global_step: Optional Variable to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the Optimizer constructor.
Returns:
An `Operation` that applies the gradients. If `global_step` was not None,
that operation also increments `global_step`.
Raises:
ValueError: If the grads_and_vars is malformed.
"""
summed_grads_and_vars = []
for (grad, var) in grads_and_vars:
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
with ops.colocate_with(grad):
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
grad, self._group_assignment), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
"""Return a slot named "name" created for "var" by the Optimizer.
This simply wraps the get_slot() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
return self._opt.get_slot(*args, **kwargs)
def get_slot_names(self, *args, **kwargs):
"""Return a list of the names of slots created by the `Optimizer`.
This simply wraps the get_slot_names() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
A list of strings.
"""
return self._opt.get_slot_names(*args, **kwargs)
def variables(self):
"""Forwarding the variables from the underlying optimizer."""
return self._opt.variables()
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_optimizer import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,253 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for sharding during TPU compilation."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import tensor_shape
_DEFAULT_NUMBER_OF_SHARDS = 1
_DEFAULT_SHARD_DIMENSION = 0
# TODO(b/36777903) change other parts of tpu.py to use this class.
class ShardingPolicy(object):
"""An object use to hold the sharding policy for a Tensor.
"""
def __init__(self):
self._number_of_shards = None
self._shard_dimension = None
self._frozen = False
def __str__(self):
if self.number_of_shards is None or self.shard_dimension is None:
return "ShardingPolicy(unset)"
else:
return ("ShardingPolicy(%d shards dimension %d)" %
(self.number_of_shards, self.shard_dimension))
def _fill_default_values(self):
if self._number_of_shards is None:
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
if self._shard_dimension is None:
self._shard_dimension = tensor_shape.as_dimension(
_DEFAULT_SHARD_DIMENSION)
def freeze(self):
"""Prevents further modification to the sharding policy.
Any values that have not been set when freeze is called are set to
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
"""
if not self._frozen:
self._fill_default_values()
self._frozen = True
@property
def number_of_shards(self):
"""Returns the number of shards in the policy or None if unspecified."""
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards for the current policy.
If the policy has been frozen then number_of_shards must match the
existing setting.
Args:
number_of_shards: The number of shards to use in the policy.
Raises:
ValueError: If the policy has been frozen and number_of_shards
differs from the frozen value; or number_of_shards <= 0.
"""
if self._frozen:
if self._number_of_shards != number_of_shards:
raise ValueError(
"Can't set sharding policy to use %d shards since it has been "
"frozen to use %d." % (number_of_shards, self._number_of_shards))
else:
if number_of_shards > 0:
self._number_of_shards = number_of_shards
else:
raise ValueError(
"Can't set sharding policy to use %s shards; value must be >0",
str(number_of_shards))
@property
def shard_dimension(self):
"""Returns the shard dimension of the policy or None if unspecified."""
return self._shard_dimension
def set_shard_dimension(self, shard_dimension):
"""Sets the shard dimension for the current policy.
If the policy has been frozen then shard_dimension must match the
existing setting.
Args:
shard_dimension: The shard dimension to use in the policy.
Raises:
ValueError: If the policy has been frozen and shard_dimension
differs from the frozen value, or shard_dimension can't be
interpreted as a Dimension.
"""
if self._frozen:
if self._shard_dimension != shard_dimension:
raise ValueError(
"Can't set shard dimension to %d since it has been frozen to "
"use %d." % (shard_dimension, self._shard_dimension))
else:
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
def merge(self, other):
"""Merges the policy of another policy into the current policy.
Args:
other: The policy to merge into this one.
Raises:
ValueError: If this policy has been frozen and the merge conflicts with
the frozen policy.
"""
if other.number_of_shards is not None:
self.set_number_of_shards(other.number_of_shards)
if other.shard_dimension is not None:
self.set_shard_dimension(other.shard_dimension)
def get_sharded_shape(self, shape, shard_index=None):
"""Returns the shape of a shard of a full Tensor.
When given the shape of a 'full-size' Tensor, returns the shape of
the sub-Tensor after it has been sharded. Freezes the policy if it
has not yet been frozen.
Args:
shape: The shape of the full-size Tensor to be sharded.
shard_index: The index of the shard whose shape should be returned.
shard_index can be None for sharding policies that use the same
shape for every shard.
freeze_config:
Returns:
The shape of the sharded version of the Tensor.
Raises:
ValueError: If shard_index is None when shards are of different
shapes; or shard_index is not None and
!(0<=shard_index<number_of_shards); or shape does not have at
least self.shard_dimension+1 dimensions; or the value of
shape's shard dimension is not a multiple of
self.number_of_shards
"""
if self._shard_dimension is None or self._number_of_shards is None:
# Don't raise an error if the config is unset.
return None
if shard_index is not None:
if shard_index < 0 or shard_index >= self.number_of_shards:
raise ValueError("shard_index %d, but must be in [0,%d)." %
(shard_index, self._number_of_shards))
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
if dims[self._shard_dimension] is None:
raise ValueError("shape %s must have a fixed size for dimension %d "
"that is known at graph construction time." %
(shape.as_list(), self._shard_dimension))
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
(shape.as_list(), self._number_of_shards,
self._shard_dimension))
dims[self._shard_dimension] /= self._number_of_shards
return tensor_shape.as_shape(dims)
def _unshard_shape(self, shape):
"""Return the unsharded shape that would generate a given sharded shape.
Args:
shape: the sharded shape to unshard
Returns:
The unsharded shape.
Raises:
ValueError: if shape is unknown or does not contain
self.shard_dimension
TypeError: if shape is not convertible to a TensorShape
"""
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
dims[self._shard_dimension] *= self._number_of_shards
return tensor_shape.as_shape(dims)
def get_unsharded_shape(self, shapes):
"""Returns the shape of an unsharded Tensor given a list of shards.
When given a list of shapes of shards, returns the shape of the
unsharded Tensor that would generate the shards. Sets defaults for the
policy if number_of_shards or shard_dimension is None.
Args:
shapes: The shapes of the Tensor shards to be combined.
Returns:
The shape of the unsharded version of the Tensor.
Raises:
ValueError: if shapes is not a list of length
self.number_of_shards; or any element of shapes is not a valid
shape consistent with the sharding policy; or the list of
shapes is not a valid sharding of a full shape.
TypeError: if an element of shapes is not convertible to a
TensorShape
"""
self._fill_default_values()
if len(shapes) != self.number_of_shards:
raise ValueError(
"shapes is %s but must be a list of length number_of_shards=%d" % (
str(shapes), self.number_of_shards))
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
for i in xrange(self.number_of_shards - 1):
if not unsharded_shapes[i].is_compatible_with(
unsharded_shapes[self.number_of_shards - 1]):
raise ValueError(
"sharded shapes %s are not consistent shards of a full shape "
"sharded %d ways along dimension %d" % (
str(shapes), self.number_of_shards, self.shard_dimension))
return unsharded_shapes[0]
# pylint: disable=wildcard-import,unused-import,redefined-builtin
from tensorflow.python.tpu.tpu_sharding import *
# pylint: enable=wildcard-import,unused-import,redefined-builtin

View File

@ -1,156 +1,25 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min
_RETRY_TIMES = 120
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
# including num_cores and num_hosts.
_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
'num_cores',
'num_hosts',
'num_of_cores_per_host',
'topology',
'devices',
])
def _query_tpu_system_metadata(master_address, cluster_def=None,
query_topology=False):
"""Automatically detects the TPU system metadata in the system."""
tpu_core_count = 0
devices = []
device_dict = collections.defaultdict(list)
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
while True:
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
master_address)
try:
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master address '
'is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic
# here.
if retry_count <= _RETRY_TIMES:
logging.warning('%s', msg)
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
retry_count += 1
else:
raise ValueError(msg)
num_of_cores_per_host = 0
if tpu_core_count:
num_cores_per_host_set = set(
[len(core_ids) for core_ids in device_dict.values()])
if len(num_cores_per_host_set) != 1:
raise RuntimeError(
'TPU cores on each host is not same. This should not happen!. '
'devices: {}'.format(devices))
num_of_cores_per_host = num_cores_per_host_set.pop()
topology = None
if query_topology:
if not tpu_core_count:
raise RuntimeError(
'Cannot find any TPU cores in the system (master address {}). '
'This usually means the master address is incorrect or the '
'TPU worker has some problems. Available devices: {}'.format(
master_address, devices))
topology = _obtain_topology(master_address, cluster_def)
metadata = _TPUSystemMetadata(
num_cores=tpu_core_count,
num_hosts=len(device_dict),
num_of_cores_per_host=num_of_cores_per_host,
topology=topology,
devices=devices)
if tpu_core_count:
logging.info('Found TPU system:')
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
logging.info('*** Num TPU Cores Per Worker: %d',
metadata.num_of_cores_per_host)
for device in metadata.devices:
logging.info('*** Available Device: %s', device)
else:
logging.info('Failed to find TPU: %s', metadata)
return metadata
def _obtain_topology(master_address, cluster_def):
"""Obtains TPU fabric topology."""
try:
logging.info('Initializing TPU system (master: %s) to fetch topology '
'for model parallelism. This might take a while.',
master_address)
with ops.Graph().as_default():
session_config = get_session_config_with_timeout(
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
with session_lib.Session(
master_address, config=session_config) as sess:
topology = sess.run(tpu.initialize_system())
return topology
except errors.DeadlineExceededError:
raise ValueError(
'Fail to initialize TPU system with master (%s). '
'Please double check the TPU system is functional.' % (
master_address))
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
"""Returns a session given a timeout and a cluster configuration."""
config = config_pb2.ConfigProto(
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
return config
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_system_metadata import *
# used by tests
from tensorflow.python.tpu.tpu_system_metadata import _query_tpu_system_metadata
# pylint: enable=wildcard-import,unused-import

View File

@ -1,223 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Library for constructing a training loop, suitable for TPUs."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.compiler import xla
from tensorflow.contrib.tpu.python.tpu import tensor_tracer
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
`condition` and `body` take the current value of the loop-carried
tensors. 'body' additionally takes a tuple of infeed from
infeed_queue if infeed_queue is not None. `condition` must return a
single boolean value that determines whether iteration
continues. `body` must return an updated list of values for the
loop-carried tensors.
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
TypeError: if body or condition has the wrong signature.
"""
del name
# Converts inputs to Tensors.
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
x in inputs]
input_types = [x.dtype for x in inputs]
input_arity = len(inputs)
body_arg_error = xla.check_function_argument_count(
body, input_arity, infeed_queue)
if body_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
input_arity, str([i.name for i in inputs]), body_arg_error))
else:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s and %d additional inputs from "
"infeed, but the computation needs %s" % (input_arity, str(
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
body_arg_error))
condition_arg_error = xla.check_function_argument_count(
condition, input_arity, None)
if condition_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
else:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s. Note that infeed is not passed to the loop "
"condition." % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
def condition_wrapper(*inputs):
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
return condition(*inputs)
def body_wrapper(*inputs):
"""Wrapper around `body` that handles infeed queues and control deps."""
inputs = list(inputs)
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
# Runs `body` with the dequeue_ops appended.
if infeed_queue:
number_of_shards = tpu_function.get_tpu_context().number_of_shards
if number_of_shards is None:
raise ValueError("Can't build training loop with infeed when there is "
"no tpu_shard_context. Are you building a loop or "
"graph directly rather than from inside tpu.rewrite, "
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
infeed_queue.set_number_of_shards(number_of_shards)
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
else:
dequeue_ops = []
outputs = body(*(inputs + dequeue_ops))
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs
if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
"TPU training loop body must return zero or more Tensor values "
"followed by zero or more Operations.")
output_types = [op.dtype for op in output_tensors]
if input_types != output_types:
raise TypeError(
"Mismatch between input types and output types for training loop "
"body: {} vs {}".format(input_types, output_types))
# Add the dequeue operations to output_operations to ensure they are run
# by the loop, even if the programmer's loop body does not use them.
output_operations += dequeue_ops
# Add a dummy output, if needed.
if not output_tensors:
output_tensors = array_ops.constant(0)
if output_operations:
# TODO(phawkins): in principle this is too restrictive since it serializes
# the training loop steps. In practice it does not matter since this loop
# will be compiled by XLA.
output_tensors = control_flow_ops.tuple(output_tensors,
control_inputs=output_operations)
if tensor_tracer.TensorTracer.is_enabled():
num_replicas = tpu_function.get_tpu_context().number_of_shards
if num_replicas is None:
num_replicas = 1
tt = tensor_tracer.TensorTracer()
output_tensors = tt.trace_tpu(ops.get_default_graph(),
output_tensors, None,
num_replicas)
return output_tensors
# If the body has arity 0, add a dummy loop-carried value to which we can add
# control dependencies from any side-effecting operations.
if input_arity == 0:
inputs = [array_ops.constant(0)]
return control_flow_ops.while_loop(
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop that executes a fixed number of iterations.
The set of loop-carried tensors correspond to `inputs`.
`body` must be a function that takes and returns the values of the
loop-carried tensors.
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
ValueError: if there is a type error.
"""
def _convert_to_list(xs):
if not isinstance(xs, (list, tuple)):
return [xs]
else:
return list(xs)
def cond(i, *args):
del args
return i < n
def body_wrapper(i, *args):
return [i + 1] + _convert_to_list(body(*args))
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
outputs = while_loop(
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
outputs = _convert_to_list(outputs)
if len(outputs) == 1:
# Returns the Op rather than an empty list.
return outputs[0].op
else:
return outputs[1:]
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.training_loop import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,51 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Utilities for the functionalities."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
def check_positive_integer(value, name):
"""Checks whether `value` is a positive integer."""
if not isinstance(value, six.integer_types):
raise TypeError('{} must be int, got {}'.format(name, type(value)))
if value <= 0:
raise ValueError('{} must be positive, got {}'.format(name, value))
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
# python/estimator/util.py.
class MultiHostDatasetInitializerHook(training.SessionRunHook):
"""Creates a SessionRunHook that initializes all passed iterators."""
def __init__(self, dataset_initializers):
self._initializers = dataset_initializers
def after_create_session(self, session, coord):
del coord
start = time.time()
session.run(self._initializers)
logging.info('Initialized dataset iterators in %d seconds',
time.time() - start)
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.util import *
# pylint: enable=wildcard-import,unused-import

View File

@ -2010,6 +2010,7 @@ tf_gen_op_wrapper_private_py(
visibility = [
"//smartass/brain/configure/python:__pkg__",
"//tensorflow/contrib/tpu:__pkg__",
"//tensorflow/python/tpu:__pkg__",
],
deps = [
"//tensorflow/core:tpu_configuration_ops_op_lib",

334
tensorflow/python/tpu/BUILD Normal file
View File

@ -0,0 +1,334 @@
# Description: Operations defined for Cloud TPUs
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_py_test",
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//knowledge/cerebra/sense/im2query:__subpackages__",
"//learning/brain:__subpackages__",
"//learning/deepmind:__subpackages__",
"//medical/pathology:__subpackages__",
"//tensorflow:__subpackages__",
"//vr/perception:__subpackages__",
],
)
py_library(
name = "tpu_py",
srcs = ["ops/tpu_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:tpu_ops_gen",
],
)
py_library(
name = "async_checkpoint",
srcs = ["async_checkpoint.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
],
)
py_library(
name = "tpu_estimator",
srcs = [
"_tpu_estimator_embedding.py",
"error_handling.py",
"tpu_config.py",
"tpu_context.py",
"tpu_estimator.py",
"util.py",
],
srcs_version = "PY2AND3",
deps = [
":async_checkpoint",
":feature_column",
":functional",
":tpu_embedding",
":tpu_lib",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:function",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:util",
"@six_archive//:six",
],
)
py_library(
name = "functional",
srcs = ["functional.py"],
srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
deps = [
"//tensorflow/python:tpu_ops_gen",
],
)
py_library(
name = "tpu",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":feature_column",
":tpu_embedding",
":tpu_estimator",
":tpu_lib",
],
)
py_library(
name = "tpu_lib",
srcs = [
"__init__.py",
"bfloat16.py",
"device_assignment.py",
"session_support.py",
"tensor_tracer.py",
"topology.py",
"tpu.py",
"tpu_feed.py",
"tpu_function.py",
"tpu_optimizer.py",
"tpu_sharding.py",
"tpu_system_metadata.py",
"training_loop.py",
"xla.py",
],
srcs_version = "PY2AND3",
deps = [
":datasets",
":functional",
":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/core:protos_all_py",
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
"//tensorflow/core/protobuf/tpu:topology_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/losses",
"//tensorflow/python/tpu/profiler",
],
)
py_library(
name = "datasets",
srcs = [
"datasets.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
tf_py_test(
name = "datasets_test",
size = "medium",
srcs = ["datasets_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
":datasets",
],
grpc_enabled = True,
shard_count = 4,
tags = ["no_oss"],
)
tf_py_test(
name = "tpu_test",
size = "small",
srcs = ["tpu_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:layers",
],
tags = ["no_windows"], # TODO: needs investigation on Windows
)
tf_py_test(
name = "tpu_sharding_test",
size = "small",
srcs = ["tpu_sharding_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "bfloat16_test",
size = "small",
srcs = ["bfloat16_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "tpu_infeed_test",
size = "small",
srcs = ["tpu_infeed_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_config_test",
size = "small",
srcs = ["tpu_config_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_estimator_signals_test",
size = "small",
srcs = ["tpu_estimator_signals_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
# TODO(jhseu): Remove. Fails in OSS on Python 3.
tags = ["no_oss"],
)
tf_py_test(
name = "topology_test",
size = "medium",
srcs = ["topology_test.py"],
additional_deps = [
":tpu",
"//tensorflow/python:framework_test_lib",
],
)
py_library(
name = "tpu_embedding",
srcs = [
"tpu_embedding.py",
"tpu_embedding_gradient.py",
],
srcs_version = "PY2AND3",
deps = [
":tpu_lib",
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:partitioned_variables",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"@six_archive//:six",
],
)
py_library(
name = "feature_column",
srcs = ["feature_column.py"],
deps = [
":tpu_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
],
)
tf_py_test(
name = "feature_column_test",
srcs = [
"feature_column_test.py",
],
additional_deps = [
":feature_column",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:variables",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
],
main = "feature_column_test.py",
)

View File

@ -0,0 +1,20 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Ops related to Tensor Processing Units."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,334 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tooling for support TPU embedding in TPUEstimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.feature_column import feature_column as core_fc
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
from tensorflow.python.tpu import feature_column as tpu_fc
from tensorflow.python.tpu import tpu_embedding
# pylint: disable=protected-access
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
tpu_fc._TPUSharedEmbeddingColumn)
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
core_fc_lib.EmbeddingColumn,
core_fc._SharedEmbeddingColumn)
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
# pylint: enable=protected-access
_TABLE_NAME_PREFIX = 'tbl_'
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
def _get_table_name_from_embedding_var_name(embedding_var_name):
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
def _get_embedding_var_name_from_table_name(table_name):
return table_name[_LEN_TABLE_NAME_PREFIX:]
def _get_embedding_variable_name(scope_name, var_name):
return '{}/{}'.format(scope_name, var_name)
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
"""Return embedding variable names which are consistent with CPU runs."""
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
return tpu_embedding.AdagradSlotVariableName(
'{}/{}/Adagrad'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
return tpu_embedding.AdamSlotVariableNames(
'{}/{}/Adam/m'.format(scope_name, var_name),
'{}/{}/Adam/v'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters,
tpu_embedding.StochasticGradientDescentParameters):
return None
else:
raise ValueError('Support to infer full variable name '
'for optimization_parameter {} has not been added.'
.format(optimization_parameters))
def get_full_variable_names(
graph, table_to_config_dict, optimization_parameters):
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
if not collection:
raise RuntimeError(
'Embedding feature column did not capture any thing. Make sure the '
'feature columns passed to TPUEstimator constructor is properly '
'used in model_fn.')
embedding_variable_name_by_table = {}
slot_variable_names_by_table = {}
for table_name in table_to_config_dict:
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
(scope_name, var_name) = collection[0][embedding_var_name]
embedding_variable_name_by_table[table_name] = (
_get_embedding_variable_name(scope_name, var_name))
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
scope_name, var_name, optimization_parameters)
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
return embedding_variable_name_by_table, slot_variable_names_by_table
def get_tpu_embedding_config_from_feature_columns(feature_columns):
"""Create configs for TPUEmbedding from a list of feature columns.
This function will place one embedding tensor per table and the return is
intended to be used as input to TPUEmbedding.
Args:
feature_columns: a list of supported feature columns.
Returns:
A pair of dicts, the first maps tables to their config, the second maps
features to tables.
"""
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
for column in feature_columns:
if not isinstance(column, allowed):
raise TypeError(
'Unsupported feature column {}. Supported types are {}.'.format(
type(column), allowed))
table_to_config = {}
feature_to_table = {}
for column in feature_columns:
feature_name = column.get_feature_key_name()
table_name = _get_table_name_from_embedding_var_name(
column.get_embedding_var_name())
if feature_name in feature_to_table:
raise ValueError(
'Feature column {} is used with multiple embeddings and this is '
'not supported.'.format(feature_name))
feature_to_table[feature_name] = table_name
vocabulary_size, dimension = column.get_embedding_table_size()
table_to_config[table_name] = tpu_embedding.TableConfig(
vocabulary_size=vocabulary_size,
dimension=dimension,
initializer=column.get_initializer(),
combiner=column.get_combiner())
return table_to_config, feature_to_table
def _get_tpu_embedding_optimization_parameters(embedding_config_spec):
"""Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec."""
if embedding_config_spec.optimizer_type == 'adagrad':
return tpu_embedding.AdagradParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.adagrad_initial_accumulator,
embedding_config_spec.use_gradient_accumulation)
elif embedding_config_spec.optimizer_type == 'sgd':
return tpu_embedding.StochasticGradientDescentParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.use_gradient_accumulation)
elif embedding_config_spec.optimizer_type == 'adam':
return tpu_embedding.AdamParameters(
embedding_config_spec.learning_rate,
embedding_config_spec.adam_parameters.beta1,
embedding_config_spec.adam_parameters.beta2,
embedding_config_spec.adam_parameters.epsilon,
use_gradient_accumulation=embedding_config_spec
.use_gradient_accumulation)
else:
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
AdamParameters = collections.namedtuple('AdamParameters',
['beta1', 'beta2', 'epsilon'])
# TODO(shizhiw): Improve the API to support more optimizer parameters in API.
class EmbeddingConfigSpec(
collections.namedtuple('EmbeddingConfigSpec', [
'feature_columns', 'learning_rate', 'optimizer_type',
'adagrad_initial_accumulator', 'clipping_limit',
'use_gradient_accumulation', 'adam_parameters'
])):
"""Class to keep track of embedding config specification."""
def __new__(cls,
feature_columns,
learning_rate,
optimizer_type='adagrad',
adagrad_initial_accumulator=None,
clipping_limit=None,
use_gradient_accumulation=False,
adam_parameters=None):
"""Creates an EmbeddingConfigSpec instance.
Args:
feature_columns: All `FeatureColumn`s used by model.
learning_rate: embedding optimizer learning rate.
optimizer_type: (String) Name of the optimizer for embedding gradients
updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default
value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam'
(`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will
be applied to all embedding variables specified by `feature_columns`.
adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when
optimizer_type is 'adagrad'. Default is `0.1`.
clipping_limit: (Optional) Clipping limit (absolute value).
use_gradient_accumulation: (Experimental) Whether to accumulate the
gradients across TPU embedding mini-batches. Gradient accumulation does
not affect SGD and therefore this is applicable only for Adagrad.
adam_parameters: AdamParameters. Used when optimizer_type is 'adam'.
Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon.
Returns:
An EmbeddingConfigSpec instance.
Raises:
ValueError: If the feature_columns are not specified.
TypeError: If the feature columns are not of ths correct type (one of
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
_EMBEDDING_COLUMN_CLASSES).
ValueError: If use_gradient_accumulation is True for SGD.
ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or
"adam".
"""
if not feature_columns:
raise ValueError('`feature_columns` cannot be `None` or empty.')
# It is unknown at this moment, whether the TPUEstimator is running in CPU
# or TPU mode. So allow non-TPU embedding columns also.
supported_classes = tuple(
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
list(_EMBEDDING_COLUMN_CLASSES))
for column in feature_columns:
if not isinstance(column, supported_classes):
raise TypeError(
'All feature columns must be supported types in {}. Got {}'.format(
supported_classes, type(column)))
if optimizer_type == 'adagrad':
if adagrad_initial_accumulator is None:
adagrad_initial_accumulator = 0.1
if adagrad_initial_accumulator <= 0:
raise ValueError('Adagrad initial_accumulator must be positive')
elif optimizer_type == 'sgd':
if use_gradient_accumulation:
raise ValueError('Gradient accumulation makes sense for Adagrad only.')
elif optimizer_type == 'adam':
if adam_parameters is None:
adam_parameters = AdamParameters(0.9, 0.999, 1e-8)
if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.:
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(
adam_parameters.beta1))
if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.:
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(
adam_parameters.beta2))
if adam_parameters.epsilon <= 0.:
raise ValueError('epsilon must be positive; got {}.'.format(
adam_parameters.epsilon))
else:
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
return super(EmbeddingConfigSpec, cls).__new__(
cls,
feature_columns=feature_columns,
learning_rate=learning_rate,
optimizer_type=optimizer_type,
adagrad_initial_accumulator=adagrad_initial_accumulator,
clipping_limit=clipping_limit,
use_gradient_accumulation=use_gradient_accumulation,
adam_parameters=adam_parameters)
class EmbeddingConfig(object):
"""This is the internal immutable object for embedding config.
`_EmbeddingConfig` is responsible to _translate_ user provided
`EmbeddingConfigSpec` to internal data structures, mostly constructor
arguments of `TPUEmbedding`.
"""
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
num_hosts, num_cores, master):
self._embedding_config_spec = embedding_config_spec
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._num_hosts = num_hosts
self._num_cores = num_cores
self._master = master
self._table_to_config_dict, self._feature_to_table_dict = (
get_tpu_embedding_config_from_feature_columns(
embedding_config_spec.feature_columns))
self._optimization_parameters = _get_tpu_embedding_optimization_parameters(
self._embedding_config_spec)
self._mode_to_tpu_embedding_dict = {}
self.dummy_table_variables = None
def has_embedding_tables(self):
return bool(self._table_to_config_dict)
def _create_tpu_embedding(self, mode):
"""Create tpu_embedding.TPUEmbedding based on mode."""
if mode == model_fn_lib.ModeKeys.TRAIN:
batch_size = self._train_batch_size
else:
batch_size = self._eval_batch_size
if mode == model_fn_lib.ModeKeys.TRAIN:
tpu_embedding_mode = tpu_embedding.TRAINING
elif (mode == model_fn_lib.ModeKeys.EVAL or
mode == model_fn_lib.ModeKeys.PREDICT):
tpu_embedding_mode = tpu_embedding.INFERENCE
else:
raise ValueError('Mode {} is not supported.'.format(mode))
tpu_embedding_ = tpu_embedding.TPUEmbedding(
self._table_to_config_dict,
self._feature_to_table_dict,
batch_size,
tpu_embedding_mode,
self._master,
self._optimization_parameters,
)
return tpu_embedding_
def get_tpu_embedding(self, mode):
if mode not in self._mode_to_tpu_embedding_dict:
self._mode_to_tpu_embedding_dict[mode] = (
self._create_tpu_embedding(mode))
return self._mode_to_tpu_embedding_dict[mode]
def split_inputs(ctx, features, labels):
"""Splits the dense and sparse tensors inside the features and labels."""
sparse_features = collections.OrderedDict()
if ctx.embedding_config:
tpu_embedding_ = ctx.embedding_config.tpu_embedding
for feature_key in tpu_embedding_.feature_to_table_dict:
sparse_features[feature_key] = features.pop(feature_key)
return features, labels, sparse_features

View File

@ -0,0 +1,212 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
allow execution to continue on the main thread.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
"""Initializes a `CheckpointSaverHook`.
Args:
checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances. Used for
callbacks that run immediately before or after this hook saves the
checkpoint.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
self._save_thread = None
self._write_graph_thread = None
self._checkpoint_dir = checkpoint_dir
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or []
self._steps_per_run = 1
self._summary_writer = None
self._global_step_tensor = None
self._last_checkpoint_step = None
def _set_steps_per_run(self, steps_per_run):
self._steps_per_run = steps_per_run
def begin(self):
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
def _write_graph_fn(self):
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
self._write_graph_thread = threading.Thread(target=_write_graph_fn,
args=[self])
self._write_graph_thread.start()
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
logging.info("Triggering checkpoint. %s", global_step)
if self._save(run_context.session, global_step):
run_context.request_stop()
def end(self, session):
if self._save_thread:
logging.info("Waiting for any pending checkpoints to finish.")
self._save_thread.join()
if self._write_graph_thread:
logging.info("Waiting for any pending write_graph to finish.")
self._write_graph_thread.join()
last_step = session.run(self._global_step_tensor)
if self._last_checkpoint_step != last_step:
self._save(session, last_step, asynchronous=False)
for l in self._listeners:
l.end(session, last_step)
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
# Skip saving on step 0
if step == 0:
return
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
start_time = time.time()
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
for l in self._listeners:
l.after_save(session, step)
end_time = time.time()
logging.info("Checkpoint actual writing time: (%.3f sec)",
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
if not asynchronous:
self._last_checkpoint_step = step
_save_fn()
return
if self._save_thread is not None:
self._save_thread.join(timeout=0.1)
if self._save_thread.is_alive():
logging.info("Saver thread still in progress, skipping checkpoint.")
return
self._last_checkpoint_step = step
self._save_thread = threading.Thread(target=_save_fn)
self._save_thread.start()
def _get_saver(self):
if self._saver is not None:
return self._saver
elif self._scaffold is not None:
return self._scaffold.saver
# Get saver from the SAVERS collection if present.
collection_key = ops.GraphKeys.SAVERS
savers = ops.get_collection(collection_key)
if not savers:
raise RuntimeError(
"No items in collection {}. Please add a saver to the collection "
"or provide a saver or scaffold.".format(collection_key))
elif len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
self._saver = savers[0]
return savers[0]

View File

@ -0,0 +1,77 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper context for running models with bfloat16."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_contextlib
def _get_custom_getter():
"""Returns a custom getter that this class's methods must be called under.
All methods of this class must be called under a variable scope that was
passed this custom getter. Example:
```python
network = ConvNetBuilder(...)
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
network.conv(...)
# Call more methods of network here
```
Currently, this custom getter only does anything if self.use_tf_layers is
True. In that case, it causes variables to be stored as dtype
self.variable_type, then casted to the requested dtype, instead of directly
storing the variable as the requested dtype.
"""
def inner_custom_getter(getter, *args, **kwargs):
"""Custom getter that forces variables to have type self.variable_type."""
cast_to_bfloat16 = False
requested_dtype = kwargs['dtype']
if requested_dtype == dtypes.bfloat16:
# Only change the variable dtype if doing so does not decrease variable
# precision.
kwargs['dtype'] = dtypes.float32
cast_to_bfloat16 = True
var = getter(*args, **kwargs)
# This if statement is needed to guard the cast, because batch norm
# assigns directly to the return value of this custom getter. The cast
# makes the return value not a variable so it cannot be assigned. Batch
# norm variables are always in fp32 so this if statement is never
# triggered for them.
if cast_to_bfloat16:
var = math_ops.cast(var, dtypes.bfloat16)
return var
return inner_custom_getter
@tf_contextlib.contextmanager
def bfloat16_scope():
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
"""
with variable_scope.variable_scope(
'', custom_getter=_get_custom_getter()) as varscope:
yield varscope

View File

@ -19,11 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import bfloat16
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.tpu import bfloat16
class BFloat16ScopeTest(test.TestCase):

View File

@ -0,0 +1,191 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of Cloud TPU helper functions for data loading."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import functional_ops
def _TextLineDataset(filename):
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
return dataset
def _TFRecordDataset(filename):
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
return dataset
_FILETYPE_MAP = {
'tfrecord': _TFRecordDataset,
'textline': _TextLineDataset,
'text': _TextLineDataset,
}
def StreamingFilesDataset(files,
filetype=None,
file_reader_job=None,
worker_job=None,
num_epochs=None,
filename_shuffle_buffer_size=None,
num_parallel_reads=None,
batch_transfer_size=None,
sloppy=None):
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
files local to your GCE VM. In order to train using files stored on your local
VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset
helper to generate a dataset to feed your Cloud TPU with files from your GCE
VM.
The resulting dataset may return an OutOfRangeError if there are no files
found as a result of the fileglob expansion.
Note: StreamingFilesDataset assumes that the session is using a
TPUClusterResolver and has therefore a worker and a coordinator job. File
loading will be done on the coordinator job.
Args:
files: A string glob to match files, or a `tf.data.Dataset` generating file
names.
filetype: A string (one of 'tfrecord', or 'textline') or a single-argument
TensorFlow function that when given a filename returns a dataset.
file_reader_job: An optional string that corresponds to the job that should
perform the file reads.
worker_job: An optional string that corresponds to the job that should
process the tensors (i.e. your GPU or TPU worker).
num_epochs: The number of epochs through the training set that should be
generated. By default, it will repeat infinitely.
filename_shuffle_buffer_size: An optional integer whose value controls the
shuffling of the file names. If you would like to read from the files in
the same order, set to 0 or False.
num_parallel_reads: An optional integer controlling the number of files to
read from concurrently. (Set to 1 for no parallelism.)
batch_transfer_size: An optional integer controlling the batching used to
amortize the remote function invocation overhead. Set to a very large
number to increase throughput. Set to a very small number to reduce memory
consumption. Set to False to skip batching.
sloppy: (Optional.) If `False`, read input data while maintaining a
deterministic order. (This may have significant performance impacts.)
sloppy defaults to: True.
Returns:
A `tf.data.Dataset` with an infinite stream of elements generated by a
parallel interleaving of the set of files matched (or generated) by `files`
with a type is the output of the dataset specified by `filetype`.
Raises:
ValueError: if any argument is not of the expected type.
"""
if filetype is None:
filetype = 'tfrecord'
if isinstance(filetype, str):
if filetype not in _FILETYPE_MAP:
raise ValueError('Unexpected filetype: %s' % filetype)
reader_fn = _FILETYPE_MAP[filetype]
elif callable(filetype):
reader_fn = filetype
else:
raise ValueError('filetype should be a string or a callable')
file_reader_job = file_reader_job or 'coordinator'
worker_job = worker_job or 'worker'
if filename_shuffle_buffer_size is None:
filename_shuffle_buffer_size = 4096
num_parallel_reads = num_parallel_reads or 8
if batch_transfer_size is None:
batch_transfer_size = 256
if sloppy is None:
sloppy = True
with ops.device('/job:%s' % file_reader_job):
if isinstance(files, str):
source_dataset = dataset_ops.Dataset.list_files(files)
elif isinstance(files, dataset_ops.DatasetV2):
source_dataset = files
else:
raise ValueError('files was not a string or a dataset: %s' % files)
if filename_shuffle_buffer_size:
source_dataset = source_dataset.shuffle(
buffer_size=filename_shuffle_buffer_size)
source_dataset = source_dataset.apply(
interleave_ops.parallel_interleave(
reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
source_dataset = source_dataset.repeat(num_epochs)
if batch_transfer_size:
source_dataset = source_dataset.batch(batch_transfer_size)
source_dataset = source_dataset.prefetch(1)
source_iterator = dataset_ops.make_one_shot_iterator(source_dataset)
source_handle = source_iterator.string_handle()
@function.Defun(dtypes.string)
def LoadingFunc(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, source_dataset.output_types, source_dataset.output_shapes)
return remote_iterator.get_next()
def MapFn(unused_input):
if isinstance(source_dataset.output_types, dtypes.DType):
output_types = [source_dataset.output_types]
elif isinstance(source_dataset.output_types, (list, tuple)):
output_types = source_dataset.output_types
else:
raise ValueError('source dataset has invalid output types')
remote_calls = functional_ops.remote_call(
args=[source_handle],
Tout=output_types,
f=LoadingFunc,
target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
if len(remote_calls) == 1:
return remote_calls[0]
else:
return remote_calls
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
MapFn, num_parallel_calls=4 if sloppy else None)
output_dataset = output_dataset.prefetch(1)
if batch_transfer_size:
# Undo the batching used during the transfer.
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
return output_dataset

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import os
from tensorflow.contrib.tpu.python.tpu import datasets
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@ -31,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test
from tensorflow.python.tpu import datasets
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat

View File

@ -0,0 +1,313 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of TPU helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.tpu.topology import Topology
SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]]
def _compute_task_and_cores_to_replicas(core_assignment, topology):
"""Computes a nested dict which maps task and logical core to replicas."""
task_and_cores_to_replicas = {}
for replica in xrange(core_assignment.shape[0]):
for logical_core in xrange(core_assignment.shape[1]):
coordinates = core_assignment[replica, logical_core, :]
task_id = topology.task_ordinal_at_coordinates(coordinates)
if task_id not in task_and_cores_to_replicas:
task_and_cores_to_replicas[task_id] = {}
if logical_core not in task_and_cores_to_replicas[task_id]:
task_and_cores_to_replicas[task_id][logical_core] = set()
task_and_cores_to_replicas[task_id][logical_core].add(replica)
task_to_sorted_replica_id = {}
for task, core_to_replicas in task_and_cores_to_replicas.items():
core_to_sorted_replicas = {}
for core, replicas in core_to_replicas.items():
core_to_sorted_replicas[core] = sorted(replicas)
task_to_sorted_replica_id[task] = core_to_sorted_replicas
return task_to_sorted_replica_id
class DeviceAssignment(object):
"""Mapping from logical cores in a computation to the physical TPU topology.
Prefer to use the `device_assignment()` helper to construct a
`DeviceAssignment`; it is easier if less flexible than constructing a
`DeviceAssignment` directly.
"""
def __init__(self, topology, core_assignment):
"""Constructs a `DeviceAssignment` object.
Args:
topology: A `Topology` object that describes the physical TPU topology.
core_assignment: A logical to physical core mapping, represented as a
rank 3 numpy array. See the description of the `core_assignment`
property for more details.
Raises:
ValueError: If `topology` is not `Topology` object.
ValueError: If `core_assignment` is not a rank 3 numpy array.
"""
if not isinstance(topology, Topology):
raise ValueError("topology must be a Topology object, got {}".format(
type(topology)))
core_assignment = np.asarray(core_assignment, dtype=np.int32)
self._topology = topology
if core_assignment.ndim != 3:
raise ValueError("core_assignment must be a rank 3 numpy array, "
"got shape {}".format(core_assignment.shape))
self._num_replicas = core_assignment.shape[0]
self._num_cores_per_replica = core_assignment.shape[1]
if core_assignment.shape[-1] != topology.mesh_rank:
raise ValueError(
"minor dimension of core_assignment must have size equal to topology "
"rank ({}), got shape {}".format(topology.mesh_rank,
core_assignment.shape))
self._core_assignment = core_assignment
self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
self._core_assignment, topology)
@property
def topology(self):
"""A `Topology` that describes the TPU topology."""
return self._topology
@property
def num_cores_per_replica(self):
"""The number of cores per replica."""
return self._num_cores_per_replica
@property
def num_replicas(self):
"""The number of replicas of the computation."""
return self._num_replicas
@property
def core_assignment(self):
"""The logical to physical core mapping.
Returns:
An integer numpy array of rank 3, with shape
`[num_replicas, num_cores_per_replica, topology_rank]`. Maps
(replica, logical core) pairs to physical topology coordinates.
"""
return self._core_assignment
def coordinates(self, replica, logical_core):
"""Returns the physical topology coordinates of a logical core."""
return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
"""Lookup replica ids by task number and logical core.
Args:
task_id: TensorFlow task number.
logical_core: An integer, identifying a logical core.
Returns:
A sorted list of the replicas that are attached to that task and
logical_core.
Raises:
ValueError: If no replica exists in the task which contains the logical
core.
"""
try:
return self._task_and_cores_to_replicas[task_id][logical_core]
except KeyError:
raise ValueError(
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
def tpu_ordinal(self, replica=0, logical_core=0):
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
def host_device(self, replica=0, logical_core=0, job=None):
"""Returns the CPU device attached to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
def tpu_device(self, replica=0, logical_core=0, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
def device_assignment(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1):
"""Computes a device_assignment of a computation across a TPU topology.
Attempts to choose a compact grid of cores for locality.
Returns a `DeviceAssignment` that describes the cores in the topology assigned
to each core of each replica.
`computation_shape` and `computation_stride` values should be powers of 2 for
optimal packing.
Args:
topology: A `Topology` object that describes the TPU cluster topology.
To obtain a TPU topology, evaluate the `Tensor` returned by
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
computation_shape: A rank 1 int32 numpy array with size equal to the
topology rank, describing the shape of the computation's block of cores.
If None, the `computation_shape` is `[1] * topology_rank`.
computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
describing the inter-core spacing of the `computation_shape` cores in the
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
Returns:
A DeviceAssignment object, which describes the mapping between the logical
cores in each computation replica and the physical cores in the TPU
topology.
Raises:
ValueError: If `topology` is not a valid `Topology` object.
ValueError: If `computation_shape` or `computation_stride` are not 1D int32
numpy arrays with shape [3] where all values are positive.
ValueError: If computation's replicas cannot fit into the TPU topology.
"""
# Deserialize the Topology proto, if it is a string.
if isinstance(topology, bytes):
topology = Topology(serialized=topology)
if not isinstance(topology, Topology):
raise ValueError("`topology` is not a Topology object; got {}".format(
type(topology)))
topology_rank = len(topology.mesh_shape)
mesh_shape = topology.mesh_shape
if computation_shape is None:
computation_shape = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_shape = np.asarray(computation_shape, dtype=np.int32)
if computation_stride is None:
computation_stride = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_stride = np.asarray(computation_stride, dtype=np.int32)
if computation_shape.shape != (topology_rank,):
raise ValueError("computation_shape must have shape [{}]; got {}".format(
topology_rank, computation_shape.shape))
if computation_stride.shape != (topology_rank,):
raise ValueError("computation_stride must have shape [{}]; got {}".format(
topology_rank, computation_stride.shape))
if any(computation_shape < 1):
raise ValueError(
"computation_shape must be positive; got computation_shape={}".format(
computation_shape))
if any(computation_stride < 1):
raise ValueError(
"computation_stride must be positive; got computation_stride={}".format(
computation_stride))
# Computes the physical size of one computation instance.
computation_footprint = computation_shape * computation_stride
if any(computation_footprint > mesh_shape):
raise ValueError(
"computation footprint {} does not fit in TPU topology shape {}".format(
computation_footprint, mesh_shape))
# Computes how many copies of the computation footprint fit in the mesh.
block_counts = mesh_shape // computation_footprint
replica_counts = block_counts * computation_stride
max_replicas = np.prod(replica_counts)
if num_replicas > max_replicas:
raise ValueError(
"requested {} replicas but only {} replicas with shape {} and "
"computation_stride {} fit in a TPU mesh of shape {}".format(
num_replicas, max_replicas, computation_shape, computation_stride,
mesh_shape))
def ceil_of_ratio(n, m):
return (n + m - 1) // m
replica_shape = [0] * topology_rank
if num_replicas > 0:
remaining_replicas = num_replicas
remaining_dims = topology_rank
# Choose dimensions as close to an equal cube as possible, in order of
# increasing dimension size. By visiting dimensions in increasing size, we
# assign the most constrained dimension first, so we won't make infeasible
# choices.
#
# As a secondary sort order, visit the dimensions in reverse order. This
# means we try to use both cores on the same chip in preference to two cores
# on different chips.
for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
i = -ni
target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
replica_shape[i] = min(target_size, x)
remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
remaining_dims -= 1
assert remaining_replicas == 1 and remaining_dims == 0
# Assigns an offset to each replica such that no two replicas overlap.
replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
for replica in xrange(num_replicas):
# Chooses a replica number in each axis.
t = replica
pos = []
for dim in replica_shape[::-1]:
pos.append(t % dim)
t //= dim
replica_pos = np.array(pos[::-1], dtype=np.int32)
# Determines where that replica starts in each axis.
outer = replica_pos // computation_stride
inner = replica_pos % computation_stride
replica_offsets[replica, :] = outer * computation_footprint + inner
# Computes a complete logical core -> physical core mapping for each replica.
indices = [
np.arange(0, computation_shape[i] * computation_stride[i],
computation_stride[i]) for i in xrange(topology_rank)
]
indices = np.concatenate(
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
axis=-1)
indices = indices.reshape((-1, topology_rank))
assignment = indices + replica_offsets[:, np.newaxis, :]
return DeviceAssignment(topology, core_assignment=assignment)

View File

@ -0,0 +1,132 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""ErrorRendezvous handler for collecting errors from multiple threads."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import sys
import threading
import time
import six
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
_UNINTERESTING_ERRORS = (errors.CancelledError,)
class ErrorRendezvous(object):
"""Resolve errors from multiple threads during TPU execution.
TPU errors can occur on the infeed or outfeed threads as well as the main
training thread.
Depending on which thread "wins" and receives the session error first, we may
end up showing users a confusing and non-actionable error message (session
cancelled) instead of a root cause (e.g. a bad filename).
The rendezvous object provides a location to capture these errors until all
threads terminate. At that point we can choose the most informative error
to report.
"""
def __init__(self, num_sources):
# string -> (message, traceback)
self._errors = {}
self._num_sources = num_sources
self._session_cancel_timer = None
def record_error(self, source, exc_info, session=None):
"""Report an exception from the given source.
If a session is passed, a timer will be registered to close it after a few
seconds. This is necessary to ensure the main training loop does not hang
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
interesting error from another thread to propagate.
Args:
source: string, source of the error
exc_info: Output from `sys.exc_info` (type, value, traceback)
session: Session to close after delay.
"""
_, value, _ = exc_info
self._errors[source] = exc_info
logging.info('Error recorded from %s: %s', source, value)
if session is not None and self._session_cancel_timer is None:
def _cancel_session():
time.sleep(5)
try:
session.close()
except: # pylint: disable=bare-except
pass
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
self._session_cancel_timer.daemon = True
self._session_cancel_timer.start()
def record_done(self, source):
"""Mark execution source `source` as done.
If an error was originally reported from `source` it is left intact.
Args:
source: `str`, source being recorded
"""
logging.info('%s marked as finished', source)
if source not in self._errors:
self._errors[source] = None
@contextlib.contextmanager
def catch_errors(self, source, session=None):
"""Context manager to report any errors within a block."""
try:
yield
except Exception: # pylint: disable=broad-except
self.record_error(source, sys.exc_info(), session)
def raise_errors(self, timeout_sec=0):
"""Wait for up to `timeout` seconds for all error sources to finish.
Preferentially raise "interesting" errors (errors not in the
_UNINTERESTING_ERRORS) set.
Args:
timeout_sec: Seconds to wait for other error sources.
"""
for _ in range(timeout_sec):
if len(self._errors) == self._num_sources:
break
time.sleep(1)
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
# First check for any interesting errors, then fall back on the session
# cancelled errors etc.
for k, (typ, value, traceback) in kept_errors:
if isinstance(value, _UNINTERESTING_ERRORS):
continue
else:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)
for k, (typ, value, traceback) in kept_errors:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)

View File

@ -0,0 +1,435 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU Feature Column Library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_function
# pylint: disable=protected-access
_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
fc._VocabularyFileCategoricalColumn,
fc._VocabularyListCategoricalColumn,
fc._WeightedCategoricalColumn,
fc_lib.IdentityCategoricalColumn,
fc_lib.VocabularyFileCategoricalColumn,
fc_lib.VocabularyListCategoricalColumn,
fc_lib.WeightedCategoricalColumn)
def embedding_column(categorical_column,
dimension,
combiner='mean',
initializer=None):
"""TPU embedding_column for `tf.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU
version. The following args available for the non-TPU version are NOT
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
Args:
categorical_column: A categorical_column returned from
categorical_column_with_identity, weighted_categorical_column,
categorical_column_with_vocabulary_list or
categorical_column_with_vocabulary_file.
dimension: An integer specifying dimension of the embedding, must be > 0.
combiner: A string specifying how to reduce if there are multiple entries
in a single row. For more information, see
`tf.feature_column.embedding_column`.
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
`1/sqrt(dimension)`.
Returns:
A _TPUEmbeddingColumn.
Raises:
ValueError: if `dimension` not > 0.
ValueError: if `initializer` is specified but not callable.
"""
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '
' embedding_column must be type %s, got %s.' % (' or '.join([
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
]), type(categorical_column)))
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified. '
'Embedding of column_name: {}'.format(
categorical_column.name))
if initializer is None:
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
def _creator(weight_collections, scope):
embedding_column_layer = fc._EmbeddingColumnLayer(
embedding_shape=embedding_shape,
initializer=initializer,
weight_collections=weight_collections,
trainable=True,
name='embedding_column_layer')
return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
column = _TPUEmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
layer_creator=_creator,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
# For Embedding column, the initializer is hidden inside the creator Fn, which
# is not accessiable later. So, we attach it to a speicial field. Also note
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
# initializer differently. See shared_embedding_columns for details.
column._tpu_initializer = initializer
return column
def shared_embedding_columns(categorical_columns,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None):
"""List of dense columns that convert from sparse, categorical input."""
for categorical_column in categorical_columns:
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '
' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
]), type(categorical_column)))
columns = fc_lib.shared_embedding_columns(
categorical_columns,
dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
# Use the initializer and shared_embedding_collection_name to create TPU
# version
initializer = columns[0].initializer
shared_embedding_collection_name = columns[0].shared_embedding_collection_name
tpu_columns = []
# Create the state (_SharedEmbeddingColumnLayer) here.
for categorical_column in categorical_columns:
column = _TPUSharedEmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True)
tpu_columns.append(column)
return tpu_columns
class _TPUBaseEmbeddingColumn(object):
"""Base class for TPU Embedding Column."""
def __init__(self, categorical_column):
self._tpu_categorical_column = categorical_column
def get_combiner(self):
"""Returns the embedding combiner."""
raise NotImplementedError('not implemented')
def get_embedding_table_size(self):
"""Returns the embedding table size, tuple of vocab size and dimension."""
raise NotImplementedError('not implemented')
def get_feature_key_name(self):
"""Returns the feature key name in the features dict."""
raise NotImplementedError('not impl')
def get_weight_key_name(self):
"""Return the key name for weights."""
raise NotImplementedError('not impl')
def get_embedding_var_name(self):
"""Returns the embedding variable name.
Feature key name and embedding variable name are usually one-to-one mapping.
But for shared embedding columns, it is many-to-one mapping.
"""
raise NotImplementedError('not impl')
def get_initializer(self):
"""Returns the initializer."""
raise NotImplementedError('not impl')
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
raise NotImplementedError('not impl')
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
"""Core Embedding Column."""
def __new__(cls,
categorical_column,
dimension,
combiner='mean',
layer_creator=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn.
return fc._EmbeddingColumn.__new__(
cls,
categorical_column,
dimension,
combiner=combiner,
layer_creator=layer_creator,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable)
def __init__(self,
categorical_column,
dimension,
combiner='mean',
layer_creator=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
self._key = None
def get_combiner(self):
return self.combiner
def get_embedding_table_size(self):
"""Returns num_ids and width."""
return (self.categorical_column._num_buckets, self.dimension)
def get_feature_key_name(self):
"""get_feature_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.categorical_column.name
return self.categorical_column.name
def get_weight_key_name(self):
"""get_weight_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.weight_feature_key
return None
def get_embedding_var_name(self):
"""get_embedding_var_name."""
return self.categorical_column.name
def get_initializer(self):
return self._tpu_initializer
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
if isinstance(
self.categorical_column,
(
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
fc_lib.WeightedCategoricalColumn)):
return True
return False
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
def host_computation():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
# TPU mode
# Get the embeddings from the LazyBuilder.
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
return tensor
class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
fc._SharedEmbeddingColumn):
"""Core Shared Embedding Column."""
def __new__(cls,
categorical_column,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
return fc._SharedEmbeddingColumn.__new__(
cls,
categorical_column,
dimension,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable)
def __init__(self,
categorical_column,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
self._key = None
def get_combiner(self):
return self.combiner
def get_embedding_table_size(self):
"""Returns num_ids and width."""
return (self.categorical_column._num_buckets, self.dimension)
def get_feature_key_name(self):
"""get_feature_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.categorical_column.name
return self.categorical_column.name
def get_weight_key_name(self):
"""get_weight_key_name."""
if self.is_categorical_column_weighted():
return self.categorical_column.weight_feature_key
return None
def get_embedding_var_name(self):
"""get_embedding_var_name."""
return self.shared_embedding_collection_name
def get_initializer(self):
return self.initializer
def is_categorical_column_weighted(self):
"""Check if the categorical column of the embedding column is weighted."""
if isinstance(
self.categorical_column,
(
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
fc_lib.WeightedCategoricalColumn)):
return True
return False
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
def host_computation():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
# TPU mode
# Get the embeddings from the LazyBuilder.
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
is_shared_embedding=True)
return tensor
def _record_variable_scope_and_name(embedding_var_name,
embedding_var_name_in_fc,
is_shared_embedding=False):
"""Add embedding variable name and scope to collection."""
g = ops.get_default_graph()
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
if not collection:
collection.append({})
var_def_dict = collection[0]
captured_scope = variable_scope.get_variable_scope()
captured_scope_name = captured_scope.name
if embedding_var_name in var_def_dict:
if (var_def_dict[embedding_var_name][0] != captured_scope_name
and not is_shared_embedding):
raise ValueError(
'For embedding var name {}, the variable scope name is different, '
'got {}; expected {}'.format(embedding_var_name,
captured_scope_name,
var_def_dict[embedding_var_name][0]))
if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
raise ValueError(
'For embedding var name {}, the embedding name is different, '
'got {}; expected {}'.format(embedding_var_name,
embedding_var_name_in_fc,
var_def_dict[embedding_var_name][1]))
else:
var_def_dict[embedding_var_name] = (captured_scope_name,
embedding_var_name_in_fc)
def _is_running_on_cpu():
"""Returns True if the current context is CPU model."""
return tpu_function.get_tpu_context().number_of_shards is None

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tests for contrib.tpu.python.tpu.feature_column."""
"""Tests for python.tpu.feature_column."""
from __future__ import absolute_import
from __future__ import division
@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc
from tensorflow.python.client import session
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column import feature_column_lib as fc_lib
@ -31,6 +30,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import feature_column as tpu_fc
def _initialized_session():

View File

@ -0,0 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Functional operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.tpu.ops import tpu_ops
TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name

View File

@ -0,0 +1,418 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations for TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_function
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.python.ops import gen_tpu_ops
from tensorflow.python.ops.gen_tpu_ops import *
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
def _create_default_group_assignment():
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
logging.warning(
"cross_replica_sum should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
group_assignment = [list(range(num_shards))]
return group_assignment
def all_to_all(x,
concat_dimension,
split_dimension,
split_count,
group_assignment=None,
name=None):
"""Exchange data across TPU replicas.
Args:
x: The local tensor.
concat_dimension: The dimension number to concatenate.
split_dimension: The dimension number to split.
split_count: The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1])
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
A `Tensor` which is concatenated by data from different replicas.
"""
if group_assignment is None:
group_assignment = _create_default_group_assignment()
return gen_tpu_ops.all_to_all(
x,
group_assignment,
concat_dimension=concat_dimension,
split_dimension=split_dimension,
split_count=split_count,
name=name)
@ops.RegisterGradient("AllToAll")
def _all_to_all_grad(op, grad):
# The gradient of a all-to-all is also a all-to-all but the
# split_dimension and concat_dimension is swapped.
# The graident with respect to group_assignment is None.
return [
gen_tpu_ops.all_to_all(
grad,
op.inputs[1],
concat_dimension=op.get_attr("split_dimension"),
split_dimension=op.get_attr("concat_dimension"),
split_count=op.get_attr("split_count")), None
]
def cross_replica_sum(x, group_assignment=None, name=None):
"""Sum the input tensor across replicas according to group_assignment.
Args:
x: The local tensor to the sum.
group_assignment: Optional 2d int32 lists with shape [num_groups,
num_replicas_per_group]. `group_assignment[i]` represents the replica
ids in the ith subgroup.
name: Optional op name.
Returns:
A `Tensor` which is summed across replicas.
"""
if group_assignment is None:
group_assignment = _create_default_group_assignment()
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def collective_permute(x, source_target_pairs, name=None):
"""Permute the input tensor across replicas given source_target_pairs.
For each source_target_pair <a, b>, we send replica a's input to replica b.
Each replica id must only appear once in the source column. Also it must
only appear once in the target column.
For the replica id not in the target column, this op returns a zero tensor
with the same shape and dtype of the input x.
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
`[0, A, B, C]`.
Args:
x: The local tensor to be permuted.
source_target_pairs: 2d int lists with shape [num_pairs, 2].
source_target_pairs[i][0] represents the source replica id and
source_target_pairs[i][1] represents the target replica id.
name: Optional op name.
Returns:
A `Tensor` which is permuted.
"""
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
@ops.RegisterGradient("CollectivePermute")
def _collective_permute_grad(op, grad):
# The gradient of a collective permute operation is also a collective
# permute, but with source/target pairs reversed. The gradient with respect
# to input argument `source_target_pairs` is `None`.
source_target_pairs = op.inputs[1][:, ::-1]
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
# The gradient with respect to group_assignment is None.
return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
# types are supported.
_SUPPORTED_INFEED_DTYPES = set([
dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
dtypes.complex64, dtypes.uint32
])
@ops.RegisterGradient("TPUEmbeddingActivations")
def _embedding_activations_grad(activations_op, grad_wrt_activations):
"""Saves the gradient of embedding activations ops in a graph collection."""
g = ops.get_default_graph()
table_id = activations_op.get_attr("table_id")
lookup_id = activations_op.get_attr("lookup_id")
table_gradients = g.get_collection_ref(
"tpu_embedding_gradients_table_%d" % table_id)
if not table_gradients:
raise RuntimeError(
"Gradients for TPUEmbedding have been generated in non-training mode."
"This is not expected. Consider putting your Optimizer.minimize code "
"behind the training mode condition check. For Estimator, you can "
"do \n\n"
" if mode == tf.estimator.ModeKeys.TRAIN:\n"
" train_op = opt.minimize(loss)\n"
"\n")
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
return [
# RegisterGradient requires that value be returned for all inputs. Since
# the first argument (tpu_gradient_variable_{table_name}) has shape [1],
# we will return zeros(shape=[1]). The actual gradient w.r.t. the
# embedding activations (grad_wrt_activations) has the same shape as the
# activations returned by embedding_activations.
array_ops.zeros(arg.shape, dtype=dtypes.float32)
for arg in activations_op.inputs
]
def infeed_dequeue(dtype, shape, name=None):
"""A placeholder op for a value that will be fed into the computation.
Args:
dtype: A `tf.DType`. The type of elements in the tensor.
shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `dtype`.
A tensor that will be provided using the infeed mechanism.
Raises:
TypeError: If 'dtype` is not a supported infeed type.
"""
if dtype not in _SUPPORTED_INFEED_DTYPES:
raise TypeError(
"{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
# pylint: disable=redefined-outer-name
def infeed_dequeue_tuple(dtypes, shapes, name=None):
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
Args:
dtypes: A list of `tf.DType`s that has length `>= 1`.
The element types of each element in `outputs`.
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
The shapes of each tensor in `outputs`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `dtypes`.
A list of tensors that will be provided using the infeed mechanism.
Raises:
TypeError: If a type in 'dtypes` is not a supported infeed type.
"""
for dtype in dtypes:
if dtype not in _SUPPORTED_INFEED_DTYPES:
raise TypeError(
"{} is not a supported TPU infeed type. Supported types are: "
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name
# pylint: disable=protected-access
def send_tpu_embedding_gradients(inputs,
config,
learning_rates=None,
name=None):
"""A placeholder op for feeding per-sample gradients to the embedding layer.
Args:
inputs: A TensorList of gradients with which to update embedding tables.
This argument has the same length and shapes as the return value of
RecvTPUEmbeddingActivations, but contains gradients of the model's
loss with respect to the embedding activations. The embedding tables
are updated from these gradients via the optimizers specified in the
TPU embedding configuration given to tpu.initialize_system.
config: Serialized TPUEmbeddingConfiguration proto.
learning_rates: A TensorList of float32 scalars, one for each dynamic
learning rate tag: see the comments in
//third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto.
Multiple tables can share the same dynamic learning rate tag as
specified in the configuration. If the learning rates for all tables
are constant, this list should be empty.
name: A name for the operation (optional).
Returns:
A SendTPUEmbeddingGradients operation.
"""
if learning_rates is None:
learning_rates = []
return gen_tpu_ops.send_tpu_embedding_gradients(
inputs=inputs, learning_rates=learning_rates, config=config, name=name)
send_tpu_embedding_gradients.__doc__ = (
gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_integer_batch(batch,
device_ordinal,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
batch: A list of 1D tensors, one for each embedding table, containing the
indices into the tables.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingIntegerBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
batch=batch,
device_ordinal=device_ordinal,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_integer_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_sparse_batch(sample_indices,
embedding_indices,
aggregation_weights,
device_ordinal,
combiners=None,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 1 Tensors specifying the training example
and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size).
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables.
aggregation_weights: A list of rank 1 Tensors containing per sample --
i.e. per (training example, feature) -- aggregation weights.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
is to use 'sum' for all tables (optional).
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingSparseBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
sample_indices=sample_indices,
embedding_indices=embedding_indices,
aggregation_weights=aggregation_weights,
device_ordinal=device_ordinal,
combiners=combiners,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_sparse_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
embedding_indices,
aggregation_weights,
table_ids,
device_ordinal,
combiners=None,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_indices: A list of rank 1 Tensors specifying the training example
to which the corresponding embedding_indices and aggregation_weights
values belong. It corresponds to sp_ids.indices[:,0] in
embedding_lookup_sparse().
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse().
table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using
table_ids[i]. The size of the table_ids list must be equal to that of
sample_indices, embedding_indices and aggregation_weights.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
is to use 'sum' for all tables (optional).
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingSparseTensorBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
sample_indices=sample_indices,
embedding_indices=embedding_indices,
aggregation_weights=aggregation_weights,
table_ids=table_ids,
device_ordinal=device_ordinal,
combiners=combiners,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
pass

View File

@ -0,0 +1,20 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations to select TPU core to run."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,27 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//tensorflow:__subpackages__",
],
)
py_library(
name = "profiler",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":tpu_profiler_analysis_pb2_grpc",
"//tensorflow/core/profiler:profiler_analysis_proto_py",
"//tensorflow/core/profiler:protos_all_py",
"//tensorflow/python:util",
],
)
py_library(
name = "tpu_profiler_analysis_pb2_grpc",
srcs = ["tpu_profiler_analysis_pb2_grpc.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
)

View File

@ -0,0 +1,31 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Classes for TPU trace events."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.core.profiler.trace_events_pb2 import *
from tensorflow.core.profiler.profiler_analysis_pb2 import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent']
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,438 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Operations for handling session logging and shutdown notifications."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from google.protobuf import text_format
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
_WATCHDOG = None
class CoordinatorShutdownException(Exception):
"""Raised when the coordinator needs to shutdown."""
pass
def _clone_session(session, graph=None):
return session_lib.Session(
target=session.sess_str,
config=session._config, # pylint: disable=protected-access
graph=graph if graph else session.graph)
def _make_heartbeat_op(session, device, request_ph):
"""Return a heartbeat op or None if heartbeats are not supported by device."""
try:
# Test if we can connect in a isolated graph + session
with ops.Graph().as_default():
with _clone_session(session) as temp_session:
with ops.device(device):
heartbeat_op = tpu_ops.worker_heartbeat('')
options = config_pb2.RunOptions(timeout_in_ms=5000)
temp_session.run(heartbeat_op, options=options)
except errors.InvalidArgumentError as _:
logging.warning('Error running heartbeat on %s', device)
return None
except errors.DeadlineExceededError as _:
logging.warning('Timeout connecting to %s when testing heartbeat', device)
return None
# If we successfully connected and pinged the worker, go ahead and construct
# the operation.
with ops.device(device):
return tpu_ops.worker_heartbeat(request_ph)
class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers."""
def __init__(self, session, devices, heartbeat_ops, request_placeholder):
"""Construct a new WorkerHeartbeatManager.
(Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
Args:
session: `tf.Session`, session to use for heartbeat operations.
devices: `list[string]` Set of devices to connect to.
heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
the WorkerHeartbeatRequest protocol buffer.
"""
self._session = session
self._devices = devices
self._ops = heartbeat_ops
self._request_placeholder = request_placeholder
@staticmethod
def from_devices(session, devices):
"""Construct a heartbeat manager for the given devices."""
if not devices:
logging.error('Trying to create heartbeat manager with no devices?')
logging.info('Creating heartbeat manager for %s', devices)
request_placeholder = array_ops.placeholder(
name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = []
kept_devices = []
for device in devices:
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
if heartbeat_op is not None:
kept_devices.append(device)
heartbeat_ops.append(heartbeat_op)
else:
logging.warning('Heartbeat support not available for %s', device)
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
request_placeholder)
def num_workers(self):
return len(self._devices)
def configure(self, message):
"""Configure heartbeat manager for all devices.
Args:
message: `event_pb2.WorkerHeartbeatRequest`
Returns: `None`
"""
logging.info('Configuring worker heartbeat: %s',
text_format.MessageToString(message))
self._session.run(self._ops,
{self._request_placeholder: message.SerializeToString()})
def ping(self, request=None, timeout_in_ms=5000):
"""Ping all workers, returning the parsed status results."""
if request is None:
request = event_pb2.WorkerHeartbeatRequest()
options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
results = self._session.run(
self._ops,
feed_dict={self._request_placeholder: request.SerializeToString()},
options=options)
parsed_results = [
event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
for res_pb in results
]
logging.debug('Ping results: %s', parsed_results)
return parsed_results
def lame_workers(self):
"""Ping all workers, returning manager containing lame workers (or None)."""
ping_results = self.ping()
lame_workers = []
for ping_response, device, op in zip(ping_results, self._devices,
self._ops):
if ping_response.health_status != event_pb2.OK:
lame_workers.append((device, op))
if not lame_workers:
return None
bad_devices, bad_ops = zip(*lame_workers)
return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
self._request_placeholder)
def __repr__(self):
return 'HeartbeatManager(%s)' % ','.join(self._devices)
def shutdown(self, timeout_ms=10000):
"""Shutdown all workers after `shutdown_timeout_secs`."""
logging.info('Shutting down %s.', self)
req = event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
self.configure(req)
# Wait for workers to shutdown. This isn't strictly required
# but it avoids triggering multiple checkpoints with the same lame worker.
logging.info('Waiting %dms for worker shutdown.', timeout_ms)
time.sleep(timeout_ms / 1000)
def all_worker_devices(session):
"""Return a list of devices for each worker in the system."""
devices = session.list_devices()
return [
device.name
for device in devices
if ':CPU:' in device.name and 'coordinator' not in device.name
]
class WatchdogManager(threading.Thread):
"""Configures worker watchdog timer and handles periodic pings.
Usage:
# Ping workers every minute, shutting down workers if they haven't received
# a ping after 1 hour.
watchdog_manager = WatchdogManager(
ping_interval=60, shutdown_timeout=3600
)
# Use as a context manager, resetting watchdog on context exit:
with watchdog_manager:
session.run(...)
# Or setup globally; watchdog will remain active until program exit.
watchdog_manager.configure_and_run()
"""
def __init__(self,
session,
devices=None,
ping_interval=60,
shutdown_timeout=3600):
"""Initialize a watchdog manager.
Args:
session: Session connected to worker devices. A cloned session and graph
will be created for managing worker pings.
devices: Set of devices to monitor. If none, all workers will be
monitored.
ping_interval: Time, in seconds, between watchdog pings.
shutdown_timeout: Time, in seconds, before watchdog timeout.
"""
threading.Thread.__init__(self)
self.ping_interval = ping_interval
self.shutdown_timeout = shutdown_timeout
self.daemon = True
self._config = session._config # pylint: disable=protected-access
self._target = session.sess_str
self._running = False
self._devices = devices
self._graph = None
self._session = None
self._worker_manager = None
def _reset_manager(self):
"""Reset the graph, session and worker manager."""
self._graph = ops.Graph()
self._session = session_lib.Session(
target=self._target,
graph=self._graph,
config=self._config,
)
if self._devices is None:
self._devices = all_worker_devices(self._session)
with self._graph.as_default():
self._worker_manager = WorkerHeartbeatManager.from_devices(
self._session, self._devices)
self._worker_manager.configure(
event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(
timeout_ms=self.shutdown_timeout * 1000,),
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
def configure_and_run(self):
logging.info(
'Enabling watchdog timer with %d second timeout '
'and %d second ping interval.', self.shutdown_timeout,
self.ping_interval)
self._reset_manager()
self._running = True
self.start()
def stop(self):
logging.info('Stopping worker watchdog.')
self._worker_manager.configure(
event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
shutdown_mode=event_pb2.NOT_CONFIGURED))
self._running = False
self.join()
def __enter__(self):
self.configure_and_run()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def run(self):
# Don't fetch logs or adjust timing: just ping the watchdog.
#
# If we hit an exception, reset our session as it is likely broken.
while self._running:
try:
self._worker_manager.ping(request=None)
time.sleep(self.ping_interval)
except errors.OpError as e:
# Catch any TF errors that occur so we don't stop sending heartbeats
logging.debug('Caught error while sending heartbeat: %s', e)
self._reset_manager()
def start_worker_watchdog(session,
devices=None,
ping_interval=60,
shutdown_timeout=3600):
"""Start global worker watchdog to shutdown workers on coordinator exit."""
global _WATCHDOG
if _WATCHDOG is None:
# Ensure we can send a few pings before we timeout!
ping_interval = min(shutdown_timeout / 10., ping_interval)
_WATCHDOG = WatchdogManager(session, devices, ping_interval,
shutdown_timeout)
_WATCHDOG.configure_and_run()
class GracefulShutdownHook(session_run_hook.SessionRunHook):
"""Session hook that watches for shutdown events.
If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
SystemShutdown exception is raised to terminate the main session. If `saver`
is None the `SAVERS` collection will be read to find a saver.
`on_shutdown_hooks` is an optional list of functions that should be called
after checkpointing. The function is called with (`run_context`,
`all_workers`, `lame_workers`).
If `heartbeat_group` is not specified, it will default to all CPU workers
in the system.
"""
def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
self._saver = saver
self._checkpoint_prefix = checkpoint_prefix
self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
# Worker heartbeats are managed independently of the main training graph.
self._graph = ops.Graph()
self._workers = None
self._session = None
self._heartbeat_supported = False
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
# N.B. We have to pull the global step here to avoid it being unavailable
# at checkpoint time; the graph has been frozen at that point.
if training_util.get_global_step() is None and self.saver() is not None:
raise ValueError(
'Saver defined but no global step. Run `get_or_create_global_step()`'
' in your model definition to allow checkpointing.')
with self._graph.as_default():
logging.info('Installing graceful shutdown hook.')
self._session = _clone_session(training_session, self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
else:
logging.warn(
'No workers support hearbeats. Failure handling will be disabled.')
def saver(self):
if self._saver:
return self._saver
savers = ops.get_collection(ops.GraphKeys.SAVERS)
if not savers:
return None
if not isinstance(savers, list):
return savers
if len(savers) > 1:
logging.error(
'Multiple savers in the SAVERS collection. On-demand checkpointing '
'will be disabled. Pass an explicit `saver` to the constructor to '
'override this behavior.')
return None
return savers[0]
def after_run(self, run_context, run_values):
del run_values
if not self._heartbeat_supported:
return
lame_workers = self._workers.lame_workers()
if lame_workers:
logging.info('ShutdownHook: lame workers found: %s', lame_workers)
if self.saver():
logging.info('ShutdownHook: saving checkpoint to %s',
self._checkpoint_prefix)
self.saver().save(
run_context.session,
self._checkpoint_prefix,
global_step=training_util.get_global_step(),
write_state=True,
)
else:
logging.info('ShutdownHook: no Saver defined.')
for fn in self._on_shutdown_hooks:
fn(run_context, self._workers, lame_workers)
class RestartComputation(object):
"""Restart the entire computation.
This hook shuts down all workers and returns control to the top-level by
throwing a CoordinatorShutdownException.
"""
def __init__(self, timeout_ms=10000):
self.timeout_ms = timeout_ms
def __call__(self, run_context, all_workers, lame_workers):
del run_context, lame_workers
all_workers.shutdown(timeout_ms=self.timeout_ms)
logging.info('Terminating coordinator.')
raise CoordinatorShutdownException()
class ShutdownLameWorkers(object):
"""Shutdown lamed workers.
Processing will continue normally (typically by waiting for the down
workers to be restarted).
"""
def __init__(self, timeout_ms=10000):
self.timeout_in_ms = timeout_ms
def __call__(self, run_context, all_workers, lame_workers):
lame_workers.shutdown(timeout_ms=self.timeout_in_ms)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,220 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Defines the `Topology` class, that describes a TPU fabric topology."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf.tpu import topology_pb2
def _tpu_device_name(job, task, device):
"""Returns the device name for the TPU `device` on `task` of `job`."""
if job is None:
return "/task:%d/device:TPU:%d" % (task, device)
else:
return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
def _tpu_host_device_name(job, task):
"""Returns the device name for the CPU device on `task` of `job`."""
if job is None:
return "/task:%d/device:CPU:0" % task
else:
return "/job:%s/task:%d/device:CPU:0" % (job, task)
class Topology(object):
"""Describes a set of TPU devices.
Represents both the shape of the physical mesh, and the mapping between
TensorFlow TPU devices to physical mesh coordinates.
"""
def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
"""Builds a Topology object.
If `serialized` is not `None`, the topology is parsed from `serialized` and
the other arguments are ignored. Otherwise, the topology is computed from
`mesh_shape` and `device_coordinates`.
Args:
serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
serialized proto is parsed to discover the topology.
mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
the shape of the TPU topology, in number of cores. Ignored if
`serialized` is not `None`.
device_coordinates: A rank 3 numpy array that describes the mapping from
TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
if `serialized is not `None`.
Raises:
ValueError: If `serialized` does not describe a well-formed topology.
ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
of 3 positive integers.
ValueError: If `serialized` is `None` and `device_coordinates` is not a
rank 3 numpy int32 array that describes a valid coordinate mapping.
"""
self._serialized = serialized
if serialized:
self._parse_topology(serialized)
else:
self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
self._device_coordinates = np.asarray(device_coordinates, np.int32)
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
raise ValueError("`mesh_shape` must be a sequence of 3 positive "
"entries; got {}".format(self._mesh_shape))
if (len(self._device_coordinates.shape) != 3 or
self._device_coordinates.shape[2] != len(self._mesh_shape)):
raise ValueError("`device_coordinates` must be a rank 3 int32 array "
"with minor dimension equal to the mesh shape rank")
self._topology_tasks, self._topology_devices = self._invert_topology()
def _parse_topology(self, serialized):
"""Parses a serialized `TopologyProto` into `self`."""
proto = topology_pb2.TopologyProto()
proto.ParseFromString(serialized)
self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
"entries; got {}".format(self._mesh_shape))
if proto.num_tasks < 0:
raise ValueError("`num_tasks` must be >= 0; got {}".format(
proto.num_tasks))
if proto.num_tpu_devices_per_task < 0:
raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
proto.num_tpu_devices_per_task))
expected_coordinates_size = (
proto.num_tasks * proto.num_tpu_devices_per_task * len(
proto.mesh_shape))
if len(proto.device_coordinates) != expected_coordinates_size:
raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
"num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
"got shape {}".format(proto.num_tasks,
proto.num_tpu_devices_per_task,
proto.mesh_shape,
len(proto.device_coordinates)))
coords = np.array(proto.device_coordinates, dtype=np.int32)
if any(coords < 0):
raise ValueError("`device_coordinates` must be >= 0")
coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
len(proto.mesh_shape)))
self._device_coordinates = coords
def _invert_topology(self):
"""Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
for task in xrange(self.device_coordinates.shape[0]):
for device in xrange(self.device_coordinates.shape[1]):
x, y, z = self.device_coordinates[task, device, :]
tasks[x, y, z] = task
devices[x, y, z] = device
return tasks, devices
@property
def mesh_shape(self):
"""A rank 1 int32 array describing the shape of the TPU topology."""
return self._mesh_shape
@property
def mesh_rank(self):
"""Returns the number of dimensions in the mesh."""
return len(self._mesh_shape)
@property
def device_coordinates(self):
"""Describes the mapping from TPU devices to topology coordinates.
Returns:
A rank 3 int32 array with shape `[tasks, devices, axis]`.
`tasks` is the number of tasks in the TPU cluster, `devices` is the number
of TPU devices per task, and `axis` is the number of axes in the TPU
cluster topology. Each entry gives the `axis`-th coordinate in the
topology of a task/device pair. TPU topologies are 3-dimensional, with
dimensions `(x, y, core number)`.
"""
return self._device_coordinates
def task_ordinal_at_coordinates(self, device_coordinates):
"""Returns the TensorFlow task number attached to `device_coordinates`.
Args:
device_coordinates: An integer sequence describing a device's physical
coordinates in the TPU fabric.
Returns:
Returns the TensorFlow task number that contains the TPU device with those
physical coordinates.
"""
return self._topology_tasks[tuple(device_coordinates)]
def tpu_device_ordinal_at_coordinates(self, device_coordinates):
"""Returns the TensorFlow device number at `device_coordinates`.
Args:
device_coordinates: An integer sequence describing a device's physical
coordinates in the TPU fabric.
Returns:
Returns the TensorFlow device number within the task corresponding to
attached to the device with those physical coordinates.
"""
return self._topology_devices[tuple(device_coordinates)]
def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
"""Returns the CPU device attached to a logical core."""
return _tpu_host_device_name(
job, self._topology_tasks[tuple(device_coordinates)])
def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
return _tpu_device_name(job,
self._topology_tasks[tuple(device_coordinates)],
self._topology_devices[tuple(device_coordinates)])
@property
def num_tasks(self):
"""Returns the number of TensorFlow tasks in the TPU slice."""
return self._device_coordinates.shape[0]
@property
def num_tpus_per_task(self):
"""Returns the number of TPU devices per task in the TPU slice."""
return self._device_coordinates.shape[1]
def serialized(self):
"""Returns the serialized form of the topology."""
if self._serialized is None:
proto = topology_pb2.TopologyProto()
proto.mesh_shape[:] = list(self._mesh_shape)
proto.num_tasks = self._device_coordinates.shape[0]
proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
self._serialized = proto.SerializeToString()
return self._serialized

View File

@ -19,9 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import topology
from tensorflow.python.platform import test
from tensorflow.python.tpu import topology
class TopologyTest(test.TestCase):

1576
tensorflow/python/tpu/tpu.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,276 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""A RunConfig subclass with TPU support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import util as util_lib
# pylint: disable=protected-access
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
# pylint: enable=protected-access
class InputPipelineConfig(object):
r"""Please see the definition of these values in TPUConfig."""
PER_SHARD_V1 = 1
PER_HOST_V1 = 2
PER_HOST_V2 = 3
BROADCAST = 4
class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
'input_partition_dims',
])):
r"""TPU related configuration required by `TPUEstimator`.
Args:
iterations_per_loop: This is the number of train steps running in TPU
system before returning to CPU host for each `Session.run`. This means
global step is increased `iterations_per_loop` times in one `Session.run`.
It is recommended to be set as number of global steps for next checkpoint.
num_shards: (Deprecated, ignored by TPUEstimator).
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
the model to multiple cores. Currently num_cores_per_replica must be
1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked once on each host. With the per-core input pipeline
configuration, it is invoked once for each core.
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
the batch size for each shard is `train_batch_size` // #hosts in the
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
invoked once on host 0 and the tensors are broadcasted to all other
replicas. The batch size equals to train_batch_size`. With the per-core
input pipeline configuration, the shard batch size is also
`train_batch_size` // #cores.
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
string.
initial_infeed_sleep_secs: The number of seconds the infeed thread should
wait before enqueueing the first batch. This helps avoid timeouts for
models that require a long compilation time.
input_partition_dims: A nested list to describe the partition dims
for all the tensors from input_fn(). The structure of
input_partition_dims must match the structure of `features` and
`labels` from input_fn(). The total number of partitions must match
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
images with shape [N, H, W, C] and labels [N].
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
to all the TPU cores since the partition dims is `None`.
Current limitations: This feature is only supported with the PER_HOST_V2
input mode.
Raises:
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None,
input_partition_dims=None):
# Check iterations_per_loop.
util_lib.check_positive_integer(iterations_per_loop,
'TPUConfig iterations_per_loop')
# Check num_shards.
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
if input_partition_dims is not None:
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
raise ValueError(
'input_partition_dims must be a list/tuple with one or two'
' elements.')
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
raise ValueError(
'input_partition_dims is only supported in PER_HOST_V2 mode.')
if num_cores_per_replica is None:
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
# Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
if per_host_input_for_training is False:
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
elif per_host_input_for_training is True:
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
'TPUConfig initial_infeed_sleep_secs')
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
input_partition_dims=input_partition_dims)
class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
def __init__(self,
tpu_config=None,
evaluation_master=None,
master=None,
cluster=None,
**kwargs):
"""Constructs a RunConfig.
Args:
tpu_config: the TPUConfig that specifies TPU-specific configuration.
evaluation_master: a string. The address of the master to use for eval.
Defaults to master if not set.
master: a string. The address of the master to use for training.
cluster: a ClusterResolver
**kwargs: keyword config parameters.
Raises:
ValueError: if cluster is not None and the provided session_config has a
cluster_def already.
"""
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TPUConfig()
self._cluster = cluster
# If user sets master and/or evaluation_master explicitly, including empty
# string '', take it. Otherwise, take the values set by parent class.
if master is not None:
if cluster is not None:
raise ValueError('Both master and cluster are set.')
self._master = master
else:
if cluster:
self._master = cluster.master()
if evaluation_master is not None:
self._evaluation_master = evaluation_master
elif (not self._evaluation_master and
self.task_type != run_config_lib.TaskType.EVALUATOR):
# If the task type is EVALUATOR, it means some cluster manager sets the
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
#
# Otherwise, it means user executes the code without external cluster
# manager. For that, we optimize the user experience by setting
# evaluation_master to master, unless user overwrites it.
self._evaluation_master = self._master
# Set the ClusterSpec to use
if cluster:
self._cluster_spec = cluster.cluster_spec()
# Merge the cluster_def into the ConfigProto.
if self._session_config is None: # pylint: disable=access-member-before-definition
self._session_config = config_pb2.ConfigProto(
allow_soft_placement=True, isolate_session_state=True)
if self._session_config.HasField('cluster_def'):
raise ValueError(
'You cannot provide a ClusterResolver and '
'session_config.cluster_def.')
if self._cluster_spec:
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
def _maybe_overwrite_session_config_for_distributed_training(self):
# Overrides the parent class session_config overwrite for between-graph. TPU
# runs with in-graph, which should not have device filter. Doing nothing
# ("pass") basically disables it.
pass
@property
def evaluation_master(self):
return self._evaluation_master
@property
def master(self):
return self._master
@property
def tpu_config(self):
return self._tpu_config
@property
def cluster(self):
return self._cluster
def replace(self, **kwargs):
if 'tpu_config' not in kwargs:
return super(RunConfig, self).replace(**kwargs)
tpu_config = kwargs.pop('tpu_config')
new_instance = super(RunConfig, self).replace(**kwargs)
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
return new_instance
def _get_tpu_job_name_from_tf_config():
"""Extracts the TPU job name from TF_CONFIG env variable."""
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
# spec propagation.
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
if tpu_job_name:
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
return tpu_job_name

View File

@ -20,10 +20,10 @@ from __future__ import print_function
import json
from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_config as tpu_config_lib
def _set_tf_config_env_variable(tf_config):

View File

@ -0,0 +1,763 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from contextlib import contextmanager
import copy
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import _tpu_estimator_embedding
from tensorflow.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.python.tpu import tpu_config
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
_NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
8: [2, 2, 2],
16: [4, 2, 2],
}
class TPUContext(object):
"""A context that holds the current configuration of the TPU computation."""
def __init__(self,
internal_ctx,
input_device=None,
invocation_index=None,
call_from_input_fn=True):
self._internal_ctx = internal_ctx
self._input_device = input_device
self._invocation_index = invocation_index
self._call_from_input_fn = call_from_input_fn
def current_input_fn_deployment(self):
"""The configuration of the current input_fn invocation.
The configuration depends on `TPUConfig.per_host_input_for_training`. See
`TPUConfig` for details.
Only set in params dict of input_fn
Returns:
A tuple of
1. Device spec string: String, is the current CPU host where the
input_fn is invoked.
2. Current invocation index: Int, 0-based index of the input_fn
invocation. See next item for details.
3. Total invocation count: Int, the total number of times to invoke the
input_fn on all CPU hosts. Each invocation will be passed with a new
`TPUContext` instance with current invocation index set properly.
4. Total number of replicas consumed by current_invocation: Int, the
number of replicas fed by the data returned by current input_fn. For
example, for per_core input pipeline deployment
and non-model-parallelism, total invocation count is equal to
the number of cores in the system and num replicas consumed by
current invocation is 1. For per-host v2 input pipeline deployment,
total invocation count is equal to the number of hosts in the system
and num replicas consumed by current invocation is equal to number of
cores per host.
Raises:
RuntimeError: If this method must not be called from input_fn.
"""
if not self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' model_fn.')
if self._internal_ctx.is_input_sharded_per_core():
total_invocation_count = (self._internal_ctx.num_hosts
* self._internal_ctx.num_of_replicas_per_host)
replicas_consumed = 1
elif self._internal_ctx.is_input_broadcast_with_iterators():
total_invocation_count = 1
replicas_consumed = self._internal_ctx.num_replicas
else:
total_invocation_count = self._internal_ctx.num_hosts
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
return (self._input_device, self._invocation_index,
total_invocation_count, replicas_consumed)
@property
def num_replicas(self):
"""The total number of replicas.
For non-model-parallelism, num_replicas should be the total num of TPU
cores in the system.
Returns:
The number of replicas.
"""
return self._internal_ctx.num_replicas
@property
def num_hosts(self):
"""The number of hosts for the TPU system."""
return self._internal_ctx.num_hosts
@property
def current_host(self):
"""The current host index for the TPU system."""
return self._invocation_index
@property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
raise ValueError(
'num_of_replicas_per_host is not supported for model_parallelism')
return self._internal_ctx.num_of_replicas_per_host
@property
def device_assignment(self):
"""Returns device_assignment object."""
if self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' input_fn.')
return self._internal_ctx.device_assignment
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
# Note that: For the non-model parallelism, the mapping could be
# a random permutation. The order should not matter in most cases
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function.
The place function takes host_id as the input and returns the TF device
for the correspoding host.
"""
def _placement_function(host_id):
"""Return the host device given host_id."""
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
return _placement_function
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
This immutable object holds TPUEstimator config, train/eval batch size, and
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
provides utility functions, based on the current state, to determine other
information commonly required by TPU computation, such as TPU device names,
TPU hosts, shard batch size, etc.
if eval_on_tpu is False, then execution of eval on TPU is disabled.
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
and TPU execution is disabled for all modes.
N.B. As `mode` is not immutable state in Estimator, but essential to
distinguish between TPU training and evaluation, a common usage for
_InternalTPUContext with `mode` is as follows:
```
with _ctx.with_mode(mode) as ctx:
if ctx.is_running_on_cpu():
...
```
"""
def __init__(self,
config,
train_batch_size,
eval_batch_size,
predict_batch_size,
use_tpu,
eval_on_tpu=True,
embedding_config_spec=None):
self._config = config
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._predict_batch_size = predict_batch_size
self._use_tpu = use_tpu
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
if not use_tpu and eval_on_tpu:
logging.warning('eval_on_tpu ignored because use_tpu is False.')
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
num_cores_per_replica = config.tpu_config.num_cores_per_replica
if self._model_parallelism_enabled:
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
num_cores_per_replica]
else:
self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
self._embedding_config_spec = embedding_config_spec
self._lazy_embedding_config_dict = {} # key by master address
def _assert_mode(self):
if self._mode is None:
raise RuntimeError(
'`mode` needs to be set via contextmanager `with_mode`.')
return self._mode
@contextmanager
def with_mode(self, mode):
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
# such as _lazy_tpu_system_metadata_dict between new copy and the original
# one. Note that all lazy states stored in properties _lazy_foo are sort of
# immutable as they should be same for the process lifetime.
new_ctx = copy.copy(self)
new_ctx._mode = mode # pylint: disable=protected-access
yield new_ctx
@property
def mode(self):
return self._assert_mode()
def _get_master_address(self):
mode = self._assert_mode()
config = self._config
master = (
config.master
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
return master
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
cluster_def = None
if (self._config.session_config and
self._config.session_config.cluster_def.job):
cluster_def = self._config.session_config.cluster_def
# pylint: disable=protected-access
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
query_topology=self.model_parallelism_enabled))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_device_assignment(self):
"""Gets the (maybe cached) TPU device assignment."""
master = self._get_master_address()
device_assignment = self._lazy_device_assignment_dict.get(master)
if device_assignment is not None:
return device_assignment
tpu_system_metadata = self._get_tpu_system_metadata()
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
logging.info('num_cores_per_replica: %s',
str(self._config.tpu_config.num_cores_per_replica))
logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
logging.info('device_assignment.core_assignment: %s',
str(device_assignment.core_assignment))
self._lazy_device_assignment_dict[master] = device_assignment
return device_assignment
@property
def embedding_config(self):
"""Returns the embedding config based on current mode."""
master = self._get_master_address()
if master in self._lazy_embedding_config_dict:
embedding_config = self._lazy_embedding_config_dict[master]
else:
embedding_config = None
if self._use_tpu and self._embedding_config_spec:
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
self._embedding_config_spec, self._train_batch_size,
self._eval_batch_size, self.num_hosts, self.num_cores, master)
if not embedding_config.has_embedding_tables():
embedding_config = None
self._lazy_embedding_config_dict[master] = embedding_config
if embedding_config is not None:
mode = self._assert_mode()
# Dynamically attach tpu_embedding based on mode. With
# this, we could keep embedding_config immutable but call site always
# accesses the unified API '.tpu_embedding'.
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
return embedding_config
@property
def model_parallelism_enabled(self):
return self._model_parallelism_enabled
@property
def input_partition_dims(self):
return self._config.tpu_config.input_partition_dims
@property
def device_assignment(self):
return (self._get_device_assignment()
if self._model_parallelism_enabled else None)
@property
def num_of_cores_per_host(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_of_cores_per_host
@property
def num_cores(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_cores
@property
def num_of_replicas_per_host(self):
"""Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
return self.num_of_cores_per_host
@property
def num_replicas(self):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the total num of '
'TPU cores in the system. num_cores_per_replica: {}, num cores '
'in the system: {}'.format(num_cores_per_replica,
num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
else:
return num_cores_in_system
@property
def num_hosts(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_hosts
@property
def config(self):
return self._config
def is_input_sharded_per_core(self):
"""Return true if input_fn is invoked per-core (other than per-host)."""
mode = self._assert_mode()
return (mode == model_fn_lib.ModeKeys.TRAIN and
(self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1))
def is_input_per_host_with_iterators(self):
"""Return true if input_fn should be run in the per-host v2 config."""
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_HOST_V2)
def is_input_broadcast_with_iterators(self):
"""Return true if input_fn should be run in the full_replicae config."""
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.BROADCAST)
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
This API also validates user provided configuration, such as batch size,
according the lazy initialized TPU system metadata.
Args:
is_export_mode: Indicates whether the current mode is for exporting the
model, when mode == PREDICT. Only with this bool, we could
tell whether user is calling the Estimator.predict or
Estimator.export_savedmodel, which are running on TPU and CPU
respectively. Parent class Estimator does not distinguish these two.
Returns:
bool, whether current input_fn or model_fn should be running on CPU.
Raises:
ValueError: any configuration is invalid.
"""
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
if not is_running_on_cpu:
self._validate_tpu_configuration()
return is_running_on_cpu
def _is_running_on_cpu(self, is_export_mode):
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
mode = self._assert_mode()
if not self._use_tpu:
return True
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
return True
if is_export_mode:
return True
return False
@property
def global_batch_size(self):
mode = self._assert_mode()
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size
elif mode == model_fn_lib.ModeKeys.EVAL:
return self._eval_batch_size
elif mode == model_fn_lib.ModeKeys.PREDICT:
return self._predict_batch_size
else:
return None
@property
def batch_size_for_input_fn(self):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU
if self.is_input_sharded_per_core() or (
self.is_input_per_host_with_iterators()):
return global_batch_size // self.num_replicas
else:
return global_batch_size // self.num_hosts
@property
def batch_size_for_model_fn(self):
"""Returns the shard batch size for `model_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU. always sharded per shard.
return global_batch_size // self.num_replicas
@property
def master_job(self):
"""Returns the job name to use to place TPU computations on.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
run_config = self._config
# If the user specifies the tpu_job_name, use that.
if run_config.tpu_config.tpu_job_name:
return run_config.tpu_config.tpu_job_name
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
mode = self._assert_mode()
master = (
run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
if master in _LOCAL_MASTERS:
return None
if (not run_config.session_config or
not run_config.session_config.cluster_def.job):
return _DEFAULT_JOB_NAME
cluster_def = run_config.session_config.cluster_def
job_names = set([job.name for job in cluster_def.job])
if _DEFAULT_JOB_NAME in job_names:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise ValueError('Currently, tpu_worker is not an allowed job name.')
if len(job_names) == 1:
return cluster_def.job[0].name
if len(job_names) == 2:
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
return job_names.pop()
# TODO(b/67716447): Include more sophisticated heuristics.
raise ValueError(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.')
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
master = self.master_job
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
"""Return the host device given replica_id or host_id."""
assert _sentinal is None
if replica_id is not None and host_id is not None:
raise RuntimeError(
'replica_id and host_id can have only one non-None value.')
if master is None:
return '/replica:0/task:0/device:CPU:0'
else:
if replica_id is not None:
if self.model_parallelism_enabled:
return self.device_assignment.host_device(
replica=replica_id, job=master)
else:
host_id = replica_id / self.num_of_cores_per_host
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
return _placement_function
@property
def tpu_device_placement_function(self):
"""Returns a TPU device placement Fn."""
master = self.master_job
job_device = '' if master is None else ('/job:%s' % master)
def _placement_function(i):
if self.model_parallelism_enabled:
return self.device_assignment.tpu_device(replica=i, job=master)
else:
num_of_cores_per_host = self.num_of_cores_per_host
host_id = i / num_of_cores_per_host
ordinal_id = i % num_of_cores_per_host
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
return _placement_function
def tpu_ordinal_function(self, host_id):
"""Returns the TPU ordinal fn."""
def _tpu_ordinal_function(shard_index_in_host):
"""Return the TPU ordinal associated with a shard.
Required because the enqueue ops are placed on CPU.
Args:
shard_index_in_host: the shard index
Returns:
The ordinal of the TPU device the shard's infeed should be placed on.
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
replica = self.device_assignment.lookup_replicas(host_id,
0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
return _tpu_ordinal_function
def _validate_tpu_configuration(self):
"""Validates the configuration based on the TPU system metadata."""
mode = self._assert_mode()
if self._lazy_validation_dict.get(mode):
return
# All following information is obtained from TPU system metadata.
num_cores = self.num_cores
num_replicas = self.num_replicas
num_hosts = self.num_hosts
if not num_cores:
tpu_system_metadata = self._get_tpu_system_metadata()
raise RuntimeError(
'Cannot find any TPU cores in the system. Please double check '
'Tensorflow master address and TPU worker(s). Available devices '
'are {}.'.format(tpu_system_metadata.devices))
if self._config.tpu_config.num_shards:
user_provided_num_replicas = self._config.tpu_config.num_shards
if user_provided_num_replicas != num_replicas:
message = (
'TPUConfig.num_shards is not set correctly. According to TPU '
'system metadata for Tensorflow master ({}): num_replicas should '
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
raise ValueError(message)
if self._config.tpu_config.num_cores_per_replica:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
if num_cores_per_replica > num_cores_per_host:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the '
'num_cores_per_host. num_cores_per_replica: {}, '
'num_cores_per_host: {}'.format(num_cores_per_replica,
num_cores_per_host))
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
elif mode == model_fn_lib.ModeKeys.EVAL:
if self._eval_batch_size is None:
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
if (self._eval_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.evaluate should be running on single TPU'
' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
if (self._predict_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.predict should be running on single TPU worker. '
'got {}.'.format(num_hosts))
# Record the state "validated" into lazy dictionary.
self._lazy_validation_dict[mode] = True
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
master = self.master_job
if self.model_parallelism_enabled:
return (self.device_assignment.host_device(
replica=replica_id, job=master),
self.device_assignment.tpu_ordinal(replica=replica_id))
job_device = '' if master is None else ('/job:%s' % master)
num_of_replicas_per_host = self.num_of_replicas_per_host
host_id = replica_id / num_of_replicas_per_host
ordinal_id = replica_id % num_of_replicas_per_host
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
return (host_device, ordinal_id)
class _OneCoreTPUContext(_InternalTPUContext):
"""Special _InternalTPUContext for one core usage."""
def __init__(self, config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu):
super(_OneCoreTPUContext, self).__init__(
config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
tpu_system_metadata = (
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
num_cores=1,
num_hosts=1,
num_of_cores_per_host=1,
topology=None,
devices=[]))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_tpu_context(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec):
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
config.tpu_config.num_cores_per_replica is None):
if embedding_config_spec is not None:
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
'when embedding_config_spec is not None.')
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,153 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Optional helper for gradient handling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.tpu.ops import tpu_ops
def get_gradients_through_compute_gradients(optimizer, loss, activations):
"""Compute gradients to send to TPU embedding.
Args:
optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
Used to call compute_gradients().
loss: a Tensor to call optimizer.compute_gradients() on.
activations: an OrderedDict mapping feature_name to Tensors of activations.
Returns:
An OrderedDict mapping from feature name Strings to Tensors of gradients of
the loss wrt the activations of the features.
"""
activation_list = activations.values()
grads_and_vars = optimizer.compute_gradients(loss, activation_list)
grads = [grad for grad, _ in grads_and_vars]
feature_to_gradient_dict = collections.OrderedDict(
zip(activations.keys(), grads))
return feature_to_gradient_dict
def create_dummy_table_variables(tpu_embedding):
"""Create dummy embedding table variables.
The sole purpose of these dummy variables are to trigger gradient
calcuation wrt them so that the gradients wrt activation can be captured
and later sent to TPU embedding.
Args:
tpu_embedding: TPUEmbedding, dummy table variables will be created for use
with tpu_embedding.
Returns:
A tuple of dummy variables and their initializer.
Raises:
RuntimeError: if collection to store gradients already exists and is not
empty.
"""
dummy_table_variables = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
dummy_table_variables[table] = (
# Explicitly specifying collections prevents this variable from
# being added to the GLOBAL_VARIABLES collection, so that Saver()
# ignores it.
# But Tensorflow optimizer creates slot variable for these dummy
# variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
# which will be in GLOBAL_VARIABLES collection,
variable_scope.get_variable(
'tpu_embedding_dummy_table_variable_{}'.format(table),
dtype=dtypes.float32,
shape=[1],
use_resource=True,
trainable=True,
collections=['tpu_embedding_dummy_table_variables']))
g = ops.get_default_graph()
table_gradients = g.get_collection_ref(
'tpu_embedding_gradients_table_{}'.format(table_id))
if table_gradients:
raise RuntimeError(
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
table_gradients.extend(
[None] * len(tpu_embedding.table_to_features_dict[table]))
return (dummy_table_variables,
variables.variables_initializer(
dummy_table_variables.values(),
name='tpu_embedding_dummy_table_variables_init'))
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
dummy_table_variables):
"""Have activations depend on dummy table variables for gradient intercept.
Args:
tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
tpu_embedding.
activations: An OrderedDict of feature name String to activation tensors.
dummy_table_variables: An OrderedDict of table name String to dummy table
variables.
Returns:
An OrderedDict of feature name String to activation tensors, which can be
used just as the activations input.
"""
new_activations = collections.OrderedDict()
for feature in activations:
table = tpu_embedding.feature_to_table_dict[feature]
new_activations[feature] = tpu_ops.tpu_embedding_activations(
dummy_table_variables[table],
activations[feature],
table_id=tpu_embedding.table_to_config_dict.keys().index(table),
lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
return new_activations
def get_gradients_through_dummy_table_variables(tpu_embedding):
"""Get gradients wrt the activations of each feature.
Args:
tpu_embedding: TPUEmbedding, create dummy table variable to be used with
tpu_embedding.
Returns:
An OrderedDict mapping feature name to gradient.
Raises:
ValueError: if some gradients are not defined.
"""
g = ops.get_default_graph()
feature_to_gradient_dict = collections.OrderedDict()
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
table_gradients = g.get_collection(
'tpu_embedding_gradients_table_{}'.format(table_id))
if any(gradient is None for gradient in table_gradients):
raise ValueError(
'Table {} with id {} has undefined gradients: this is probably '
'because the model asked TPUEmbedding to compute activations that '
'were not used.'.format(table, table_id))
for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
table_gradients):
feature_to_gradient_dict[feature] = gradient
return feature_to_gradient_dict

File diff suppressed because it is too large Load Diff

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_estimator
def make_input_fn(num_samples):

View File

@ -0,0 +1,919 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Helper library for handling infeed between hosts and TPUs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_sharding
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util import nest
def partition_or_replicate_on_host(tensor, dims):
"""Partitions or replicates the input tensor.
The ops inside this function are placed on the host side.
Args:
tensor: The input tensor which will be partioned or replicated.
dims: A list of integer describes how to partition the input tensor.
Returns:
An iterator of `Tensor`s or a list of partioned tensors.
"""
if dims is None:
return itertools.repeat(tensor)
dims = np.array(dims)
output = [tensor]
shape_list = np.array(tensor.shape.as_list())
quotients, remainders = np.divmod(shape_list, dims)
for axis, (quotient, remainder, dim, original_size) in enumerate(
zip(quotients, remainders, dims, shape_list)):
if dim <= 1:
continue
if remainder > 0:
# For each dimension, when it cannot be evenly partitioned, XLA assumes
# tensors are partitioned in a greedy manner by using
# ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
# are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
# [[(3, 4), (3, 4), (2, 4), (2, 2)],
# [(2, 4), (2, 4), (2, 4), (2, 2)]]
ceil_ratio = quotient + 1
num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
if len(num_or_size_splits) < dim:
num_or_size_splits += [0] * (dim - len(num_or_size_splits))
new_output = []
for x in output:
new_output.append(
array_ops.split(
x, num_or_size_splits=num_or_size_splits, axis=axis))
output = new_output
else:
output = [array_ops.split(x, dim, axis=axis) for x in output]
output = nest.flatten(output)
return output
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
Args:
tensor: The dequeued tensor on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same tensor with the xla_sharding attribute.
"""
if dims is None:
return xla_sharding.replicate(tensor)
elif np.prod(dims) == 1:
return xla_sharding.assign_device(tensor, 0)
else:
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
Args:
dequeues: A list of dequeued tensors on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same dequeues with appropriate xla_sharding attribute.
"""
nest.assert_shallow_structure(dequeues, dims)
return nest.map_structure_up_to(
dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
class InfeedQueue(object):
"""A helper object to build a device infeed queue.
The InfeedQueue builds the host-side and device-side Ops to enqueue and
dequeue elements, respectively, and ensures that their types and
shapes match.
"""
def __init__(self,
number_of_tuple_elements=None,
tuple_types=None,
tuple_shapes=None,
shard_dimensions=None,
name=None):
"""Creates a new InfeedQueue with the given configuration.
The configuration need not be fully specified at creation since it
can be modified subsequently by methods that set the values
explicitly or infer them from the shapes of inputs.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
tuple_types: if not None, a list of types of the elements of the queue.
tuple_shapes: if not None, a list of shapes of the elements of the queue.
shard_dimensions: if not None, a list of dimensions on which the
elements of the queue should be sharded during automatic
parallelization.
name: the name of the queue.
Raises:
ValueError: if number_of_tuple_elements <= 0; or
number_of_tuple_arguments, tuple_types, tuple_shapes, and
shard_dimensions are all None; or the length of tuple_types,
tuple_shapes, or shard_dimensions is not equal to
number_of_tuple_elements; or any element of shard_dimensions
can't be converted to a Dimension.
TypeError: if any element of tuple_types or tuple_shapes can't
be converted to a dtype or TensorShape, respectively.
"""
self._frozen = False
self._generated_enqueue_ops = False
self._generated_dequeue_op = False
self._name = "InfeedQueue" if name is None else name
if number_of_tuple_elements is None:
if tuple_types is not None:
number_of_tuple_elements = len(tuple_types)
elif tuple_shapes is not None:
number_of_tuple_elements = len(tuple_shapes)
elif shard_dimensions is not None:
number_of_tuple_elements = len(shard_dimensions)
else:
raise ValueError(
"number of tuple elements cannot be inferred from InfeedQueue "
"constructor")
if number_of_tuple_elements <= 0:
raise ValueError("number_of_tuple_elements %d must be > 0" %
number_of_tuple_elements)
# Make an empty sharding policy for each tuple element.
self._sharding_policies = [
tpu_sharding.ShardingPolicy()
for _ in xrange(number_of_tuple_elements)
]
if tuple_types is not None:
self.set_tuple_types(tuple_types)
else:
self._tuple_types = None
if tuple_shapes is not None:
self.set_tuple_shapes(tuple_shapes)
else:
self._tuple_shapes = None
if shard_dimensions is not None:
self.set_shard_dimensions(shard_dimensions)
self._validate()
def _validate(self):
"""Checks that the configuration is self-consistent.
Raises:
ValueError: if the shapes and sharding policies don't match.
"""
if self.tuple_shapes is not None:
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
# Raise an error if the policy is incompatible with the shape.
_ = policy.get_sharded_shape(shape)
@property
def number_of_tuple_elements(self):
"""Returns the number of InfeedQueue tuple elements."""
return len(self._sharding_policies)
@property
def tuple_types(self):
"""Returns the types of the InfeedQueue tuple elements."""
return self._tuple_types
def set_tuple_types(self, tuple_types):
"""Sets the type of each element of the queue.
tuple_types must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a dtype.
Args:
tuple_types: the types of each queue element.
Raises:
ValueError: if tuple_types is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_types cannot be converted to a
dtype.
"""
if len(tuple_types) != self.number_of_tuple_elements:
raise ValueError("tuple_types is %s, but must be a list of length %d" %
(str(tuple_types), self.number_of_tuple_elements))
if self._frozen:
for (frozen, updated) in zip(self._tuple_types, tuple_types):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible type. Frozen types are %s, updated types are %s" % (
str(self._tuple_types), str(tuple_types)))
else:
try:
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
except (TypeError) as e:
raise TypeError(
"tuple_types is %s, but must be a list of elements each "
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
@property
def tuple_shapes(self):
"""Returns the shapes of the InfeedQueue tuple elements."""
return self._tuple_shapes
def set_tuple_shapes(self, tuple_shapes):
"""Sets the shape of each element of the queue.
tuple_shapes must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a TensorShape.
Args:
tuple_shapes: the shapes of each queue element.
Raises:
ValueError: if tuple_shapes is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_shapes cannot be converted to
a TensorShape.
"""
if len(tuple_shapes) != self.number_of_tuple_elements:
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
(str(tuple_shapes), self.number_of_tuple_elements))
try:
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
except (ValueError, TypeError) as e:
raise TypeError(
"tuple_shapes is %s, but must be a list of elements each "
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
str(e)))
if self._frozen:
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
% (str(self._tuple_shapes), str(tuple_shapes)))
else:
self._tuple_shapes = tuple_shapes
self._validate()
@property
def sharding_policies(self):
"""Returns the sharding policies of the InfeedQueue tuple elements."""
return self._sharding_policies
@property
def shard_dimensions(self):
"""Gets the shard dimension of each tuple element.
Returns:
A list of length number_of_tuple_elements, where each list entry
is the shard dimension of that tuple element or None if the
shard dimension has not been set.
"""
# The number of shards is always the same for all the policies.
return [policy.shard_dimension for policy in self._sharding_policies]
def set_shard_dimensions(self, shard_dimensions):
"""Sets the shard_dimension of each element of the queue.
shard_dimensions must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a Dimension compatible with self.tuple_shapes.
Args:
shard_dimensions: the dimensions of each queue element.
Raises:
ValueError: if shard_dimensions is not of length
self.number_of_tuple_elements; or an element of
shard_dimensions cannot be converted to a Dimension; or an
element of shard_dimensions is a Dimension that is out of
range for the corresponding tuple element shape.
"""
if len(shard_dimensions) != self.number_of_tuple_elements:
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
% (str(shard_dimensions),
self.number_of_tuple_elements))
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
policy.set_shard_dimension(dimension)
self._validate()
@property
def number_of_shards(self):
"""Gets the number of shards to use for the InfeedQueue.
Returns:
Number of shards or None if the number of shards has not been set.
"""
# The number of shards is always the same for all the policies.
return self._sharding_policies[0].number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards to use for the InfeedQueue.
Args:
number_of_shards: number of ways to shard the InfeedQueue.
Raises:
ValueError: if number_of_shards is not > 0; or the policies have
been frozen and number_of_shards was already set to something
else.
"""
for policy in self._sharding_policies:
policy.set_number_of_shards(number_of_shards)
self._validate()
def set_configuration_from_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of Tensors whose types and shapes are used
to set the queue configuration.
Args:
input_tensors: list of Tensors of the same types and shapes as
the desired queue Tuple.
Raises:
ValueError: if input_tensors is not a list of length
self.number_of_tuple_elements
"""
if len(input_tensors) != self.number_of_tuple_elements:
raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
% (str(input_tensors), self.number_of_tuple_elements))
self.set_tuple_shapes([t.shape for t in input_tensors])
self.set_tuple_types([t.dtype for t in input_tensors])
def set_configuration_from_sharded_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of lists of Tensors whose types and shapes are used
to set the queue configuration. The length of the outer list is the number
of shards required, and each inner list is the tuple of Tensors to use to
determine the types and shapes of the corresponding shard. This method
depends on the shard dimension, and calling it freezes the shard policy.
Args:
input_tensors: list of lists of Tensors. The outer list length corresponds
to the desired number of shards, and each inner list is the size
and shape of the desired configuration of the corresponding shard.
Raises:
ValueError: if any inner list is not a list of length
self.number_of_tuple_elements; or the inner lists do not combine to
form a consistent unsharded shape.
TypeError: if the types of the Tensors in the inner lists do not match.
"""
if not self._frozen:
# Unset the tuple shapes in case the configuration becomes
# transiently inconsistent.
self._tuple_shapes = None
number_of_shards = len(input_tensors)
self.set_number_of_shards(number_of_shards)
for t in input_tensors:
if len(t) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s but must be a list of lists, where each inner"
" list has length number_of_tuple_elements=%d" % (
str(input_tensors), self.number_of_tuple_elements))
# Transpose the inputs to make a list of shard shapes for each tuple
# element.
sharded_shapes = [[t[i].shape for t in input_tensors]
for i in xrange(self.number_of_tuple_elements)]
# For each tuple, get the unsharded shape using that tuple's policy.
unsharded_shapes = [
policy.get_unsharded_shape(s)
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
]
self.set_tuple_shapes(unsharded_shapes)
for i in xrange(1, self.number_of_shards):
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
if t1.dtype != t2.dtype:
raise TypeError(
"types of the tuple elements of input_tensors %s are not "
"consistent" % str(input_tensors))
self.set_tuple_types([t.dtype for t in input_tensors[0]])
def freeze(self):
"""Freezes the InfeedQueue so it can no longer be modified.
The configuration is implicitly frozen before any host-side or
device-side Ops are generated. The configuration cannot be frozen
until the types and shapes of the tuple elements have been set.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set.
"""
self._frozen = True
if self._tuple_types is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple types.")
if self._tuple_shapes is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for shape in self._tuple_shapes:
if shape.dims is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for policy in self._sharding_policies:
policy.freeze()
self._validate()
def generate_dequeue_op(self, tpu_device=0):
"""Generates the device-side Op to dequeue a tuple from the queue.
Implicitly freezes the queue configuration if it is not already
frozen, which will raise errors if the shapes and types have not
been fully specified.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed. If None, no explicit placement will be performed, and it is up
to the user to call this API from within a proper TPU device scope.
The XLA code will fail if the TPU dequeue instruction is not bound to
any device.
Returns:
A list of Outputs corresponding to a shard of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
if tpu_device is not None:
with ops.device(tpu.core(tpu_device)):
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
else:
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
def _generate_enqueue_op(self,
inputs,
name_prefix,
index,
device=None,
tpu_ordinal=-1):
"""Generate a host-side Op to enqueue a tuple to the queue.
If device is None the inputs are all required to have the same
device specification, and the enqueue Op is colocated with
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
Args:
inputs: a list of Tensors with the types and shapes of the tuple elements.
name_prefix: the base name for the Op.
index: the shard index, used to uniquify the Op name.
device: device to place the Op on, or None if it should be
colocated with the inputs.
tpu_ordinal: ordinal of the TPU device on the host to use for
infeed if device is a CPU device. Should be set to -1 if device
is a TPU device.
Returns:
An Op corresponding to a shard of infeed enqueued at the host,
suitable for use within a replicated block.
Raises:
ValueError: if device is None and inputs do not all have the
same device specification.
"""
full_name = "%s/%d" % (name_prefix, index)
shapes = [t.shape for t in inputs]
if device is None:
devices = [t.device for t in inputs]
for i in xrange(1, self.number_of_tuple_elements):
if devices[0] != devices[i]:
raise ValueError(
"input devices for shard %d are %s, but should all be the same" %
(index, str(devices)))
with ops.colocate_with(inputs[0]):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
else:
with ops.device(device):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
def generate_enqueue_ops(self,
sharded_inputs,
tpu_ordinal_function=None,
placement_function=None):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
shard 0 if the queue. Returns the host-side Ops that must be run to
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
for shard i.
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of sharded_inputs, an error
will be raised.
Args:
sharded_inputs: a list of lists of Tensors. The length of the outer list
determines the number of shards. Each inner list indicates the types
and shapes of the tuples in the corresponding shard.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. tpu_ordinal_function must be
set if the inputs are placed on CPU devices.
placement_function: if not None, a function that takes the shard index as
input and returns the host device where the enqueue op should be placed
on.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
if tpu_ordinal_function is None:
tpu_ordinal_function = lambda index: -1
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
tpu_ordinal=tpu_ordinal_function(index),
device=placement_function(index) if placement_function else None)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
# TODO(misard) Generalize this to the case of systems that don't
# have 8 devices per host, and figure out what to do with
# model-parallelism.
def _default_placement_function(self, index):
return "/task:%d/device:CPU:0" % (index / 8)
def _default_ordinal_function(self, index):
return index % 8
# TODO(b/36470756) remove this from tutorials once we have a better story
# for automatic placement of input pipelines.
def split_inputs_and_generate_enqueue_ops(self,
inputs,
device_assignment=None,
placement_function=None,
tpu_ordinal_function=None):
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
Generates the host-side Ops to enqueue a tuple.
This method performs poorly because it takes an entire input on a single
host, splits it, and distributes it to all of the cores. It is present only
to simplify tutorial examples.
inputs is a list of Tensors to use to feed the queue. Each input is split
into self.number_of_shards shards. Returns an Op for each shard to enqueue
the shard. The Op for shard i is placed on device placement_function(i).
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of inputs, an error
will be raised.
Args:
inputs: a list of Tensors which indicates the types and shapes of the
queue tuple.
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
device_assignment is not `None`, but `placement_function` and
`ordinal_function` are None, then `device_assignment` will be used to
place infeeds on the first k TPU shards, where k is the number of shards
in the queue. If all three are `None`, then default placement and
ordinal functions are used.
placement_function: if not None, a function that takes the shard
index as input and returns a device string indicating which
device the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of inputs are not compatible with the frozen
configuration.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of inputs are not compatible with the frozen
configuration.
"""
if device_assignment is None:
if placement_function is None:
placement_function = self._default_placement_function
if tpu_ordinal_function is None:
tpu_ordinal_function = self._default_ordinal_function
else:
def _placement_function_from_map(index):
return device_assignment.host_device(replica=index)
def _ordinal_function_from_map(index):
return device_assignment.tpu_ordinal(replica=index)
if placement_function is None:
placement_function = _placement_function_from_map
if tpu_ordinal_function is None:
tpu_ordinal_function = _ordinal_function_from_map
self.set_configuration_from_input_tensors(inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
split_name_prefix = "%s/split" % self._name
if self.number_of_shards == 1:
transposed_sharded_inputs = [[inp] for inp in inputs]
else:
def split_fn(inp, num_shards, axis, name):
with ops.colocate_with(inp):
return array_ops.split(inp, num_shards, axis=axis, name=name)
transposed_sharded_inputs = [
split_fn(
inp,
self.number_of_shards,
axis=policy.shard_dimension,
name="%s/%d" % (split_name_prefix, index))
for (inp, policy, index) in zip(inputs, self._sharding_policies,
xrange(self.number_of_tuple_elements))
]
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
for i in xrange(self.number_of_shards)]
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
device=placement_function(index),
tpu_ordinal=tpu_ordinal_function(index))
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
class _PartitionedInfeedQueue(InfeedQueue):
"""A helper object to build a device infeed queue with input partition.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
device_assignment: A TPU `DeviceAssignment` which is used to place all the
partitions to different TPU infeed queues.
host_id: The id of the host machine.
input_partition_dims: A nested list/tuple of integers. Each inner
list/tuple describes how to partition the corresponding input tensor.
tuple_types: If not None, a list of types of the elements of the queue.
tuple_shapes: If not None, a list of shapes of the elements of the queue.
name: The name of the queue.
"""
def __init__(self,
number_of_tuple_elements,
device_assignment,
host_id,
input_partition_dims=None,
tuple_types=None,
tuple_shapes=None,
name=None):
super(_PartitionedInfeedQueue, self).__init__(
number_of_tuple_elements=number_of_tuple_elements,
tuple_types=tuple_types,
tuple_shapes=None,
shard_dimensions=None,
name="PartitionedInfeedQueue" if name is None else name)
self._input_partition_dims = input_partition_dims
self._host_id = host_id
self._device_assignment = device_assignment
def generate_dequeue_op(self, tpu_device=0):
"""Generate TPU dequeue ops.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed.
Returns:
A list of Outputs corresponding to a partition of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
with ops.device(tpu.core(tpu_device)):
values = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
return tag_sharding_attribute_for_dequeued_tensors(
values, self._input_partition_dims)
def generate_enqueue_ops(self, per_host_sharded_inputs):
"""Generates the host-side Ops to enqueue the partitioned inputs.
per_host_sharded_inputs is a list, one for each replica, of lists of
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
replica i.
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
For example, if sharded_inputs[i][j] is a 2-D Tensor:
[[A, B, C, D],
[E ,F, G, H]]
self._input_partition_dims[j] is [2, 4].
sharded_inputs[i][j] will be partitioned and flattened into:
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
Args:
per_host_sharded_inputs: a list of lists of Tensors. The length of the
outer list determines the number of shards. Each inner list indicates
the types and shapes of the tuples in the corresponding shard.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints; or if the partition dims are invalid.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
number_of_replicas_per_host = len(per_host_sharded_inputs)
number_of_tuple_elements = len(per_host_sharded_inputs[0])
assert len(self._input_partition_dims) == number_of_tuple_elements
per_host_enqueue_ops = []
for replica_index in range(number_of_replicas_per_host):
flattened_inputs = per_host_sharded_inputs[replica_index]
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
self._input_partition_dims)
inputs_parted_iters = [
iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
for x, dims in zip(per_host_sharded_inputs[replica_index],
inputs_part_dims_flat)
]
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
replica=replica_id, logical_core=logical_core)
infeed_inputs = []
for it in inputs_parted_iters:
input_for_device = next(it, None)
if input_for_device is not None:
infeed_inputs.append(input_for_device)
if infeed_inputs:
per_host_enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
def _check_input_partition_dims(self, tensor, dims):
"""Checks that input partition dims are valid for the `Tensor`.
Args:
tensor: Input tensor for partitioning.
dims: A list of integer describes how to partition the input tensor.
Raises:
ValueError: If the tensor can't be partitioned by dims or the
num_cores_per_replica doesn't match the number of
partitions(dims.prod()).
"""
# No partitioning specified, so don't perform further checks.
if dims is None:
return
dims = np.array(dims)
if (dims < 1).any():
raise ValueError("All input partition dims must be >= 1.")
# No partitioning, so don't perform further checks.
if dims.prod() == 1:
return
if dims.prod() != self._device_assignment.num_cores_per_replica:
raise ValueError(
"The product of each input parition dim should equal to "
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
if dims.shape[0] != tensor.shape.ndims:
raise ValueError(
"Input partition dims must have the same number of dimensions "
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
"partition dims = {}).".format(tensor.shape.as_list(), dims))
tensor.shape.assert_is_fully_defined()
def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
"""Checks dims and partitions or replicates the input tensor.
The ops inside this function are placed on the host side.
Args:
tensor: The input tensor which will be partioned or replicated.
dims: A list of integer describes how to partition the input tensor.
Returns:
An iterator of `Tensor`s or a list of partioned tensors.
"""
self._check_input_partition_dims(tensor, dims)
return partition_or_replicate_on_host(tensor, dims)

View File

@ -0,0 +1,66 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for functions used during TPU compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
class TpuContext(object):
"""A context object holding state about the TPU computation being built."""
def __init__(self):
"""Creates a new TpuContext."""
self._number_of_shards = None
@property
def number_of_shards(self):
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
self._number_of_shards = number_of_shards
# The Tpu context holds the number of shards when a sharded computation is
# being built, or None if no computation is being built.
_current_tpu_context = TpuContext()
@contextlib.contextmanager
def tpu_shard_context(number_of_shards):
if _current_tpu_context.number_of_shards is not None:
raise NotImplementedError("tpu_shard_context cannot be nested.")
try:
_current_tpu_context.set_number_of_shards(number_of_shards)
yield
finally:
_current_tpu_context.set_number_of_shards(None)
def get_tpu_context():
return _current_tpu_context
# Decorator function for tpu computation func that was passed to tpu.rewrite()
# if there is an embedded training loop in this func, trace tools will generate
# step markers for each iteration.
def on_device_training_loop(func):
# Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
return func

View File

@ -19,11 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_feed
class InfeedTest(test.TestCase):

View File

@ -0,0 +1,203 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Optimizer that implements cross-shard gradient reduction for TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import optimizer
class CrossShardOptimizer(optimizer.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
def __init__(self,
opt,
reduction=losses.Reduction.MEAN,
name="CrossShardOptimizer",
group_assignment=None):
"""Construct a new cross-shard optimizer.
Args:
opt: An existing `Optimizer` to encapsulate.
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
group_assignment: Optional 2d int32 lists with shape
[num_groups, num_replicas_per_group] which describles how to apply
optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
"""
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
raise ValueError("Unsupported reduction: %s." % reduction)
super(CrossShardOptimizer, self).__init__(False, name)
self._opt = opt
self._reduction = reduction
self._group_assignment = group_assignment
def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
"""Verify group_assignment and get the subgroup size".
Args:
group_assignment: list of group ids for applying the optimizer
to subgroups.
num_shards: The number of TPU shards.
Returns:
The size of one subgroup in group_assignment.
Raises:
ValueError: If group_assignment is invalid.
"""
if not group_assignment:
return None
if not (isinstance(group_assignment, list) and
all(isinstance(i, list) for i in group_assignment)):
raise ValueError("group_assignment must be a list of list. Got {}".format(
group_assignment))
replica_ids = set()
for g in group_assignment:
for i in g:
replica_ids.add(i)
if set(range(num_shards)) != replica_ids:
raise ValueError("group_assignment must be a permutation of range({0})."
" Got group_assignment={1}".format(
num_shards, group_assignment))
subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else:
raise ValueError("The size of each subgroup in group_assignment must "
"be equal. Got group_assignment={}".format(
self._group_assignment))
def compute_gradients(self, loss, var_list=None, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
This simply wraps the compute_gradients() from the real optimizer. The
gradients will be aggregated in the apply_gradients() so that user can
modify the gradients like clipping with per replica global norm if needed.
The global norm with aggregated gradients can be bad as one replica's huge
gradients can hurt the gradients from other replicas.
Args:
loss: A Tensor containing the value to minimize.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph
under the key `GraphKey.TRAINABLE_VARIABLES`.
**kwargs: Keyword arguments for compute_gradients().
Returns:
A list of (gradient, variable) pairs.
Raises:
ValueError: If not within a tpu_shard_context or group_assignment is
invalid.
"""
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
logging.warning(
"CrossShardOptimizer should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
num_shards)
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
if self._group_assignment:
scale = 1.0 / subgroup_size
else:
scale = 1.0 / num_shards
loss *= scale
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
replicas, and then applies the real optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
compute_gradients().
global_step: Optional Variable to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the Optimizer constructor.
Returns:
An `Operation` that applies the gradients. If `global_step` was not None,
that operation also increments `global_step`.
Raises:
ValueError: If the grads_and_vars is malformed.
"""
summed_grads_and_vars = []
for (grad, var) in grads_and_vars:
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
with ops.colocate_with(grad):
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
grad, self._group_assignment), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
"""Return a slot named "name" created for "var" by the Optimizer.
This simply wraps the get_slot() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
return self._opt.get_slot(*args, **kwargs)
def get_slot_names(self, *args, **kwargs):
"""Return a list of the names of slots created by the `Optimizer`.
This simply wraps the get_slot_names() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
A list of strings.
"""
return self._opt.get_slot_names(*args, **kwargs)
def variables(self):
"""Forwarding the variables from the underlying optimizer."""
return self._opt.variables()

View File

@ -0,0 +1,253 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for sharding during TPU compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import tensor_shape
_DEFAULT_NUMBER_OF_SHARDS = 1
_DEFAULT_SHARD_DIMENSION = 0
# TODO(b/36777903) change other parts of tpu.py to use this class.
class ShardingPolicy(object):
"""An object use to hold the sharding policy for a Tensor.
"""
def __init__(self):
self._number_of_shards = None
self._shard_dimension = None
self._frozen = False
def __str__(self):
if self.number_of_shards is None or self.shard_dimension is None:
return "ShardingPolicy(unset)"
else:
return ("ShardingPolicy(%d shards dimension %d)" %
(self.number_of_shards, self.shard_dimension))
def _fill_default_values(self):
if self._number_of_shards is None:
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
if self._shard_dimension is None:
self._shard_dimension = tensor_shape.as_dimension(
_DEFAULT_SHARD_DIMENSION)
def freeze(self):
"""Prevents further modification to the sharding policy.
Any values that have not been set when freeze is called are set to
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
"""
if not self._frozen:
self._fill_default_values()
self._frozen = True
@property
def number_of_shards(self):
"""Returns the number of shards in the policy or None if unspecified."""
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards for the current policy.
If the policy has been frozen then number_of_shards must match the
existing setting.
Args:
number_of_shards: The number of shards to use in the policy.
Raises:
ValueError: If the policy has been frozen and number_of_shards
differs from the frozen value; or number_of_shards <= 0.
"""
if self._frozen:
if self._number_of_shards != number_of_shards:
raise ValueError(
"Can't set sharding policy to use %d shards since it has been "
"frozen to use %d." % (number_of_shards, self._number_of_shards))
else:
if number_of_shards > 0:
self._number_of_shards = number_of_shards
else:
raise ValueError(
"Can't set sharding policy to use %s shards; value must be >0",
str(number_of_shards))
@property
def shard_dimension(self):
"""Returns the shard dimension of the policy or None if unspecified."""
return self._shard_dimension
def set_shard_dimension(self, shard_dimension):
"""Sets the shard dimension for the current policy.
If the policy has been frozen then shard_dimension must match the
existing setting.
Args:
shard_dimension: The shard dimension to use in the policy.
Raises:
ValueError: If the policy has been frozen and shard_dimension
differs from the frozen value, or shard_dimension can't be
interpreted as a Dimension.
"""
if self._frozen:
if self._shard_dimension != shard_dimension:
raise ValueError(
"Can't set shard dimension to %d since it has been frozen to "
"use %d." % (shard_dimension, self._shard_dimension))
else:
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
def merge(self, other):
"""Merges the policy of another policy into the current policy.
Args:
other: The policy to merge into this one.
Raises:
ValueError: If this policy has been frozen and the merge conflicts with
the frozen policy.
"""
if other.number_of_shards is not None:
self.set_number_of_shards(other.number_of_shards)
if other.shard_dimension is not None:
self.set_shard_dimension(other.shard_dimension)
def get_sharded_shape(self, shape, shard_index=None):
"""Returns the shape of a shard of a full Tensor.
When given the shape of a 'full-size' Tensor, returns the shape of
the sub-Tensor after it has been sharded. Freezes the policy if it
has not yet been frozen.
Args:
shape: The shape of the full-size Tensor to be sharded.
shard_index: The index of the shard whose shape should be returned.
shard_index can be None for sharding policies that use the same
shape for every shard.
freeze_config:
Returns:
The shape of the sharded version of the Tensor.
Raises:
ValueError: If shard_index is None when shards are of different
shapes; or shard_index is not None and
!(0<=shard_index<number_of_shards); or shape does not have at
least self.shard_dimension+1 dimensions; or the value of
shape's shard dimension is not a multiple of
self.number_of_shards
"""
if self._shard_dimension is None or self._number_of_shards is None:
# Don't raise an error if the config is unset.
return None
if shard_index is not None:
if shard_index < 0 or shard_index >= self.number_of_shards:
raise ValueError("shard_index %d, but must be in [0,%d)." %
(shard_index, self._number_of_shards))
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
if dims[self._shard_dimension] is None:
raise ValueError("shape %s must have a fixed size for dimension %d "
"that is known at graph construction time." %
(shape.as_list(), self._shard_dimension))
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
(shape.as_list(), self._number_of_shards,
self._shard_dimension))
dims[self._shard_dimension] /= self._number_of_shards
return tensor_shape.as_shape(dims)
def _unshard_shape(self, shape):
"""Return the unsharded shape that would generate a given sharded shape.
Args:
shape: the sharded shape to unshard
Returns:
The unsharded shape.
Raises:
ValueError: if shape is unknown or does not contain
self.shard_dimension
TypeError: if shape is not convertible to a TensorShape
"""
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
dims[self._shard_dimension] *= self._number_of_shards
return tensor_shape.as_shape(dims)
def get_unsharded_shape(self, shapes):
"""Returns the shape of an unsharded Tensor given a list of shards.
When given a list of shapes of shards, returns the shape of the
unsharded Tensor that would generate the shards. Sets defaults for the
policy if number_of_shards or shard_dimension is None.
Args:
shapes: The shapes of the Tensor shards to be combined.
Returns:
The shape of the unsharded version of the Tensor.
Raises:
ValueError: if shapes is not a list of length
self.number_of_shards; or any element of shapes is not a valid
shape consistent with the sharding policy; or the list of
shapes is not a valid sharding of a full shape.
TypeError: if an element of shapes is not convertible to a
TensorShape
"""
self._fill_default_values()
if len(shapes) != self.number_of_shards:
raise ValueError(
"shapes is %s but must be a list of length number_of_shards=%d" % (
str(shapes), self.number_of_shards))
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
for i in xrange(self.number_of_shards - 1):
if not unsharded_shapes[i].is_compatible_with(
unsharded_shapes[self.number_of_shards - 1]):
raise ValueError(
"sharded shapes %s are not consistent shards of a full shape "
"sharded %d ways along dimension %d" % (
str(shapes), self.number_of_shards, self.shard_dimension))
return unsharded_shapes[0]

View File

@ -19,10 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_sharding
class ShardingTest(test.TestCase):

View File

@ -0,0 +1,156 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu
_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min
_RETRY_TIMES = 120
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
# including num_cores and num_hosts.
_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
'num_cores',
'num_hosts',
'num_of_cores_per_host',
'topology',
'devices',
])
def _query_tpu_system_metadata(master_address, cluster_def=None,
query_topology=False):
"""Automatically detects the TPU system metadata in the system."""
tpu_core_count = 0
devices = []
device_dict = collections.defaultdict(list)
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
while True:
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
master_address)
try:
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master address '
'is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic
# here.
if retry_count <= _RETRY_TIMES:
logging.warning('%s', msg)
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
retry_count += 1
else:
raise ValueError(msg)
num_of_cores_per_host = 0
if tpu_core_count:
num_cores_per_host_set = set(
[len(core_ids) for core_ids in device_dict.values()])
if len(num_cores_per_host_set) != 1:
raise RuntimeError(
'TPU cores on each host is not same. This should not happen!. '
'devices: {}'.format(devices))
num_of_cores_per_host = num_cores_per_host_set.pop()
topology = None
if query_topology:
if not tpu_core_count:
raise RuntimeError(
'Cannot find any TPU cores in the system (master address {}). '
'This usually means the master address is incorrect or the '
'TPU worker has some problems. Available devices: {}'.format(
master_address, devices))
topology = _obtain_topology(master_address, cluster_def)
metadata = _TPUSystemMetadata(
num_cores=tpu_core_count,
num_hosts=len(device_dict),
num_of_cores_per_host=num_of_cores_per_host,
topology=topology,
devices=devices)
if tpu_core_count:
logging.info('Found TPU system:')
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
logging.info('*** Num TPU Cores Per Worker: %d',
metadata.num_of_cores_per_host)
for device in metadata.devices:
logging.info('*** Available Device: %s', device)
else:
logging.info('Failed to find TPU: %s', metadata)
return metadata
def _obtain_topology(master_address, cluster_def):
"""Obtains TPU fabric topology."""
try:
logging.info('Initializing TPU system (master: %s) to fetch topology '
'for model parallelism. This might take a while.',
master_address)
with ops.Graph().as_default():
session_config = get_session_config_with_timeout(
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
with session_lib.Session(
master_address, config=session_config) as sess:
topology = sess.run(tpu.initialize_system())
return topology
except errors.DeadlineExceededError:
raise ValueError(
'Fail to initialize TPU system with master (%s). '
'Please double check the TPU system is functional.' % (
master_address))
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
"""Returns a session given a timeout and a cluster configuration."""
config = config_pb2.ConfigProto(
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
return config

View File

@ -19,18 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import training_loop
class TPUContextTest(test.TestCase):

View File

@ -0,0 +1,222 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Library for constructing a training loop, suitable for TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu import xla
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
`condition` and `body` take the current value of the loop-carried
tensors. 'body' additionally takes a tuple of infeed from
infeed_queue if infeed_queue is not None. `condition` must return a
single boolean value that determines whether iteration
continues. `body` must return an updated list of values for the
loop-carried tensors.
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
TypeError: if body or condition has the wrong signature.
"""
del name
# Converts inputs to Tensors.
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
x in inputs]
input_types = [x.dtype for x in inputs]
input_arity = len(inputs)
body_arg_error = xla.check_function_argument_count(
body, input_arity, infeed_queue)
if body_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
input_arity, str([i.name for i in inputs]), body_arg_error))
else:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s and %d additional inputs from "
"infeed, but the computation needs %s" % (input_arity, str(
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
body_arg_error))
condition_arg_error = xla.check_function_argument_count(
condition, input_arity, None)
if condition_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
else:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s. Note that infeed is not passed to the loop "
"condition." % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
def condition_wrapper(*inputs):
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
return condition(*inputs)
def body_wrapper(*inputs):
"""Wrapper around `body` that handles infeed queues and control deps."""
inputs = list(inputs)
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
# Runs `body` with the dequeue_ops appended.
if infeed_queue:
number_of_shards = tpu_function.get_tpu_context().number_of_shards
if number_of_shards is None:
raise ValueError("Can't build training loop with infeed when there is "
"no tpu_shard_context. Are you building a loop or "
"graph directly rather than from inside tpu.rewrite, "
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
infeed_queue.set_number_of_shards(number_of_shards)
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
else:
dequeue_ops = []
outputs = body(*(inputs + dequeue_ops))
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs
if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
"TPU training loop body must return zero or more Tensor values "
"followed by zero or more Operations.")
output_types = [op.dtype for op in output_tensors]
if input_types != output_types:
raise TypeError(
"Mismatch between input types and output types for training loop "
"body: {} vs {}".format(input_types, output_types))
# Add the dequeue operations to output_operations to ensure they are run
# by the loop, even if the programmer's loop body does not use them.
output_operations += dequeue_ops
# Add a dummy output, if needed.
if not output_tensors:
output_tensors = array_ops.constant(0)
if output_operations:
# TODO(phawkins): in principle this is too restrictive since it serializes
# the training loop steps. In practice it does not matter since this loop
# will be compiled by XLA.
output_tensors = control_flow_ops.tuple(output_tensors,
control_inputs=output_operations)
if tensor_tracer.TensorTracer.is_enabled():
num_replicas = tpu_function.get_tpu_context().number_of_shards
if num_replicas is None:
num_replicas = 1
tt = tensor_tracer.TensorTracer()
output_tensors = tt.trace_tpu(ops.get_default_graph(),
output_tensors, None,
num_replicas)
return output_tensors
# If the body has arity 0, add a dummy loop-carried value to which we can add
# control dependencies from any side-effecting operations.
if input_arity == 0:
inputs = [array_ops.constant(0)]
return control_flow_ops.while_loop(
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop that executes a fixed number of iterations.
The set of loop-carried tensors correspond to `inputs`.
`body` must be a function that takes and returns the values of the
loop-carried tensors.
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
ValueError: if there is a type error.
"""
def _convert_to_list(xs):
if not isinstance(xs, (list, tuple)):
return [xs]
else:
return list(xs)
def cond(i, *args):
del args
return i < n
def body_wrapper(i, *args):
return [i + 1] + _convert_to_list(body(*args))
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
outputs = while_loop(
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
outputs = _convert_to_list(outputs)
if len(outputs) == 1:
# Returns the Op rather than an empty list.
return outputs[0].op
else:
return outputs[1:]

View File

@ -0,0 +1,51 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Utilities for the functionalities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
def check_positive_integer(value, name):
"""Checks whether `value` is a positive integer."""
if not isinstance(value, six.integer_types):
raise TypeError('{} must be int, got {}'.format(name, type(value)))
if value <= 0:
raise ValueError('{} must be positive, got {}'.format(name, value))
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
# python/estimator/util.py.
class MultiHostDatasetInitializerHook(training.SessionRunHook):
"""Creates a SessionRunHook that initializes all passed iterators."""
def __init__(self, dataset_initializers):
self._initializers = dataset_initializers
def after_create_session(self, session, coord):
del coord
start = time.time()
session.run(self._initializers)
logging.info('Initialized dataset iterators in %d seconds',
time.time() - start)

View File

@ -0,0 +1,106 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""XLA utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.util import tf_inspect
def is_flat(outputs):
"""Checks if outputs is a flat structure.
Following structures and values are considered flat:
1) None
2) A single object
3) A list or tuple of Tensors/Operations
The only structures that this function understands are sequences and
dictionaries. E.g. this means that if outputs contains a single
user-defined Object, it is considered to be flat. Errors are raised later on
if that Object cannot be converted to a Tensor.
Args:
outputs: Output from `computation` inside `xla.compile`.
Returns:
A boolean indicates whether outputs is flat.
"""
# If outputs is a list or tuple, check if it has any nested structure. If
# there is, then outputs is non-flat.
if isinstance(outputs, collections.Sequence):
for o in outputs:
if isinstance(o, collections.Sequence) or isinstance(o, dict):
return False
# If outputs is a dict, it is non-flat.
if isinstance(outputs, dict):
return False
# Getting here means either outputs itself is a single non-structured value
# or it is a flat list of single non-structured values.
return True
def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to an XLA function.
Args:
func: the Python function that will be called to generate the body of an XLA
computation graph.
input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
Returns:
None if function can be called with the supplied number of
arguments, or an error string if it cannot.
"""
def format_error(complaint, quantity):
return '%s %d argument%s' % (complaint, quantity, ''
if quantity == 1 else 's')
num_args_supplied = input_arity
if infeed_queue is not None:
num_args_supplied += infeed_queue.number_of_tuple_elements
arg_spec = tf_inspect.getargspec(func)
num_func_args = len(arg_spec.args)
if arg_spec.defaults is None:
num_func_defaults = 0
else:
num_func_defaults = len(arg_spec.defaults)
min_func_args = num_func_args - num_func_defaults
if num_args_supplied < min_func_args:
# The required number of arguments is not enough to call the function.
if num_func_defaults == 0 and arg_spec.varargs is None:
return format_error('exactly', num_func_args)
else:
return format_error('at least', min_func_args)
if arg_spec.varargs is None and num_args_supplied > num_func_args:
# The required number of arguments is too many to call the function.
if num_func_defaults == 0:
return format_error('exactly', num_func_args)
else:
return format_error('at most', num_func_args)
# Reaching here means either
# 1) There are varargs, func can accept any number of arguments greater than
# the minimum.
# 2) Number of supplied arguments falls in range of acceptable argument count
# of func.
return None