Move TPU python files to TF core.
PiperOrigin-RevId: 234851893
This commit is contained in:
parent
2825b62b27
commit
d23fc2b7ff
tensorflow
contrib/tpu
BUILD
profiler
python
ops
profiler
tpu
__init__.py_tpu_estimator_embedding.pyasync_checkpoint.pybfloat16.pydatasets.pydevice_assignment.pyerror_handling.pyfeature_column.pyfunctional.pysession_support.pytensor_tracer.pytopology.pytpu.pytpu_config.pytpu_context.pytpu_embedding.pytpu_embedding_gradient.pytpu_estimator.pytpu_feed.pytpu_function.pytpu_optimizer.pytpu_sharding.pytpu_system_metadata.pytraining_loop.pyutil.py
python
BUILD
tpu
BUILD__init__.py_tpu_estimator_embedding.pyasync_checkpoint.pybfloat16.pybfloat16_test.pydatasets.pydatasets_test.pydevice_assignment.pyerror_handling.pyfeature_column.pyfeature_column_test.pyfunctional.py
ops
profiler
session_support.pytensor_tracer.pytopology.pytopology_test.pytpu.pytpu_config.pytpu_config_test.pytpu_context.pytpu_embedding.pytpu_embedding_gradient.pytpu_estimator.pytpu_estimator_signals_test.pytpu_feed.pytpu_function.pytpu_infeed_test.pytpu_optimizer.pytpu_sharding.pytpu_sharding_test.pytpu_system_metadata.pytpu_test.pytraining_loop.pyutil.pyxla.py@ -28,8 +28,7 @@ py_library(
|
||||
srcs = ["python/ops/tpu_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python/tpu:tpu_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -38,19 +37,7 @@ py_library(
|
||||
srcs = ["python/tpu/async_checkpoint.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/tpu:async_checkpoint",
|
||||
],
|
||||
)
|
||||
|
||||
@ -72,24 +59,7 @@ py_library(
|
||||
":tpu_embedding",
|
||||
":tpu_lib",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/estimator:util",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/python/tpu:tpu_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
@ -101,7 +71,7 @@ py_library(
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python/tpu:functional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -110,10 +80,7 @@ py_library(
|
||||
srcs = ["python/profiler/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc",
|
||||
"//tensorflow/core/profiler:profiler_analysis_proto_py",
|
||||
"//tensorflow/core/profiler:protos_all_py",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/tpu/profiler",
|
||||
],
|
||||
)
|
||||
|
||||
@ -130,6 +97,7 @@ py_library(
|
||||
":tpu_embedding",
|
||||
":tpu_estimator",
|
||||
":tpu_lib",
|
||||
"//tensorflow/python/tpu",
|
||||
],
|
||||
)
|
||||
|
||||
@ -196,29 +164,9 @@ py_library(
|
||||
":functional",
|
||||
":profiler",
|
||||
":tpu_py",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
|
||||
"//tensorflow/contrib/compiler:xla",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -229,106 +177,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:batching",
|
||||
"//tensorflow/contrib/data/python/ops:interleave_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "datasets_test",
|
||||
size = "medium",
|
||||
srcs = ["python/tpu/datasets_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
":datasets",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
shard_count = 4,
|
||||
tags = ["no_oss"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:layers",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_sharding_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_sharding_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "bfloat16_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/bfloat16_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_infeed_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_infeed_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_config_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_config_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_estimator_signals_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_estimator_signals_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "topology_test",
|
||||
size = "medium",
|
||||
srcs = ["python/tpu/topology_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/tpu:datasets",
|
||||
],
|
||||
)
|
||||
|
||||
@ -341,16 +190,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_lib",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/python/tpu:tpu_embedding",
|
||||
],
|
||||
)
|
||||
|
||||
@ -359,31 +199,6 @@ py_library(
|
||||
srcs = ["python/tpu/feature_column.py"],
|
||||
deps = [
|
||||
":tpu_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/feature_column:feature_column_py",
|
||||
"//tensorflow/python/tpu:feature_column",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "feature_column_test",
|
||||
srcs = [
|
||||
"python/tpu/feature_column_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":feature_column",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/feature_column:feature_column_py",
|
||||
],
|
||||
main = "python/tpu/feature_column_test.py",
|
||||
)
|
||||
|
@ -63,11 +63,3 @@ tf_cc_test(
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_profiler_analysis_pb2_grpc",
|
||||
srcs = ["tpu_profiler_analysis_pb2_grpc.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
|
||||
)
|
||||
|
@ -1,418 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations for TPUs."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.python.ops import gen_tpu_ops
|
||||
from tensorflow.python.ops.gen_tpu_ops import *
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
def _create_default_group_assignment():
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
logging.warning(
|
||||
"cross_replica_sum should be used within a tpu_shard_context, but "
|
||||
"got unset number_of_shards. Assuming 1.")
|
||||
num_shards = 1
|
||||
group_assignment = [list(range(num_shards))]
|
||||
return group_assignment
|
||||
|
||||
def all_to_all(x,
|
||||
concat_dimension,
|
||||
split_dimension,
|
||||
split_count,
|
||||
group_assignment=None,
|
||||
name=None):
|
||||
"""Exchange data across TPU replicas.
|
||||
|
||||
Args:
|
||||
x: The local tensor.
|
||||
concat_dimension: The dimension number to concatenate.
|
||||
split_dimension: The dimension number to split.
|
||||
split_count: The number of splits, this number must equal to the sub-group
|
||||
size(group_assignment.get_shape()[1])
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is concatenated by data from different replicas.
|
||||
"""
|
||||
if group_assignment is None:
|
||||
group_assignment = _create_default_group_assignment()
|
||||
return gen_tpu_ops.all_to_all(
|
||||
x,
|
||||
group_assignment,
|
||||
concat_dimension=concat_dimension,
|
||||
split_dimension=split_dimension,
|
||||
split_count=split_count,
|
||||
name=name)
|
||||
|
||||
@ops.RegisterGradient("AllToAll")
|
||||
def _all_to_all_grad(op, grad):
|
||||
# The gradient of a all-to-all is also a all-to-all but the
|
||||
# split_dimension and concat_dimension is swapped.
|
||||
# The graident with respect to group_assignment is None.
|
||||
return [
|
||||
gen_tpu_ops.all_to_all(
|
||||
grad,
|
||||
op.inputs[1],
|
||||
concat_dimension=op.get_attr("split_dimension"),
|
||||
split_dimension=op.get_attr("concat_dimension"),
|
||||
split_count=op.get_attr("split_count")), None
|
||||
]
|
||||
|
||||
def cross_replica_sum(x, group_assignment=None, name=None):
|
||||
"""Sum the input tensor across replicas according to group_assignment.
|
||||
|
||||
Args:
|
||||
x: The local tensor to the sum.
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is summed across replicas.
|
||||
"""
|
||||
if group_assignment is None:
|
||||
group_assignment = _create_default_group_assignment()
|
||||
|
||||
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
|
||||
|
||||
def collective_permute(x, source_target_pairs, name=None):
|
||||
"""Permute the input tensor across replicas given source_target_pairs.
|
||||
|
||||
For each source_target_pair <a, b>, we send replica a's input to replica b.
|
||||
Each replica id must only appear once in the source column. Also it must
|
||||
only appear once in the target column.
|
||||
For the replica id not in the target column, this op returns a zero tensor
|
||||
with the same shape and dtype of the input x.
|
||||
|
||||
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
|
||||
source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
|
||||
`[0, A, B, C]`.
|
||||
|
||||
Args:
|
||||
x: The local tensor to be permuted.
|
||||
source_target_pairs: 2d int lists with shape [num_pairs, 2].
|
||||
source_target_pairs[i][0] represents the source replica id and
|
||||
source_target_pairs[i][1] represents the target replica id.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is permuted.
|
||||
"""
|
||||
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
|
||||
|
||||
@ops.RegisterGradient("CollectivePermute")
|
||||
def _collective_permute_grad(op, grad):
|
||||
# The gradient of a collective permute operation is also a collective
|
||||
# permute, but with source/target pairs reversed. The gradient with respect
|
||||
# to input argument `source_target_pairs` is `None`.
|
||||
source_target_pairs = op.inputs[1][:, ::-1]
|
||||
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
|
||||
|
||||
@ops.RegisterGradient("CrossReplicaSum")
|
||||
def _cross_replica_sum_grad(op, grad):
|
||||
# The gradient of a cross replica sum is also a cross-replica sum.
|
||||
# The gradient with respect to group_assignment is None.
|
||||
return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
|
||||
|
||||
# This extra type checking exists to give a more helpful error message in
|
||||
# the common case that uint8 and int64 values are infed. Remove when both
|
||||
# types are supported.
|
||||
|
||||
_SUPPORTED_INFEED_DTYPES = set([
|
||||
dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
|
||||
dtypes.complex64, dtypes.uint32
|
||||
])
|
||||
|
||||
@ops.RegisterGradient("TPUEmbeddingActivations")
|
||||
def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
||||
"""Saves the gradient of embedding activations ops in a graph collection."""
|
||||
g = ops.get_default_graph()
|
||||
table_id = activations_op.get_attr("table_id")
|
||||
lookup_id = activations_op.get_attr("lookup_id")
|
||||
table_gradients = g.get_collection_ref(
|
||||
"tpu_embedding_gradients_table_%d" % table_id)
|
||||
|
||||
if not table_gradients:
|
||||
raise RuntimeError(
|
||||
"Gradients for TPUEmbedding have been generated in non-training mode."
|
||||
"This is not expected. Consider putting your Optimizer.minimize code "
|
||||
"behind the training mode condition check. For Estimator, you can "
|
||||
"do \n\n"
|
||||
" if mode == tf.estimator.ModeKeys.TRAIN:\n"
|
||||
" train_op = opt.minimize(loss)\n"
|
||||
"\n")
|
||||
|
||||
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
|
||||
return [
|
||||
# RegisterGradient requires that value be returned for all inputs. Since
|
||||
# the first argument (tpu_gradient_variable_{table_name}) has shape [1],
|
||||
# we will return zeros(shape=[1]). The actual gradient w.r.t. the
|
||||
# embedding activations (grad_wrt_activations) has the same shape as the
|
||||
# activations returned by embedding_activations.
|
||||
array_ops.zeros(arg.shape, dtype=dtypes.float32)
|
||||
for arg in activations_op.inputs
|
||||
]
|
||||
|
||||
def infeed_dequeue(dtype, shape, name=None):
|
||||
"""A placeholder op for a value that will be fed into the computation.
|
||||
|
||||
Args:
|
||||
dtype: A `tf.DType`. The type of elements in the tensor.
|
||||
shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of type `dtype`.
|
||||
A tensor that will be provided using the infeed mechanism.
|
||||
|
||||
Raises:
|
||||
TypeError: If 'dtype` is not a supported infeed type.
|
||||
"""
|
||||
if dtype not in _SUPPORTED_INFEED_DTYPES:
|
||||
raise TypeError(
|
||||
"{} is not a supported TPU infeed type. Supported types are: "
|
||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||
|
||||
return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
||||
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
|
||||
|
||||
Args:
|
||||
dtypes: A list of `tf.DType`s that has length `>= 1`.
|
||||
The element types of each element in `outputs`.
|
||||
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
|
||||
The shapes of each tensor in `outputs`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` objects of type `dtypes`.
|
||||
A list of tensors that will be provided using the infeed mechanism.
|
||||
|
||||
Raises:
|
||||
TypeError: If a type in 'dtypes` is not a supported infeed type.
|
||||
"""
|
||||
for dtype in dtypes:
|
||||
if dtype not in _SUPPORTED_INFEED_DTYPES:
|
||||
raise TypeError(
|
||||
"{} is not a supported TPU infeed type. Supported types are: "
|
||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
|
||||
# pylint: enable=redefined-outer-name
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def send_tpu_embedding_gradients(inputs,
|
||||
config,
|
||||
learning_rates=None,
|
||||
name=None):
|
||||
"""A placeholder op for feeding per-sample gradients to the embedding layer.
|
||||
|
||||
Args:
|
||||
inputs: A TensorList of gradients with which to update embedding tables.
|
||||
This argument has the same length and shapes as the return value of
|
||||
RecvTPUEmbeddingActivations, but contains gradients of the model's
|
||||
loss with respect to the embedding activations. The embedding tables
|
||||
are updated from these gradients via the optimizers specified in the
|
||||
TPU embedding configuration given to tpu.initialize_system.
|
||||
config: Serialized TPUEmbeddingConfiguration proto.
|
||||
learning_rates: A TensorList of float32 scalars, one for each dynamic
|
||||
learning rate tag: see the comments in
|
||||
//third_party/tensorflow/core/protobuf/tpu/
|
||||
optimization_parameters.proto.
|
||||
Multiple tables can share the same dynamic learning rate tag as
|
||||
specified in the configuration. If the learning rates for all tables
|
||||
are constant, this list should be empty.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A SendTPUEmbeddingGradients operation.
|
||||
"""
|
||||
if learning_rates is None:
|
||||
learning_rates = []
|
||||
return gen_tpu_ops.send_tpu_embedding_gradients(
|
||||
inputs=inputs, learning_rates=learning_rates, config=config, name=name)
|
||||
|
||||
send_tpu_embedding_gradients.__doc__ = (
|
||||
gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_integer_batch(batch,
|
||||
device_ordinal,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
batch: A list of 1D tensors, one for each embedding table, containing the
|
||||
indices into the tables.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingIntegerBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
|
||||
batch=batch,
|
||||
device_ordinal=device_ordinal,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_integer_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
||||
embedding_indices,
|
||||
aggregation_weights,
|
||||
device_ordinal,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
and feature to which the corresponding embedding_indices and
|
||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||
f, where nf is the number of features from the corresponding table, f is
|
||||
in [0, nf), and b is in [0, batch size).
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||
i.e. per (training example, feature) -- aggregation weights.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
|
||||
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
|
||||
is to use 'sum' for all tables (optional).
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingSparseBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
|
||||
sample_indices=sample_indices,
|
||||
embedding_indices=embedding_indices,
|
||||
aggregation_weights=aggregation_weights,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_sparse_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
embedding_indices,
|
||||
aggregation_weights,
|
||||
table_ids,
|
||||
device_ordinal,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
to which the corresponding embedding_indices and aggregation_weights
|
||||
values belong. It corresponds to sp_ids.indices[:,0] in
|
||||
embedding_lookup_sparse().
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
|
||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||
example aggregation weights. It corresponds to sp_weights.values in
|
||||
embedding_lookup_sparse().
|
||||
table_ids: A list of integers specifying the identifier of the embedding
|
||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||
lookup the corresponding input. The ith input is looked up using
|
||||
table_ids[i]. The size of the table_ids list must be equal to that of
|
||||
sample_indices, embedding_indices and aggregation_weights.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
|
||||
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
|
||||
is to use 'sum' for all tables (optional).
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingSparseTensorBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
sample_indices=sample_indices,
|
||||
embedding_indices=embedding_indices,
|
||||
aggregation_weights=aggregation_weights,
|
||||
table_ids=table_ids,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
|
||||
|
||||
else:
|
||||
# We have already built the appropriate libraries into the binary via CMake
|
||||
# if we have built contrib, so we don't need this
|
||||
pass
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.ops.tpu_ops import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,35 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations to select TPU core to run."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.python.ops.gen_tpu_ops import tpu_ordinal_selector
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
else:
|
||||
# We have already built the appropriate libraries into the binary via CMake
|
||||
# if we have built contrib, so we don't need this
|
||||
pass
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.ops.tpu_ordinal_selector_op import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,31 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Classes for TPU trace events."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.core.profiler.trace_events_pb2 import *
|
||||
from tensorflow.core.profiler.profiler_analysis_pb2 import *
|
||||
from tensorflow.python.tpu.profiler import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -1,20 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Ops related to Tensor Processing Units."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,334 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Tooling for support TPU embedding in TPUEstimator."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_embedding
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
|
||||
tpu_fc._TPUSharedEmbeddingColumn)
|
||||
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
|
||||
core_fc_lib.EmbeddingColumn,
|
||||
core_fc._SharedEmbeddingColumn)
|
||||
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_TABLE_NAME_PREFIX = 'tbl_'
|
||||
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
|
||||
|
||||
|
||||
def _get_table_name_from_embedding_var_name(embedding_var_name):
|
||||
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
|
||||
|
||||
|
||||
def _get_embedding_var_name_from_table_name(table_name):
|
||||
return table_name[_LEN_TABLE_NAME_PREFIX:]
|
||||
|
||||
|
||||
def _get_embedding_variable_name(scope_name, var_name):
|
||||
return '{}/{}'.format(scope_name, var_name)
|
||||
|
||||
|
||||
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
|
||||
"""Return embedding variable names which are consistent with CPU runs."""
|
||||
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
|
||||
return tpu_embedding.AdagradSlotVariableName(
|
||||
'{}/{}/Adagrad'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
|
||||
return tpu_embedding.AdamSlotVariableNames(
|
||||
'{}/{}/Adam/m'.format(scope_name, var_name),
|
||||
'{}/{}/Adam/v'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters,
|
||||
tpu_embedding.StochasticGradientDescentParameters):
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Support to infer full variable name '
|
||||
'for optimization_parameter {} has not been added.'
|
||||
.format(optimization_parameters))
|
||||
|
||||
|
||||
def get_full_variable_names(
|
||||
graph, table_to_config_dict, optimization_parameters):
|
||||
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
|
||||
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
if not collection:
|
||||
raise RuntimeError(
|
||||
'Embedding feature column did not capture any thing. Make sure the '
|
||||
'feature columns passed to TPUEstimator constructor is properly '
|
||||
'used in model_fn.')
|
||||
|
||||
embedding_variable_name_by_table = {}
|
||||
slot_variable_names_by_table = {}
|
||||
for table_name in table_to_config_dict:
|
||||
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
|
||||
(scope_name, var_name) = collection[0][embedding_var_name]
|
||||
embedding_variable_name_by_table[table_name] = (
|
||||
_get_embedding_variable_name(scope_name, var_name))
|
||||
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
|
||||
scope_name, var_name, optimization_parameters)
|
||||
|
||||
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
return embedding_variable_name_by_table, slot_variable_names_by_table
|
||||
|
||||
|
||||
def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
"""Create configs for TPUEmbedding from a list of feature columns.
|
||||
|
||||
This function will place one embedding tensor per table and the return is
|
||||
intended to be used as input to TPUEmbedding.
|
||||
|
||||
Args:
|
||||
feature_columns: a list of supported feature columns.
|
||||
|
||||
Returns:
|
||||
A pair of dicts, the first maps tables to their config, the second maps
|
||||
features to tables.
|
||||
"""
|
||||
|
||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, allowed):
|
||||
raise TypeError(
|
||||
'Unsupported feature column {}. Supported types are {}.'.format(
|
||||
type(column), allowed))
|
||||
|
||||
table_to_config = {}
|
||||
feature_to_table = {}
|
||||
for column in feature_columns:
|
||||
feature_name = column.get_feature_key_name()
|
||||
table_name = _get_table_name_from_embedding_var_name(
|
||||
column.get_embedding_var_name())
|
||||
if feature_name in feature_to_table:
|
||||
raise ValueError(
|
||||
'Feature column {} is used with multiple embeddings and this is '
|
||||
'not supported.'.format(feature_name))
|
||||
feature_to_table[feature_name] = table_name
|
||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||
vocabulary_size=vocabulary_size,
|
||||
dimension=dimension,
|
||||
initializer=column.get_initializer(),
|
||||
combiner=column.get_combiner())
|
||||
|
||||
return table_to_config, feature_to_table
|
||||
|
||||
|
||||
def _get_tpu_embedding_optimization_parameters(embedding_config_spec):
|
||||
"""Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec."""
|
||||
if embedding_config_spec.optimizer_type == 'adagrad':
|
||||
return tpu_embedding.AdagradParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adagrad_initial_accumulator,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'sgd':
|
||||
return tpu_embedding.StochasticGradientDescentParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'adam':
|
||||
return tpu_embedding.AdamParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adam_parameters.beta1,
|
||||
embedding_config_spec.adam_parameters.beta2,
|
||||
embedding_config_spec.adam_parameters.epsilon,
|
||||
use_gradient_accumulation=embedding_config_spec
|
||||
.use_gradient_accumulation)
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
|
||||
|
||||
AdamParameters = collections.namedtuple('AdamParameters',
|
||||
['beta1', 'beta2', 'epsilon'])
|
||||
|
||||
|
||||
# TODO(shizhiw): Improve the API to support more optimizer parameters in API.
|
||||
class EmbeddingConfigSpec(
|
||||
collections.namedtuple('EmbeddingConfigSpec', [
|
||||
'feature_columns', 'learning_rate', 'optimizer_type',
|
||||
'adagrad_initial_accumulator', 'clipping_limit',
|
||||
'use_gradient_accumulation', 'adam_parameters'
|
||||
])):
|
||||
"""Class to keep track of embedding config specification."""
|
||||
|
||||
def __new__(cls,
|
||||
feature_columns,
|
||||
learning_rate,
|
||||
optimizer_type='adagrad',
|
||||
adagrad_initial_accumulator=None,
|
||||
clipping_limit=None,
|
||||
use_gradient_accumulation=False,
|
||||
adam_parameters=None):
|
||||
"""Creates an EmbeddingConfigSpec instance.
|
||||
|
||||
Args:
|
||||
feature_columns: All `FeatureColumn`s used by model.
|
||||
learning_rate: embedding optimizer learning rate.
|
||||
optimizer_type: (String) Name of the optimizer for embedding gradients
|
||||
updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default
|
||||
value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam'
|
||||
(`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will
|
||||
be applied to all embedding variables specified by `feature_columns`.
|
||||
adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when
|
||||
optimizer_type is 'adagrad'. Default is `0.1`.
|
||||
clipping_limit: (Optional) Clipping limit (absolute value).
|
||||
use_gradient_accumulation: (Experimental) Whether to accumulate the
|
||||
gradients across TPU embedding mini-batches. Gradient accumulation does
|
||||
not affect SGD and therefore this is applicable only for Adagrad.
|
||||
adam_parameters: AdamParameters. Used when optimizer_type is 'adam'.
|
||||
Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon.
|
||||
|
||||
Returns:
|
||||
An EmbeddingConfigSpec instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the feature_columns are not specified.
|
||||
TypeError: If the feature columns are not of ths correct type (one of
|
||||
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
|
||||
_EMBEDDING_COLUMN_CLASSES).
|
||||
ValueError: If use_gradient_accumulation is True for SGD.
|
||||
ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or
|
||||
"adam".
|
||||
"""
|
||||
if not feature_columns:
|
||||
raise ValueError('`feature_columns` cannot be `None` or empty.')
|
||||
|
||||
# It is unknown at this moment, whether the TPUEstimator is running in CPU
|
||||
# or TPU mode. So allow non-TPU embedding columns also.
|
||||
supported_classes = tuple(
|
||||
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
|
||||
list(_EMBEDDING_COLUMN_CLASSES))
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, supported_classes):
|
||||
raise TypeError(
|
||||
'All feature columns must be supported types in {}. Got {}'.format(
|
||||
supported_classes, type(column)))
|
||||
|
||||
if optimizer_type == 'adagrad':
|
||||
if adagrad_initial_accumulator is None:
|
||||
adagrad_initial_accumulator = 0.1
|
||||
if adagrad_initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
elif optimizer_type == 'sgd':
|
||||
if use_gradient_accumulation:
|
||||
raise ValueError('Gradient accumulation makes sense for Adagrad only.')
|
||||
elif optimizer_type == 'adam':
|
||||
if adam_parameters is None:
|
||||
adam_parameters = AdamParameters(0.9, 0.999, 1e-8)
|
||||
if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta1))
|
||||
if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.:
|
||||
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta2))
|
||||
if adam_parameters.epsilon <= 0.:
|
||||
raise ValueError('epsilon must be positive; got {}.'.format(
|
||||
adam_parameters.epsilon))
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
|
||||
return super(EmbeddingConfigSpec, cls).__new__(
|
||||
cls,
|
||||
feature_columns=feature_columns,
|
||||
learning_rate=learning_rate,
|
||||
optimizer_type=optimizer_type,
|
||||
adagrad_initial_accumulator=adagrad_initial_accumulator,
|
||||
clipping_limit=clipping_limit,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
adam_parameters=adam_parameters)
|
||||
|
||||
|
||||
class EmbeddingConfig(object):
|
||||
"""This is the internal immutable object for embedding config.
|
||||
|
||||
`_EmbeddingConfig` is responsible to _translate_ user provided
|
||||
`EmbeddingConfigSpec` to internal data structures, mostly constructor
|
||||
arguments of `TPUEmbedding`.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
|
||||
num_hosts, num_cores, master):
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._num_hosts = num_hosts
|
||||
self._num_cores = num_cores
|
||||
self._master = master
|
||||
|
||||
self._table_to_config_dict, self._feature_to_table_dict = (
|
||||
get_tpu_embedding_config_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
self._optimization_parameters = _get_tpu_embedding_optimization_parameters(
|
||||
self._embedding_config_spec)
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
|
||||
def has_embedding_tables(self):
|
||||
return bool(self._table_to_config_dict)
|
||||
|
||||
def _create_tpu_embedding(self, mode):
|
||||
"""Create tpu_embedding.TPUEmbedding based on mode."""
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
batch_size = self._train_batch_size
|
||||
else:
|
||||
batch_size = self._eval_batch_size
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
tpu_embedding_mode = tpu_embedding.TRAINING
|
||||
elif (mode == model_fn_lib.ModeKeys.EVAL or
|
||||
mode == model_fn_lib.ModeKeys.PREDICT):
|
||||
tpu_embedding_mode = tpu_embedding.INFERENCE
|
||||
else:
|
||||
raise ValueError('Mode {} is not supported.'.format(mode))
|
||||
|
||||
tpu_embedding_ = tpu_embedding.TPUEmbedding(
|
||||
self._table_to_config_dict,
|
||||
self._feature_to_table_dict,
|
||||
batch_size,
|
||||
tpu_embedding_mode,
|
||||
self._master,
|
||||
self._optimization_parameters,
|
||||
)
|
||||
return tpu_embedding_
|
||||
|
||||
def get_tpu_embedding(self, mode):
|
||||
if mode not in self._mode_to_tpu_embedding_dict:
|
||||
self._mode_to_tpu_embedding_dict[mode] = (
|
||||
self._create_tpu_embedding(mode))
|
||||
return self._mode_to_tpu_embedding_dict[mode]
|
||||
|
||||
|
||||
def split_inputs(ctx, features, labels):
|
||||
"""Splits the dense and sparse tensors inside the features and labels."""
|
||||
sparse_features = collections.OrderedDict()
|
||||
if ctx.embedding_config:
|
||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||
for feature_key in tpu_embedding_.feature_to_table_dict:
|
||||
sparse_features[feature_key] = features.pop(feature_key)
|
||||
|
||||
return features, labels, sparse_features
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,212 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Hook for asynchronous checkpointing.
|
||||
|
||||
This hook dispatches checkpoint writing operations in a separate thread to
|
||||
allow execution to continue on the main thread.
|
||||
"""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tensorflow.core.util.event_pb2 import SessionLog
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
|
||||
|
||||
class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
"""Saves checkpoints every N steps or seconds."""
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_dir,
|
||||
save_secs=None,
|
||||
save_steps=None,
|
||||
saver=None,
|
||||
checkpoint_basename="model.ckpt",
|
||||
scaffold=None,
|
||||
listeners=None):
|
||||
"""Initializes a `CheckpointSaverHook`.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: `str`, base directory for the checkpoint files.
|
||||
save_secs: `int`, save every N secs.
|
||||
save_steps: `int`, save every N steps.
|
||||
saver: `Saver` object, used for saving.
|
||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||
scaffold: `Scaffold`, use to get saver object.
|
||||
listeners: List of `CheckpointSaverListener` subclass instances. Used for
|
||||
callbacks that run immediately before or after this hook saves the
|
||||
checkpoint.
|
||||
|
||||
Raises:
|
||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||
ValueError: At most one of `saver` or `scaffold` should be set.
|
||||
"""
|
||||
logging.info("Create AsyncCheckpointSaverHook.")
|
||||
if saver is not None and scaffold is not None:
|
||||
raise ValueError("You cannot provide both saver and scaffold.")
|
||||
self._saver = saver
|
||||
self._save_thread = None
|
||||
self._write_graph_thread = None
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||
self._scaffold = scaffold
|
||||
self._timer = basic_session_run_hooks.SecondOrStepTimer(
|
||||
every_secs=save_secs, every_steps=save_steps)
|
||||
self._listeners = listeners or []
|
||||
self._steps_per_run = 1
|
||||
self._summary_writer = None
|
||||
self._global_step_tensor = None
|
||||
|
||||
self._last_checkpoint_step = None
|
||||
|
||||
def _set_steps_per_run(self, steps_per_run):
|
||||
self._steps_per_run = steps_per_run
|
||||
|
||||
def begin(self):
|
||||
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
|
||||
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
|
||||
if self._global_step_tensor is None:
|
||||
raise RuntimeError(
|
||||
"Global step should be created to use CheckpointSaverHook.")
|
||||
for l in self._listeners:
|
||||
l.begin()
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
global_step = session.run(self._global_step_tensor)
|
||||
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
# We cannot do this in begin, since we let other hooks to change graph and
|
||||
# add variables in begin. Graph is finalized after all begin calls.
|
||||
def _write_graph_fn(self):
|
||||
training_util.write_graph(
|
||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir, "graph.pbtxt")
|
||||
self._write_graph_thread = threading.Thread(target=_write_graph_fn,
|
||||
args=[self])
|
||||
self._write_graph_thread.start()
|
||||
|
||||
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
||||
graph = ops.get_default_graph()
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
|
||||
self._summary_writer.add_graph(graph)
|
||||
self._summary_writer.add_meta_graph(meta_graph_def)
|
||||
# The checkpoint saved here is the state at step "global_step".
|
||||
self._save(session, global_step)
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
return SessionRunArgs(self._global_step_tensor)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
logging.info("Triggering checkpoint. %s", global_step)
|
||||
if self._save(run_context.session, global_step):
|
||||
run_context.request_stop()
|
||||
|
||||
def end(self, session):
|
||||
if self._save_thread:
|
||||
logging.info("Waiting for any pending checkpoints to finish.")
|
||||
self._save_thread.join()
|
||||
if self._write_graph_thread:
|
||||
logging.info("Waiting for any pending write_graph to finish.")
|
||||
self._write_graph_thread.join()
|
||||
|
||||
last_step = session.run(self._global_step_tensor)
|
||||
|
||||
if self._last_checkpoint_step != last_step:
|
||||
self._save(session, last_step, asynchronous=False)
|
||||
|
||||
for l in self._listeners:
|
||||
l.end(session, last_step)
|
||||
|
||||
def _save(self, session, step, asynchronous=True):
|
||||
"""Saves the latest checkpoint, returns should_stop."""
|
||||
|
||||
# Skip saving on step 0
|
||||
if step == 0:
|
||||
return
|
||||
|
||||
def _save_fn():
|
||||
"""Run the saver process."""
|
||||
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
|
||||
|
||||
start_time = time.time()
|
||||
for l in self._listeners:
|
||||
l.before_save(session, step)
|
||||
|
||||
self._get_saver().save(session, self._save_path, global_step=step)
|
||||
self._summary_writer.add_session_log(
|
||||
SessionLog(
|
||||
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
|
||||
step)
|
||||
|
||||
for l in self._listeners:
|
||||
l.after_save(session, step)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info("Checkpoint actual writing time: (%.3f sec)",
|
||||
end_time - start_time)
|
||||
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
|
||||
|
||||
if not asynchronous:
|
||||
self._last_checkpoint_step = step
|
||||
_save_fn()
|
||||
return
|
||||
|
||||
if self._save_thread is not None:
|
||||
self._save_thread.join(timeout=0.1)
|
||||
if self._save_thread.is_alive():
|
||||
logging.info("Saver thread still in progress, skipping checkpoint.")
|
||||
return
|
||||
|
||||
self._last_checkpoint_step = step
|
||||
self._save_thread = threading.Thread(target=_save_fn)
|
||||
self._save_thread.start()
|
||||
|
||||
def _get_saver(self):
|
||||
if self._saver is not None:
|
||||
return self._saver
|
||||
elif self._scaffold is not None:
|
||||
return self._scaffold.saver
|
||||
|
||||
# Get saver from the SAVERS collection if present.
|
||||
collection_key = ops.GraphKeys.SAVERS
|
||||
savers = ops.get_collection(collection_key)
|
||||
if not savers:
|
||||
raise RuntimeError(
|
||||
"No items in collection {}. Please add a saver to the collection "
|
||||
"or provide a saver or scaffold.".format(collection_key))
|
||||
elif len(savers) > 1:
|
||||
raise RuntimeError(
|
||||
"More than one item in collection {}. "
|
||||
"Please indicate which one to use by passing it to the constructor."
|
||||
.format(collection_key))
|
||||
|
||||
self._saver = savers[0]
|
||||
return savers[0]
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.async_checkpoint import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,77 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper context for running models with bfloat16."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def _get_custom_getter():
|
||||
"""Returns a custom getter that this class's methods must be called under.
|
||||
|
||||
All methods of this class must be called under a variable scope that was
|
||||
passed this custom getter. Example:
|
||||
|
||||
```python
|
||||
network = ConvNetBuilder(...)
|
||||
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
|
||||
network.conv(...)
|
||||
# Call more methods of network here
|
||||
```
|
||||
|
||||
Currently, this custom getter only does anything if self.use_tf_layers is
|
||||
True. In that case, it causes variables to be stored as dtype
|
||||
self.variable_type, then casted to the requested dtype, instead of directly
|
||||
storing the variable as the requested dtype.
|
||||
"""
|
||||
|
||||
def inner_custom_getter(getter, *args, **kwargs):
|
||||
"""Custom getter that forces variables to have type self.variable_type."""
|
||||
cast_to_bfloat16 = False
|
||||
requested_dtype = kwargs['dtype']
|
||||
if requested_dtype == dtypes.bfloat16:
|
||||
# Only change the variable dtype if doing so does not decrease variable
|
||||
# precision.
|
||||
kwargs['dtype'] = dtypes.float32
|
||||
cast_to_bfloat16 = True
|
||||
var = getter(*args, **kwargs)
|
||||
# This if statement is needed to guard the cast, because batch norm
|
||||
# assigns directly to the return value of this custom getter. The cast
|
||||
# makes the return value not a variable so it cannot be assigned. Batch
|
||||
# norm variables are always in fp32 so this if statement is never
|
||||
# triggered for them.
|
||||
if cast_to_bfloat16:
|
||||
var = math_ops.cast(var, dtypes.bfloat16)
|
||||
return var
|
||||
|
||||
return inner_custom_getter
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def bfloat16_scope():
|
||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||
|
||||
This enables variables to be read as bfloat16 type when using get_variable.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
'', custom_getter=_get_custom_getter()) as varscope:
|
||||
yield varscope
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.bfloat16 import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,191 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Library of Cloud TPU helper functions for data loading."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
|
||||
|
||||
def _TextLineDataset(filename):
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def _TFRecordDataset(filename):
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
_FILETYPE_MAP = {
|
||||
'tfrecord': _TFRecordDataset,
|
||||
'textline': _TextLineDataset,
|
||||
'text': _TextLineDataset,
|
||||
}
|
||||
|
||||
|
||||
def StreamingFilesDataset(files,
|
||||
filetype=None,
|
||||
file_reader_job=None,
|
||||
worker_job=None,
|
||||
num_epochs=None,
|
||||
filename_shuffle_buffer_size=None,
|
||||
num_parallel_reads=None,
|
||||
batch_transfer_size=None,
|
||||
sloppy=None):
|
||||
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
||||
|
||||
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
||||
files local to your GCE VM. In order to train using files stored on your local
|
||||
VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset
|
||||
helper to generate a dataset to feed your Cloud TPU with files from your GCE
|
||||
VM.
|
||||
|
||||
The resulting dataset may return an OutOfRangeError if there are no files
|
||||
found as a result of the fileglob expansion.
|
||||
|
||||
Note: StreamingFilesDataset assumes that the session is using a
|
||||
TPUClusterResolver and has therefore a worker and a coordinator job. File
|
||||
loading will be done on the coordinator job.
|
||||
|
||||
Args:
|
||||
files: A string glob to match files, or a `tf.data.Dataset` generating file
|
||||
names.
|
||||
filetype: A string (one of 'tfrecord', or 'textline') or a single-argument
|
||||
TensorFlow function that when given a filename returns a dataset.
|
||||
file_reader_job: An optional string that corresponds to the job that should
|
||||
perform the file reads.
|
||||
worker_job: An optional string that corresponds to the job that should
|
||||
process the tensors (i.e. your GPU or TPU worker).
|
||||
num_epochs: The number of epochs through the training set that should be
|
||||
generated. By default, it will repeat infinitely.
|
||||
filename_shuffle_buffer_size: An optional integer whose value controls the
|
||||
shuffling of the file names. If you would like to read from the files in
|
||||
the same order, set to 0 or False.
|
||||
num_parallel_reads: An optional integer controlling the number of files to
|
||||
read from concurrently. (Set to 1 for no parallelism.)
|
||||
batch_transfer_size: An optional integer controlling the batching used to
|
||||
amortize the remote function invocation overhead. Set to a very large
|
||||
number to increase throughput. Set to a very small number to reduce memory
|
||||
consumption. Set to False to skip batching.
|
||||
sloppy: (Optional.) If `False`, read input data while maintaining a
|
||||
deterministic order. (This may have significant performance impacts.)
|
||||
sloppy defaults to: True.
|
||||
Returns:
|
||||
A `tf.data.Dataset` with an infinite stream of elements generated by a
|
||||
parallel interleaving of the set of files matched (or generated) by `files`
|
||||
with a type is the output of the dataset specified by `filetype`.
|
||||
|
||||
Raises:
|
||||
ValueError: if any argument is not of the expected type.
|
||||
"""
|
||||
if filetype is None:
|
||||
filetype = 'tfrecord'
|
||||
|
||||
if isinstance(filetype, str):
|
||||
if filetype not in _FILETYPE_MAP:
|
||||
raise ValueError('Unexpected filetype: %s' % filetype)
|
||||
reader_fn = _FILETYPE_MAP[filetype]
|
||||
elif callable(filetype):
|
||||
reader_fn = filetype
|
||||
else:
|
||||
raise ValueError('filetype should be a string or a callable')
|
||||
|
||||
file_reader_job = file_reader_job or 'coordinator'
|
||||
|
||||
worker_job = worker_job or 'worker'
|
||||
|
||||
if filename_shuffle_buffer_size is None:
|
||||
filename_shuffle_buffer_size = 4096
|
||||
|
||||
num_parallel_reads = num_parallel_reads or 8
|
||||
|
||||
if batch_transfer_size is None:
|
||||
batch_transfer_size = 256
|
||||
|
||||
if sloppy is None:
|
||||
sloppy = True
|
||||
|
||||
with ops.device('/job:%s' % file_reader_job):
|
||||
if isinstance(files, str):
|
||||
source_dataset = dataset_ops.Dataset.list_files(files)
|
||||
elif isinstance(files, dataset_ops.DatasetV2):
|
||||
source_dataset = files
|
||||
else:
|
||||
raise ValueError('files was not a string or a dataset: %s' % files)
|
||||
|
||||
if filename_shuffle_buffer_size:
|
||||
source_dataset = source_dataset.shuffle(
|
||||
buffer_size=filename_shuffle_buffer_size)
|
||||
|
||||
source_dataset = source_dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
|
||||
|
||||
source_dataset = source_dataset.repeat(num_epochs)
|
||||
|
||||
if batch_transfer_size:
|
||||
source_dataset = source_dataset.batch(batch_transfer_size)
|
||||
|
||||
source_dataset = source_dataset.prefetch(1)
|
||||
|
||||
source_iterator = dataset_ops.make_one_shot_iterator(source_dataset)
|
||||
source_handle = source_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def LoadingFunc(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, source_dataset.output_types, source_dataset.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
def MapFn(unused_input):
|
||||
if isinstance(source_dataset.output_types, dtypes.DType):
|
||||
output_types = [source_dataset.output_types]
|
||||
elif isinstance(source_dataset.output_types, (list, tuple)):
|
||||
output_types = source_dataset.output_types
|
||||
else:
|
||||
raise ValueError('source dataset has invalid output types')
|
||||
remote_calls = functional_ops.remote_call(
|
||||
args=[source_handle],
|
||||
Tout=output_types,
|
||||
f=LoadingFunc,
|
||||
target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
|
||||
if len(remote_calls) == 1:
|
||||
return remote_calls[0]
|
||||
else:
|
||||
return remote_calls
|
||||
|
||||
with ops.device('/job:%s' % worker_job):
|
||||
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
|
||||
MapFn, num_parallel_calls=4 if sloppy else None)
|
||||
output_dataset = output_dataset.prefetch(1)
|
||||
|
||||
if batch_transfer_size:
|
||||
# Undo the batching used during the transfer.
|
||||
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
|
||||
|
||||
return output_dataset
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.datasets import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,313 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Library of TPU helper functions."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu.topology import Topology
|
||||
|
||||
|
||||
SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]]
|
||||
|
||||
|
||||
def _compute_task_and_cores_to_replicas(core_assignment, topology):
|
||||
"""Computes a nested dict which maps task and logical core to replicas."""
|
||||
task_and_cores_to_replicas = {}
|
||||
for replica in xrange(core_assignment.shape[0]):
|
||||
for logical_core in xrange(core_assignment.shape[1]):
|
||||
coordinates = core_assignment[replica, logical_core, :]
|
||||
task_id = topology.task_ordinal_at_coordinates(coordinates)
|
||||
if task_id not in task_and_cores_to_replicas:
|
||||
task_and_cores_to_replicas[task_id] = {}
|
||||
if logical_core not in task_and_cores_to_replicas[task_id]:
|
||||
task_and_cores_to_replicas[task_id][logical_core] = set()
|
||||
|
||||
task_and_cores_to_replicas[task_id][logical_core].add(replica)
|
||||
|
||||
task_to_sorted_replica_id = {}
|
||||
|
||||
for task, core_to_replicas in task_and_cores_to_replicas.items():
|
||||
core_to_sorted_replicas = {}
|
||||
for core, replicas in core_to_replicas.items():
|
||||
core_to_sorted_replicas[core] = sorted(replicas)
|
||||
|
||||
task_to_sorted_replica_id[task] = core_to_sorted_replicas
|
||||
return task_to_sorted_replica_id
|
||||
|
||||
|
||||
class DeviceAssignment(object):
|
||||
"""Mapping from logical cores in a computation to the physical TPU topology.
|
||||
|
||||
Prefer to use the `device_assignment()` helper to construct a
|
||||
`DeviceAssignment`; it is easier if less flexible than constructing a
|
||||
`DeviceAssignment` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, topology, core_assignment):
|
||||
"""Constructs a `DeviceAssignment` object.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the physical TPU topology.
|
||||
core_assignment: A logical to physical core mapping, represented as a
|
||||
rank 3 numpy array. See the description of the `core_assignment`
|
||||
property for more details.
|
||||
|
||||
Raises:
|
||||
ValueError: If `topology` is not `Topology` object.
|
||||
ValueError: If `core_assignment` is not a rank 3 numpy array.
|
||||
"""
|
||||
if not isinstance(topology, Topology):
|
||||
raise ValueError("topology must be a Topology object, got {}".format(
|
||||
type(topology)))
|
||||
core_assignment = np.asarray(core_assignment, dtype=np.int32)
|
||||
|
||||
self._topology = topology
|
||||
|
||||
if core_assignment.ndim != 3:
|
||||
raise ValueError("core_assignment must be a rank 3 numpy array, "
|
||||
"got shape {}".format(core_assignment.shape))
|
||||
|
||||
self._num_replicas = core_assignment.shape[0]
|
||||
self._num_cores_per_replica = core_assignment.shape[1]
|
||||
|
||||
if core_assignment.shape[-1] != topology.mesh_rank:
|
||||
raise ValueError(
|
||||
"minor dimension of core_assignment must have size equal to topology "
|
||||
"rank ({}), got shape {}".format(topology.mesh_rank,
|
||||
core_assignment.shape))
|
||||
|
||||
self._core_assignment = core_assignment
|
||||
self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
|
||||
self._core_assignment, topology)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
"""A `Topology` that describes the TPU topology."""
|
||||
return self._topology
|
||||
|
||||
@property
|
||||
def num_cores_per_replica(self):
|
||||
"""The number of cores per replica."""
|
||||
return self._num_cores_per_replica
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
"""The number of replicas of the computation."""
|
||||
return self._num_replicas
|
||||
|
||||
@property
|
||||
def core_assignment(self):
|
||||
"""The logical to physical core mapping.
|
||||
|
||||
Returns:
|
||||
An integer numpy array of rank 3, with shape
|
||||
`[num_replicas, num_cores_per_replica, topology_rank]`. Maps
|
||||
(replica, logical core) pairs to physical topology coordinates.
|
||||
"""
|
||||
return self._core_assignment
|
||||
|
||||
def coordinates(self, replica, logical_core):
|
||||
"""Returns the physical topology coordinates of a logical core."""
|
||||
return tuple(self.core_assignment[replica, logical_core, :])
|
||||
|
||||
def lookup_replicas(self, task_id, logical_core):
|
||||
"""Lookup replica ids by task number and logical core.
|
||||
|
||||
Args:
|
||||
task_id: TensorFlow task number.
|
||||
logical_core: An integer, identifying a logical core.
|
||||
Returns:
|
||||
A sorted list of the replicas that are attached to that task and
|
||||
logical_core.
|
||||
Raises:
|
||||
ValueError: If no replica exists in the task which contains the logical
|
||||
core.
|
||||
"""
|
||||
try:
|
||||
return self._task_and_cores_to_replicas[task_id][logical_core]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Can not find any replica in task: {} contains logical_core: {} ".
|
||||
format(task_id, logical_core))
|
||||
|
||||
def tpu_ordinal(self, replica=0, logical_core=0):
|
||||
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
||||
|
||||
def host_device(self, replica=0, logical_core=0, job=None):
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
def tpu_device(self, replica=0, logical_core=0, job=None):
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
|
||||
def device_assignment(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1):
|
||||
"""Computes a device_assignment of a computation across a TPU topology.
|
||||
|
||||
Attempts to choose a compact grid of cores for locality.
|
||||
|
||||
Returns a `DeviceAssignment` that describes the cores in the topology assigned
|
||||
to each core of each replica.
|
||||
|
||||
`computation_shape` and `computation_stride` values should be powers of 2 for
|
||||
optimal packing.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the TPU cluster topology.
|
||||
To obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
`initialize_system` using `Session.run`. Either a serialized
|
||||
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
|
||||
computation_shape: A rank 1 int32 numpy array with size equal to the
|
||||
topology rank, describing the shape of the computation's block of cores.
|
||||
If None, the `computation_shape` is `[1] * topology_rank`.
|
||||
computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
|
||||
describing the inter-core spacing of the `computation_shape` cores in the
|
||||
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
|
||||
num_replicas: The number of computation replicas to run. The replicas will
|
||||
be packed into the free spaces of the topology.
|
||||
|
||||
Returns:
|
||||
A DeviceAssignment object, which describes the mapping between the logical
|
||||
cores in each computation replica and the physical cores in the TPU
|
||||
topology.
|
||||
|
||||
Raises:
|
||||
ValueError: If `topology` is not a valid `Topology` object.
|
||||
ValueError: If `computation_shape` or `computation_stride` are not 1D int32
|
||||
numpy arrays with shape [3] where all values are positive.
|
||||
ValueError: If computation's replicas cannot fit into the TPU topology.
|
||||
"""
|
||||
# Deserialize the Topology proto, if it is a string.
|
||||
if isinstance(topology, bytes):
|
||||
topology = Topology(serialized=topology)
|
||||
|
||||
if not isinstance(topology, Topology):
|
||||
raise ValueError("`topology` is not a Topology object; got {}".format(
|
||||
type(topology)))
|
||||
|
||||
topology_rank = len(topology.mesh_shape)
|
||||
mesh_shape = topology.mesh_shape
|
||||
if computation_shape is None:
|
||||
computation_shape = np.array([1] * topology_rank, dtype=np.int32)
|
||||
else:
|
||||
computation_shape = np.asarray(computation_shape, dtype=np.int32)
|
||||
|
||||
if computation_stride is None:
|
||||
computation_stride = np.array([1] * topology_rank, dtype=np.int32)
|
||||
else:
|
||||
computation_stride = np.asarray(computation_stride, dtype=np.int32)
|
||||
|
||||
if computation_shape.shape != (topology_rank,):
|
||||
raise ValueError("computation_shape must have shape [{}]; got {}".format(
|
||||
topology_rank, computation_shape.shape))
|
||||
if computation_stride.shape != (topology_rank,):
|
||||
raise ValueError("computation_stride must have shape [{}]; got {}".format(
|
||||
topology_rank, computation_stride.shape))
|
||||
|
||||
if any(computation_shape < 1):
|
||||
raise ValueError(
|
||||
"computation_shape must be positive; got computation_shape={}".format(
|
||||
computation_shape))
|
||||
if any(computation_stride < 1):
|
||||
raise ValueError(
|
||||
"computation_stride must be positive; got computation_stride={}".format(
|
||||
computation_stride))
|
||||
|
||||
# Computes the physical size of one computation instance.
|
||||
computation_footprint = computation_shape * computation_stride
|
||||
if any(computation_footprint > mesh_shape):
|
||||
raise ValueError(
|
||||
"computation footprint {} does not fit in TPU topology shape {}".format(
|
||||
computation_footprint, mesh_shape))
|
||||
|
||||
# Computes how many copies of the computation footprint fit in the mesh.
|
||||
block_counts = mesh_shape // computation_footprint
|
||||
|
||||
replica_counts = block_counts * computation_stride
|
||||
max_replicas = np.prod(replica_counts)
|
||||
if num_replicas > max_replicas:
|
||||
raise ValueError(
|
||||
"requested {} replicas but only {} replicas with shape {} and "
|
||||
"computation_stride {} fit in a TPU mesh of shape {}".format(
|
||||
num_replicas, max_replicas, computation_shape, computation_stride,
|
||||
mesh_shape))
|
||||
|
||||
def ceil_of_ratio(n, m):
|
||||
return (n + m - 1) // m
|
||||
|
||||
replica_shape = [0] * topology_rank
|
||||
if num_replicas > 0:
|
||||
remaining_replicas = num_replicas
|
||||
remaining_dims = topology_rank
|
||||
|
||||
# Choose dimensions as close to an equal cube as possible, in order of
|
||||
# increasing dimension size. By visiting dimensions in increasing size, we
|
||||
# assign the most constrained dimension first, so we won't make infeasible
|
||||
# choices.
|
||||
#
|
||||
# As a secondary sort order, visit the dimensions in reverse order. This
|
||||
# means we try to use both cores on the same chip in preference to two cores
|
||||
# on different chips.
|
||||
for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
|
||||
i = -ni
|
||||
target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
|
||||
replica_shape[i] = min(target_size, x)
|
||||
remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
|
||||
remaining_dims -= 1
|
||||
|
||||
assert remaining_replicas == 1 and remaining_dims == 0
|
||||
|
||||
# Assigns an offset to each replica such that no two replicas overlap.
|
||||
replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
|
||||
for replica in xrange(num_replicas):
|
||||
# Chooses a replica number in each axis.
|
||||
t = replica
|
||||
pos = []
|
||||
for dim in replica_shape[::-1]:
|
||||
pos.append(t % dim)
|
||||
t //= dim
|
||||
replica_pos = np.array(pos[::-1], dtype=np.int32)
|
||||
|
||||
# Determines where that replica starts in each axis.
|
||||
outer = replica_pos // computation_stride
|
||||
inner = replica_pos % computation_stride
|
||||
replica_offsets[replica, :] = outer * computation_footprint + inner
|
||||
|
||||
# Computes a complete logical core -> physical core mapping for each replica.
|
||||
indices = [
|
||||
np.arange(0, computation_shape[i] * computation_stride[i],
|
||||
computation_stride[i]) for i in xrange(topology_rank)
|
||||
]
|
||||
indices = np.concatenate(
|
||||
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
|
||||
axis=-1)
|
||||
indices = indices.reshape((-1, topology_rank))
|
||||
assignment = indices + replica_offsets[:, np.newaxis, :]
|
||||
return DeviceAssignment(topology, core_assignment=assignment)
|
||||
# pylint: disable=wildcard-import,unused-import,redefined-builtin
|
||||
from tensorflow.python.tpu.device_assignment import *
|
||||
# pylint: enable=wildcard-import,unused-import,redefined-builtin
|
||||
|
@ -1,132 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""ErrorRendezvous handler for collecting errors from multiple threads."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_UNINTERESTING_ERRORS = (errors.CancelledError,)
|
||||
|
||||
|
||||
class ErrorRendezvous(object):
|
||||
"""Resolve errors from multiple threads during TPU execution.
|
||||
|
||||
TPU errors can occur on the infeed or outfeed threads as well as the main
|
||||
training thread.
|
||||
|
||||
Depending on which thread "wins" and receives the session error first, we may
|
||||
end up showing users a confusing and non-actionable error message (session
|
||||
cancelled) instead of a root cause (e.g. a bad filename).
|
||||
|
||||
The rendezvous object provides a location to capture these errors until all
|
||||
threads terminate. At that point we can choose the most informative error
|
||||
to report.
|
||||
"""
|
||||
|
||||
def __init__(self, num_sources):
|
||||
# string -> (message, traceback)
|
||||
self._errors = {}
|
||||
self._num_sources = num_sources
|
||||
self._session_cancel_timer = None
|
||||
|
||||
def record_error(self, source, exc_info, session=None):
|
||||
"""Report an exception from the given source.
|
||||
|
||||
If a session is passed, a timer will be registered to close it after a few
|
||||
seconds. This is necessary to ensure the main training loop does not hang
|
||||
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
|
||||
interesting error from another thread to propagate.
|
||||
|
||||
Args:
|
||||
source: string, source of the error
|
||||
exc_info: Output from `sys.exc_info` (type, value, traceback)
|
||||
session: Session to close after delay.
|
||||
"""
|
||||
_, value, _ = exc_info
|
||||
self._errors[source] = exc_info
|
||||
logging.info('Error recorded from %s: %s', source, value)
|
||||
|
||||
if session is not None and self._session_cancel_timer is None:
|
||||
|
||||
def _cancel_session():
|
||||
time.sleep(5)
|
||||
try:
|
||||
session.close()
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
|
||||
self._session_cancel_timer.daemon = True
|
||||
self._session_cancel_timer.start()
|
||||
|
||||
def record_done(self, source):
|
||||
"""Mark execution source `source` as done.
|
||||
|
||||
If an error was originally reported from `source` it is left intact.
|
||||
|
||||
Args:
|
||||
source: `str`, source being recorded
|
||||
"""
|
||||
logging.info('%s marked as finished', source)
|
||||
if source not in self._errors:
|
||||
self._errors[source] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def catch_errors(self, source, session=None):
|
||||
"""Context manager to report any errors within a block."""
|
||||
try:
|
||||
yield
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.record_error(source, sys.exc_info(), session)
|
||||
|
||||
def raise_errors(self, timeout_sec=0):
|
||||
"""Wait for up to `timeout` seconds for all error sources to finish.
|
||||
|
||||
Preferentially raise "interesting" errors (errors not in the
|
||||
_UNINTERESTING_ERRORS) set.
|
||||
|
||||
Args:
|
||||
timeout_sec: Seconds to wait for other error sources.
|
||||
"""
|
||||
for _ in range(timeout_sec):
|
||||
if len(self._errors) == self._num_sources:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
|
||||
|
||||
# First check for any interesting errors, then fall back on the session
|
||||
# cancelled errors etc.
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
if isinstance(value, _UNINTERESTING_ERRORS):
|
||||
continue
|
||||
else:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
||||
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.error_handling import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,435 +1,30 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU Feature Column Library."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.feature_column import feature_column as fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as fc_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
|
||||
_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
|
||||
fc._VocabularyFileCategoricalColumn,
|
||||
fc._VocabularyListCategoricalColumn,
|
||||
fc._WeightedCategoricalColumn,
|
||||
fc_lib.IdentityCategoricalColumn,
|
||||
fc_lib.VocabularyFileCategoricalColumn,
|
||||
fc_lib.VocabularyListCategoricalColumn,
|
||||
fc_lib.WeightedCategoricalColumn)
|
||||
|
||||
|
||||
def embedding_column(categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None):
|
||||
"""TPU embedding_column for `tf.feature_column.embedding_column`.
|
||||
|
||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||
version. The following args available for the non-TPU version are NOT
|
||||
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
|
||||
|
||||
Args:
|
||||
categorical_column: A categorical_column returned from
|
||||
categorical_column_with_identity, weighted_categorical_column,
|
||||
categorical_column_with_vocabulary_list or
|
||||
categorical_column_with_vocabulary_file.
|
||||
dimension: An integer specifying dimension of the embedding, must be > 0.
|
||||
combiner: A string specifying how to reduce if there are multiple entries
|
||||
in a single row. For more information, see
|
||||
`tf.feature_column.embedding_column`.
|
||||
initializer: A variable initializer function to be used in embedding
|
||||
variable initialization. If not specified, defaults to
|
||||
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
|
||||
`1/sqrt(dimension)`.
|
||||
|
||||
Returns:
|
||||
A _TPUEmbeddingColumn.
|
||||
|
||||
Raises:
|
||||
ValueError: if `dimension` not > 0.
|
||||
ValueError: if `initializer` is specified but not callable.
|
||||
"""
|
||||
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
|
||||
raise TypeError(
|
||||
'categorical_column for tpu '
|
||||
' embedding_column must be type %s, got %s.' % (' or '.join([
|
||||
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
|
||||
]), type(categorical_column)))
|
||||
if (dimension is None) or (dimension < 1):
|
||||
raise ValueError('Invalid dimension {}.'.format(dimension))
|
||||
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified. '
|
||||
'Embedding of column_name: {}'.format(
|
||||
categorical_column.name))
|
||||
if initializer is None:
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||
|
||||
embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
|
||||
|
||||
def _creator(weight_collections, scope):
|
||||
embedding_column_layer = fc._EmbeddingColumnLayer(
|
||||
embedding_shape=embedding_shape,
|
||||
initializer=initializer,
|
||||
weight_collections=weight_collections,
|
||||
trainable=True,
|
||||
name='embedding_column_layer')
|
||||
return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
|
||||
|
||||
column = _TPUEmbeddingColumn(
|
||||
categorical_column=categorical_column,
|
||||
dimension=dimension,
|
||||
combiner=combiner,
|
||||
layer_creator=_creator,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
# For Embedding column, the initializer is hidden inside the creator Fn, which
|
||||
# is not accessiable later. So, we attach it to a speicial field. Also note
|
||||
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
|
||||
# initializer differently. See shared_embedding_columns for details.
|
||||
column._tpu_initializer = initializer
|
||||
return column
|
||||
|
||||
|
||||
def shared_embedding_columns(categorical_columns,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None):
|
||||
"""List of dense columns that convert from sparse, categorical input."""
|
||||
for categorical_column in categorical_columns:
|
||||
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
|
||||
raise TypeError(
|
||||
'categorical_column for tpu '
|
||||
' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
|
||||
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
|
||||
]), type(categorical_column)))
|
||||
columns = fc_lib.shared_embedding_columns(
|
||||
categorical_columns,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
|
||||
# Use the initializer and shared_embedding_collection_name to create TPU
|
||||
# version
|
||||
initializer = columns[0].initializer
|
||||
shared_embedding_collection_name = columns[0].shared_embedding_collection_name
|
||||
tpu_columns = []
|
||||
|
||||
# Create the state (_SharedEmbeddingColumnLayer) here.
|
||||
for categorical_column in categorical_columns:
|
||||
column = _TPUSharedEmbeddingColumn(
|
||||
categorical_column=categorical_column,
|
||||
dimension=dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
tpu_columns.append(column)
|
||||
|
||||
return tpu_columns
|
||||
|
||||
|
||||
class _TPUBaseEmbeddingColumn(object):
|
||||
"""Base class for TPU Embedding Column."""
|
||||
|
||||
def __init__(self, categorical_column):
|
||||
self._tpu_categorical_column = categorical_column
|
||||
|
||||
def get_combiner(self):
|
||||
"""Returns the embedding combiner."""
|
||||
raise NotImplementedError('not implemented')
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns the embedding table size, tuple of vocab size and dimension."""
|
||||
raise NotImplementedError('not implemented')
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""Returns the feature key name in the features dict."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""Return the key name for weights."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""Returns the embedding variable name.
|
||||
|
||||
Feature key name and embedding variable name are usually one-to-one mapping.
|
||||
But for shared embedding columns, it is many-to-one mapping.
|
||||
"""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_initializer(self):
|
||||
"""Returns the initializer."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
|
||||
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
"""Core Embedding Column."""
|
||||
|
||||
def __new__(cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
layer_creator=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
||||
# are not supported on TPU. They are solely for matching the signature of
|
||||
# __new__ of parent class fc._EmbeddingColumn.
|
||||
return fc._EmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
layer_creator=layer_creator,
|
||||
ckpt_to_load_from=ckpt_to_load_from,
|
||||
tensor_name_in_ckpt=tensor_name_in_ckpt,
|
||||
max_norm=max_norm,
|
||||
trainable=trainable)
|
||||
|
||||
def __init__(self,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
layer_creator=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns num_ids and width."""
|
||||
return (self.categorical_column._num_buckets, self.dimension)
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""get_feature_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.categorical_column.name
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""get_weight_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.weight_feature_key
|
||||
return None
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""get_embedding_var_name."""
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_initializer(self):
|
||||
return self._tpu_initializer
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
if isinstance(
|
||||
self.categorical_column,
|
||||
(
|
||||
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
||||
fc_lib.WeightedCategoricalColumn)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
def host_computation():
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
# TPU mode
|
||||
# Get the embeddings from the LazyBuilder.
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
fc._SharedEmbeddingColumn):
|
||||
"""Core Shared Embedding Column."""
|
||||
|
||||
def __new__(cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
return fc._SharedEmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=ckpt_to_load_from,
|
||||
tensor_name_in_ckpt=tensor_name_in_ckpt,
|
||||
max_norm=max_norm,
|
||||
trainable=trainable)
|
||||
|
||||
def __init__(self,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns num_ids and width."""
|
||||
return (self.categorical_column._num_buckets, self.dimension)
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""get_feature_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.categorical_column.name
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""get_weight_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.weight_feature_key
|
||||
return None
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""get_embedding_var_name."""
|
||||
return self.shared_embedding_collection_name
|
||||
|
||||
def get_initializer(self):
|
||||
return self.initializer
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
if isinstance(
|
||||
self.categorical_column,
|
||||
(
|
||||
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
||||
fc_lib.WeightedCategoricalColumn)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
def host_computation():
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
# TPU mode
|
||||
# Get the embeddings from the LazyBuilder.
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
is_shared_embedding=True)
|
||||
return tensor
|
||||
|
||||
|
||||
def _record_variable_scope_and_name(embedding_var_name,
|
||||
embedding_var_name_in_fc,
|
||||
is_shared_embedding=False):
|
||||
"""Add embedding variable name and scope to collection."""
|
||||
g = ops.get_default_graph()
|
||||
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
|
||||
if not collection:
|
||||
collection.append({})
|
||||
|
||||
var_def_dict = collection[0]
|
||||
|
||||
captured_scope = variable_scope.get_variable_scope()
|
||||
captured_scope_name = captured_scope.name
|
||||
|
||||
if embedding_var_name in var_def_dict:
|
||||
if (var_def_dict[embedding_var_name][0] != captured_scope_name
|
||||
and not is_shared_embedding):
|
||||
raise ValueError(
|
||||
'For embedding var name {}, the variable scope name is different, '
|
||||
'got {}; expected {}'.format(embedding_var_name,
|
||||
captured_scope_name,
|
||||
var_def_dict[embedding_var_name][0]))
|
||||
if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
|
||||
raise ValueError(
|
||||
'For embedding var name {}, the embedding name is different, '
|
||||
'got {}; expected {}'.format(embedding_var_name,
|
||||
embedding_var_name_in_fc,
|
||||
var_def_dict[embedding_var_name][1]))
|
||||
else:
|
||||
var_def_dict[embedding_var_name] = (captured_scope_name,
|
||||
embedding_var_name_in_fc)
|
||||
|
||||
|
||||
def _is_running_on_cpu():
|
||||
"""Returns True if the current context is CPU model."""
|
||||
return tpu_function.get_tpu_context().number_of_shards is None
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.feature_column import *
|
||||
# used by tests
|
||||
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
|
||||
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
|
||||
from tensorflow.python.tpu.feature_column import _TPU_FC_TO_SCOPE
|
||||
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
|
||||
from tensorflow.python.tpu.feature_column import _TPUEmbeddingColumn
|
||||
from tensorflow.python.tpu.feature_column import _TPUSharedEmbeddingColumn
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,23 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Functional operations."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
|
||||
TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.functional import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,438 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Operations for handling session logging and shutdown notifications."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import time
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
_WATCHDOG = None
|
||||
|
||||
|
||||
class CoordinatorShutdownException(Exception):
|
||||
"""Raised when the coordinator needs to shutdown."""
|
||||
pass
|
||||
|
||||
|
||||
def _clone_session(session, graph=None):
|
||||
return session_lib.Session(
|
||||
target=session.sess_str,
|
||||
config=session._config, # pylint: disable=protected-access
|
||||
graph=graph if graph else session.graph)
|
||||
|
||||
|
||||
def _make_heartbeat_op(session, device, request_ph):
|
||||
"""Return a heartbeat op or None if heartbeats are not supported by device."""
|
||||
try:
|
||||
# Test if we can connect in a isolated graph + session
|
||||
with ops.Graph().as_default():
|
||||
with _clone_session(session) as temp_session:
|
||||
with ops.device(device):
|
||||
heartbeat_op = tpu_ops.worker_heartbeat('')
|
||||
options = config_pb2.RunOptions(timeout_in_ms=5000)
|
||||
temp_session.run(heartbeat_op, options=options)
|
||||
except errors.InvalidArgumentError as _:
|
||||
logging.warning('Error running heartbeat on %s', device)
|
||||
return None
|
||||
except errors.DeadlineExceededError as _:
|
||||
logging.warning('Timeout connecting to %s when testing heartbeat', device)
|
||||
return None
|
||||
|
||||
# If we successfully connected and pinged the worker, go ahead and construct
|
||||
# the operation.
|
||||
with ops.device(device):
|
||||
return tpu_ops.worker_heartbeat(request_ph)
|
||||
|
||||
|
||||
class WorkerHeartbeatManager(object):
|
||||
"""Manages the status/heartbeat monitor for a set of workers."""
|
||||
|
||||
def __init__(self, session, devices, heartbeat_ops, request_placeholder):
|
||||
"""Construct a new WorkerHeartbeatManager.
|
||||
|
||||
(Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
|
||||
|
||||
Args:
|
||||
session: `tf.Session`, session to use for heartbeat operations.
|
||||
devices: `list[string]` Set of devices to connect to.
|
||||
heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
|
||||
request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
|
||||
the WorkerHeartbeatRequest protocol buffer.
|
||||
"""
|
||||
self._session = session
|
||||
self._devices = devices
|
||||
self._ops = heartbeat_ops
|
||||
self._request_placeholder = request_placeholder
|
||||
|
||||
@staticmethod
|
||||
def from_devices(session, devices):
|
||||
"""Construct a heartbeat manager for the given devices."""
|
||||
if not devices:
|
||||
logging.error('Trying to create heartbeat manager with no devices?')
|
||||
|
||||
logging.info('Creating heartbeat manager for %s', devices)
|
||||
request_placeholder = array_ops.placeholder(
|
||||
name='worker_heartbeat_request', dtype=dtypes.string)
|
||||
|
||||
heartbeat_ops = []
|
||||
kept_devices = []
|
||||
for device in devices:
|
||||
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
|
||||
if heartbeat_op is not None:
|
||||
kept_devices.append(device)
|
||||
heartbeat_ops.append(heartbeat_op)
|
||||
else:
|
||||
logging.warning('Heartbeat support not available for %s', device)
|
||||
|
||||
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
|
||||
request_placeholder)
|
||||
|
||||
def num_workers(self):
|
||||
return len(self._devices)
|
||||
|
||||
def configure(self, message):
|
||||
"""Configure heartbeat manager for all devices.
|
||||
|
||||
Args:
|
||||
message: `event_pb2.WorkerHeartbeatRequest`
|
||||
Returns: `None`
|
||||
"""
|
||||
logging.info('Configuring worker heartbeat: %s',
|
||||
text_format.MessageToString(message))
|
||||
self._session.run(self._ops,
|
||||
{self._request_placeholder: message.SerializeToString()})
|
||||
|
||||
def ping(self, request=None, timeout_in_ms=5000):
|
||||
"""Ping all workers, returning the parsed status results."""
|
||||
if request is None:
|
||||
request = event_pb2.WorkerHeartbeatRequest()
|
||||
|
||||
options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
|
||||
results = self._session.run(
|
||||
self._ops,
|
||||
feed_dict={self._request_placeholder: request.SerializeToString()},
|
||||
options=options)
|
||||
parsed_results = [
|
||||
event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
|
||||
for res_pb in results
|
||||
]
|
||||
logging.debug('Ping results: %s', parsed_results)
|
||||
return parsed_results
|
||||
|
||||
def lame_workers(self):
|
||||
"""Ping all workers, returning manager containing lame workers (or None)."""
|
||||
ping_results = self.ping()
|
||||
lame_workers = []
|
||||
|
||||
for ping_response, device, op in zip(ping_results, self._devices,
|
||||
self._ops):
|
||||
if ping_response.health_status != event_pb2.OK:
|
||||
lame_workers.append((device, op))
|
||||
|
||||
if not lame_workers:
|
||||
return None
|
||||
|
||||
bad_devices, bad_ops = zip(*lame_workers)
|
||||
return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
|
||||
self._request_placeholder)
|
||||
|
||||
def __repr__(self):
|
||||
return 'HeartbeatManager(%s)' % ','.join(self._devices)
|
||||
|
||||
def shutdown(self, timeout_ms=10000):
|
||||
"""Shutdown all workers after `shutdown_timeout_secs`."""
|
||||
logging.info('Shutting down %s.', self)
|
||||
req = event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
|
||||
self.configure(req)
|
||||
|
||||
# Wait for workers to shutdown. This isn't strictly required
|
||||
# but it avoids triggering multiple checkpoints with the same lame worker.
|
||||
logging.info('Waiting %dms for worker shutdown.', timeout_ms)
|
||||
time.sleep(timeout_ms / 1000)
|
||||
|
||||
|
||||
def all_worker_devices(session):
|
||||
"""Return a list of devices for each worker in the system."""
|
||||
devices = session.list_devices()
|
||||
return [
|
||||
device.name
|
||||
for device in devices
|
||||
if ':CPU:' in device.name and 'coordinator' not in device.name
|
||||
]
|
||||
|
||||
|
||||
class WatchdogManager(threading.Thread):
|
||||
"""Configures worker watchdog timer and handles periodic pings.
|
||||
|
||||
Usage:
|
||||
# Ping workers every minute, shutting down workers if they haven't received
|
||||
# a ping after 1 hour.
|
||||
watchdog_manager = WatchdogManager(
|
||||
ping_interval=60, shutdown_timeout=3600
|
||||
)
|
||||
|
||||
# Use as a context manager, resetting watchdog on context exit:
|
||||
with watchdog_manager:
|
||||
session.run(...)
|
||||
|
||||
# Or setup globally; watchdog will remain active until program exit.
|
||||
watchdog_manager.configure_and_run()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session,
|
||||
devices=None,
|
||||
ping_interval=60,
|
||||
shutdown_timeout=3600):
|
||||
"""Initialize a watchdog manager.
|
||||
|
||||
Args:
|
||||
session: Session connected to worker devices. A cloned session and graph
|
||||
will be created for managing worker pings.
|
||||
devices: Set of devices to monitor. If none, all workers will be
|
||||
monitored.
|
||||
ping_interval: Time, in seconds, between watchdog pings.
|
||||
shutdown_timeout: Time, in seconds, before watchdog timeout.
|
||||
"""
|
||||
threading.Thread.__init__(self)
|
||||
self.ping_interval = ping_interval
|
||||
self.shutdown_timeout = shutdown_timeout
|
||||
self.daemon = True
|
||||
self._config = session._config # pylint: disable=protected-access
|
||||
self._target = session.sess_str
|
||||
self._running = False
|
||||
self._devices = devices
|
||||
|
||||
self._graph = None
|
||||
self._session = None
|
||||
self._worker_manager = None
|
||||
|
||||
def _reset_manager(self):
|
||||
"""Reset the graph, session and worker manager."""
|
||||
self._graph = ops.Graph()
|
||||
self._session = session_lib.Session(
|
||||
target=self._target,
|
||||
graph=self._graph,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
if self._devices is None:
|
||||
self._devices = all_worker_devices(self._session)
|
||||
|
||||
with self._graph.as_default():
|
||||
self._worker_manager = WorkerHeartbeatManager.from_devices(
|
||||
self._session, self._devices)
|
||||
|
||||
self._worker_manager.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(
|
||||
timeout_ms=self.shutdown_timeout * 1000,),
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
|
||||
def configure_and_run(self):
|
||||
logging.info(
|
||||
'Enabling watchdog timer with %d second timeout '
|
||||
'and %d second ping interval.', self.shutdown_timeout,
|
||||
self.ping_interval)
|
||||
self._reset_manager()
|
||||
self._running = True
|
||||
self.start()
|
||||
|
||||
def stop(self):
|
||||
logging.info('Stopping worker watchdog.')
|
||||
self._worker_manager.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
|
||||
shutdown_mode=event_pb2.NOT_CONFIGURED))
|
||||
self._running = False
|
||||
self.join()
|
||||
|
||||
def __enter__(self):
|
||||
self.configure_and_run()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def run(self):
|
||||
# Don't fetch logs or adjust timing: just ping the watchdog.
|
||||
#
|
||||
# If we hit an exception, reset our session as it is likely broken.
|
||||
while self._running:
|
||||
try:
|
||||
self._worker_manager.ping(request=None)
|
||||
time.sleep(self.ping_interval)
|
||||
except errors.OpError as e:
|
||||
# Catch any TF errors that occur so we don't stop sending heartbeats
|
||||
logging.debug('Caught error while sending heartbeat: %s', e)
|
||||
self._reset_manager()
|
||||
|
||||
|
||||
def start_worker_watchdog(session,
|
||||
devices=None,
|
||||
ping_interval=60,
|
||||
shutdown_timeout=3600):
|
||||
"""Start global worker watchdog to shutdown workers on coordinator exit."""
|
||||
global _WATCHDOG
|
||||
if _WATCHDOG is None:
|
||||
# Ensure we can send a few pings before we timeout!
|
||||
ping_interval = min(shutdown_timeout / 10., ping_interval)
|
||||
_WATCHDOG = WatchdogManager(session, devices, ping_interval,
|
||||
shutdown_timeout)
|
||||
_WATCHDOG.configure_and_run()
|
||||
|
||||
|
||||
class GracefulShutdownHook(session_run_hook.SessionRunHook):
|
||||
"""Session hook that watches for shutdown events.
|
||||
|
||||
If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
|
||||
SystemShutdown exception is raised to terminate the main session. If `saver`
|
||||
is None the `SAVERS` collection will be read to find a saver.
|
||||
|
||||
`on_shutdown_hooks` is an optional list of functions that should be called
|
||||
after checkpointing. The function is called with (`run_context`,
|
||||
`all_workers`, `lame_workers`).
|
||||
|
||||
If `heartbeat_group` is not specified, it will default to all CPU workers
|
||||
in the system.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
|
||||
self._saver = saver
|
||||
self._checkpoint_prefix = checkpoint_prefix
|
||||
self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
|
||||
|
||||
# Worker heartbeats are managed independently of the main training graph.
|
||||
self._graph = ops.Graph()
|
||||
self._workers = None
|
||||
self._session = None
|
||||
self._heartbeat_supported = False
|
||||
|
||||
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
|
||||
# N.B. We have to pull the global step here to avoid it being unavailable
|
||||
# at checkpoint time; the graph has been frozen at that point.
|
||||
if training_util.get_global_step() is None and self.saver() is not None:
|
||||
raise ValueError(
|
||||
'Saver defined but no global step. Run `get_or_create_global_step()`'
|
||||
' in your model definition to allow checkpointing.')
|
||||
|
||||
with self._graph.as_default():
|
||||
logging.info('Installing graceful shutdown hook.')
|
||||
self._session = _clone_session(training_session, self._graph)
|
||||
self._workers = WorkerHeartbeatManager.from_devices(
|
||||
self._session, all_worker_devices(self._session))
|
||||
self._heartbeat_supported = self._workers.num_workers() > 0
|
||||
if self._heartbeat_supported:
|
||||
self._workers.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
else:
|
||||
logging.warn(
|
||||
'No workers support hearbeats. Failure handling will be disabled.')
|
||||
|
||||
def saver(self):
|
||||
if self._saver:
|
||||
return self._saver
|
||||
|
||||
savers = ops.get_collection(ops.GraphKeys.SAVERS)
|
||||
if not savers:
|
||||
return None
|
||||
|
||||
if not isinstance(savers, list):
|
||||
return savers
|
||||
|
||||
if len(savers) > 1:
|
||||
logging.error(
|
||||
'Multiple savers in the SAVERS collection. On-demand checkpointing '
|
||||
'will be disabled. Pass an explicit `saver` to the constructor to '
|
||||
'override this behavior.')
|
||||
return None
|
||||
|
||||
return savers[0]
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
del run_values
|
||||
|
||||
if not self._heartbeat_supported:
|
||||
return
|
||||
|
||||
lame_workers = self._workers.lame_workers()
|
||||
if lame_workers:
|
||||
logging.info('ShutdownHook: lame workers found: %s', lame_workers)
|
||||
|
||||
if self.saver():
|
||||
logging.info('ShutdownHook: saving checkpoint to %s',
|
||||
self._checkpoint_prefix)
|
||||
self.saver().save(
|
||||
run_context.session,
|
||||
self._checkpoint_prefix,
|
||||
global_step=training_util.get_global_step(),
|
||||
write_state=True,
|
||||
)
|
||||
else:
|
||||
logging.info('ShutdownHook: no Saver defined.')
|
||||
|
||||
for fn in self._on_shutdown_hooks:
|
||||
fn(run_context, self._workers, lame_workers)
|
||||
|
||||
|
||||
class RestartComputation(object):
|
||||
"""Restart the entire computation.
|
||||
|
||||
This hook shuts down all workers and returns control to the top-level by
|
||||
throwing a CoordinatorShutdownException.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout_ms=10000):
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
def __call__(self, run_context, all_workers, lame_workers):
|
||||
del run_context, lame_workers
|
||||
all_workers.shutdown(timeout_ms=self.timeout_ms)
|
||||
|
||||
logging.info('Terminating coordinator.')
|
||||
raise CoordinatorShutdownException()
|
||||
|
||||
|
||||
class ShutdownLameWorkers(object):
|
||||
"""Shutdown lamed workers.
|
||||
|
||||
Processing will continue normally (typically by waiting for the down
|
||||
workers to be restarted).
|
||||
"""
|
||||
|
||||
def __init__(self, timeout_ms=10000):
|
||||
self.timeout_in_ms = timeout_ms
|
||||
|
||||
def __call__(self, run_context, all_workers, lame_workers):
|
||||
lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.session_support import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,220 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Defines the `Topology` class, that describes a TPU fabric topology."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf.tpu import topology_pb2
|
||||
|
||||
|
||||
def _tpu_device_name(job, task, device):
|
||||
"""Returns the device name for the TPU `device` on `task` of `job`."""
|
||||
if job is None:
|
||||
return "/task:%d/device:TPU:%d" % (task, device)
|
||||
else:
|
||||
return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
|
||||
|
||||
|
||||
def _tpu_host_device_name(job, task):
|
||||
"""Returns the device name for the CPU device on `task` of `job`."""
|
||||
if job is None:
|
||||
return "/task:%d/device:CPU:0" % task
|
||||
else:
|
||||
return "/job:%s/task:%d/device:CPU:0" % (job, task)
|
||||
|
||||
|
||||
class Topology(object):
|
||||
"""Describes a set of TPU devices.
|
||||
|
||||
Represents both the shape of the physical mesh, and the mapping between
|
||||
TensorFlow TPU devices to physical mesh coordinates.
|
||||
"""
|
||||
|
||||
def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
|
||||
"""Builds a Topology object.
|
||||
|
||||
If `serialized` is not `None`, the topology is parsed from `serialized` and
|
||||
the other arguments are ignored. Otherwise, the topology is computed from
|
||||
`mesh_shape` and `device_coordinates`.
|
||||
|
||||
Args:
|
||||
serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
|
||||
serialized proto is parsed to discover the topology.
|
||||
mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
|
||||
the shape of the TPU topology, in number of cores. Ignored if
|
||||
`serialized` is not `None`.
|
||||
device_coordinates: A rank 3 numpy array that describes the mapping from
|
||||
TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
|
||||
if `serialized is not `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `serialized` does not describe a well-formed topology.
|
||||
ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
|
||||
of 3 positive integers.
|
||||
ValueError: If `serialized` is `None` and `device_coordinates` is not a
|
||||
rank 3 numpy int32 array that describes a valid coordinate mapping.
|
||||
"""
|
||||
|
||||
self._serialized = serialized
|
||||
|
||||
if serialized:
|
||||
self._parse_topology(serialized)
|
||||
else:
|
||||
self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
|
||||
self._device_coordinates = np.asarray(device_coordinates, np.int32)
|
||||
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
|
||||
raise ValueError("`mesh_shape` must be a sequence of 3 positive "
|
||||
"entries; got {}".format(self._mesh_shape))
|
||||
|
||||
if (len(self._device_coordinates.shape) != 3 or
|
||||
self._device_coordinates.shape[2] != len(self._mesh_shape)):
|
||||
raise ValueError("`device_coordinates` must be a rank 3 int32 array "
|
||||
"with minor dimension equal to the mesh shape rank")
|
||||
|
||||
self._topology_tasks, self._topology_devices = self._invert_topology()
|
||||
|
||||
def _parse_topology(self, serialized):
|
||||
"""Parses a serialized `TopologyProto` into `self`."""
|
||||
proto = topology_pb2.TopologyProto()
|
||||
proto.ParseFromString(serialized)
|
||||
|
||||
self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
|
||||
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
|
||||
raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
|
||||
"entries; got {}".format(self._mesh_shape))
|
||||
|
||||
if proto.num_tasks < 0:
|
||||
raise ValueError("`num_tasks` must be >= 0; got {}".format(
|
||||
proto.num_tasks))
|
||||
if proto.num_tpu_devices_per_task < 0:
|
||||
raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
|
||||
proto.num_tpu_devices_per_task))
|
||||
|
||||
expected_coordinates_size = (
|
||||
proto.num_tasks * proto.num_tpu_devices_per_task * len(
|
||||
proto.mesh_shape))
|
||||
if len(proto.device_coordinates) != expected_coordinates_size:
|
||||
raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
|
||||
"num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
|
||||
"got shape {}".format(proto.num_tasks,
|
||||
proto.num_tpu_devices_per_task,
|
||||
proto.mesh_shape,
|
||||
len(proto.device_coordinates)))
|
||||
|
||||
coords = np.array(proto.device_coordinates, dtype=np.int32)
|
||||
if any(coords < 0):
|
||||
raise ValueError("`device_coordinates` must be >= 0")
|
||||
coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
|
||||
len(proto.mesh_shape)))
|
||||
self._device_coordinates = coords
|
||||
|
||||
def _invert_topology(self):
|
||||
"""Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
|
||||
tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
|
||||
devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
|
||||
for task in xrange(self.device_coordinates.shape[0]):
|
||||
for device in xrange(self.device_coordinates.shape[1]):
|
||||
x, y, z = self.device_coordinates[task, device, :]
|
||||
tasks[x, y, z] = task
|
||||
devices[x, y, z] = device
|
||||
return tasks, devices
|
||||
|
||||
@property
|
||||
def mesh_shape(self):
|
||||
"""A rank 1 int32 array describing the shape of the TPU topology."""
|
||||
return self._mesh_shape
|
||||
|
||||
@property
|
||||
def mesh_rank(self):
|
||||
"""Returns the number of dimensions in the mesh."""
|
||||
return len(self._mesh_shape)
|
||||
|
||||
@property
|
||||
def device_coordinates(self):
|
||||
"""Describes the mapping from TPU devices to topology coordinates.
|
||||
|
||||
Returns:
|
||||
A rank 3 int32 array with shape `[tasks, devices, axis]`.
|
||||
`tasks` is the number of tasks in the TPU cluster, `devices` is the number
|
||||
of TPU devices per task, and `axis` is the number of axes in the TPU
|
||||
cluster topology. Each entry gives the `axis`-th coordinate in the
|
||||
topology of a task/device pair. TPU topologies are 3-dimensional, with
|
||||
dimensions `(x, y, core number)`.
|
||||
"""
|
||||
return self._device_coordinates
|
||||
|
||||
def task_ordinal_at_coordinates(self, device_coordinates):
|
||||
"""Returns the TensorFlow task number attached to `device_coordinates`.
|
||||
|
||||
Args:
|
||||
device_coordinates: An integer sequence describing a device's physical
|
||||
coordinates in the TPU fabric.
|
||||
|
||||
Returns:
|
||||
Returns the TensorFlow task number that contains the TPU device with those
|
||||
physical coordinates.
|
||||
"""
|
||||
return self._topology_tasks[tuple(device_coordinates)]
|
||||
|
||||
def tpu_device_ordinal_at_coordinates(self, device_coordinates):
|
||||
"""Returns the TensorFlow device number at `device_coordinates`.
|
||||
|
||||
Args:
|
||||
device_coordinates: An integer sequence describing a device's physical
|
||||
coordinates in the TPU fabric.
|
||||
|
||||
Returns:
|
||||
Returns the TensorFlow device number within the task corresponding to
|
||||
attached to the device with those physical coordinates.
|
||||
"""
|
||||
return self._topology_devices[tuple(device_coordinates)]
|
||||
|
||||
def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
return _tpu_host_device_name(
|
||||
job, self._topology_tasks[tuple(device_coordinates)])
|
||||
|
||||
def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
return _tpu_device_name(job,
|
||||
self._topology_tasks[tuple(device_coordinates)],
|
||||
self._topology_devices[tuple(device_coordinates)])
|
||||
|
||||
@property
|
||||
def num_tasks(self):
|
||||
"""Returns the number of TensorFlow tasks in the TPU slice."""
|
||||
return self._device_coordinates.shape[0]
|
||||
|
||||
@property
|
||||
def num_tpus_per_task(self):
|
||||
"""Returns the number of TPU devices per task in the TPU slice."""
|
||||
return self._device_coordinates.shape[1]
|
||||
|
||||
def serialized(self):
|
||||
"""Returns the serialized form of the topology."""
|
||||
if self._serialized is None:
|
||||
proto = topology_pb2.TopologyProto()
|
||||
proto.mesh_shape[:] = list(self._mesh_shape)
|
||||
proto.num_tasks = self._device_coordinates.shape[0]
|
||||
proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
|
||||
proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
|
||||
self._serialized = proto.SerializeToString()
|
||||
|
||||
return self._serialized
|
||||
# pylint: disable=wildcard-import,unused-import,redefined-builtin
|
||||
from tensorflow.python.tpu.topology import *
|
||||
# pylint: enable=wildcard-import,unused-import,redefined-builtin
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,276 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""A RunConfig subclass with TPU support."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import util as util_lib
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
|
||||
_SERVICE_KEY = run_config_lib._SERVICE_KEY
|
||||
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class InputPipelineConfig(object):
|
||||
r"""Please see the definition of these values in TPUConfig."""
|
||||
PER_SHARD_V1 = 1
|
||||
PER_HOST_V1 = 2
|
||||
PER_HOST_V2 = 3
|
||||
BROADCAST = 4
|
||||
|
||||
|
||||
class TPUConfig(
|
||||
collections.namedtuple('TPUConfig', [
|
||||
'iterations_per_loop',
|
||||
'num_shards',
|
||||
'num_cores_per_replica',
|
||||
'per_host_input_for_training',
|
||||
'tpu_job_name',
|
||||
'initial_infeed_sleep_secs',
|
||||
'input_partition_dims',
|
||||
])):
|
||||
r"""TPU related configuration required by `TPUEstimator`.
|
||||
|
||||
Args:
|
||||
iterations_per_loop: This is the number of train steps running in TPU
|
||||
system before returning to CPU host for each `Session.run`. This means
|
||||
global step is increased `iterations_per_loop` times in one `Session.run`.
|
||||
It is recommended to be set as number of global steps for next checkpoint.
|
||||
num_shards: (Deprecated, ignored by TPUEstimator).
|
||||
The number of model replicas in the system. For non-model-parallelism
|
||||
case, this number equals the total number of TPU cores. For
|
||||
model-parallelism, the total number of TPU cores equals
|
||||
num_cores_per_replica * num_shards.
|
||||
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
|
||||
An integer which describes the number of TPU cores per model replica. This
|
||||
is required by model-parallelism which enables partitioning
|
||||
the model to multiple cores. Currently num_cores_per_replica must be
|
||||
1, 2, 4, or 8.
|
||||
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
|
||||
`input_fn` is invoked once on each host. With the per-core input pipeline
|
||||
configuration, it is invoked once for each core.
|
||||
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
|
||||
the batch size for each shard is `train_batch_size` // #hosts in the
|
||||
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
|
||||
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
|
||||
invoked once on host 0 and the tensors are broadcasted to all other
|
||||
replicas. The batch size equals to train_batch_size`. With the per-core
|
||||
input pipeline configuration, the shard batch size is also
|
||||
`train_batch_size` // #cores.
|
||||
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
|
||||
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
|
||||
within TPUEstimator, however when using ClusterSpec propagation in more
|
||||
esoteric cluster configurations, you may need to specify the job name as a
|
||||
string.
|
||||
initial_infeed_sleep_secs: The number of seconds the infeed thread should
|
||||
wait before enqueueing the first batch. This helps avoid timeouts for
|
||||
models that require a long compilation time.
|
||||
input_partition_dims: A nested list to describe the partition dims
|
||||
for all the tensors from input_fn(). The structure of
|
||||
input_partition_dims must match the structure of `features` and
|
||||
`labels` from input_fn(). The total number of partitions must match
|
||||
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
|
||||
images with shape [N, H, W, C] and labels [N].
|
||||
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
|
||||
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
|
||||
to all the TPU cores since the partition dims is `None`.
|
||||
Current limitations: This feature is only supported with the PER_HOST_V2
|
||||
input mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
iterations_per_loop=2,
|
||||
num_shards=None,
|
||||
num_cores_per_replica=None,
|
||||
per_host_input_for_training=True,
|
||||
tpu_job_name=None,
|
||||
initial_infeed_sleep_secs=None,
|
||||
input_partition_dims=None):
|
||||
|
||||
# Check iterations_per_loop.
|
||||
util_lib.check_positive_integer(iterations_per_loop,
|
||||
'TPUConfig iterations_per_loop')
|
||||
|
||||
# Check num_shards.
|
||||
if num_shards is not None:
|
||||
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
|
||||
|
||||
if input_partition_dims is not None:
|
||||
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
|
||||
raise ValueError(
|
||||
'input_partition_dims must be a list/tuple with one or two'
|
||||
' elements.')
|
||||
|
||||
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
|
||||
raise ValueError(
|
||||
'input_partition_dims is only supported in PER_HOST_V2 mode.')
|
||||
|
||||
if num_cores_per_replica is None:
|
||||
raise ValueError(
|
||||
'input_partition_dims requires setting num_cores_per_replica.')
|
||||
|
||||
# Check num_cores_per_replica
|
||||
if num_cores_per_replica is not None:
|
||||
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
|
||||
raise ValueError(
|
||||
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
|
||||
str(num_cores_per_replica)))
|
||||
|
||||
# per_host_input_for_training may be True, False, or integer in [1..3].
|
||||
# Map legacy values (True, False) to numeric values.
|
||||
if per_host_input_for_training is False:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
|
||||
elif per_host_input_for_training is True:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
|
||||
|
||||
# Check initial_infeed_sleep_secs.
|
||||
if initial_infeed_sleep_secs:
|
||||
util_lib.check_positive_integer(initial_infeed_sleep_secs,
|
||||
'TPUConfig initial_infeed_sleep_secs')
|
||||
|
||||
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
|
||||
|
||||
return super(TPUConfig, cls).__new__(
|
||||
cls,
|
||||
iterations_per_loop=iterations_per_loop,
|
||||
num_shards=num_shards,
|
||||
num_cores_per_replica=num_cores_per_replica,
|
||||
per_host_input_for_training=per_host_input_for_training,
|
||||
tpu_job_name=tpu_job_name,
|
||||
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
|
||||
input_partition_dims=input_partition_dims)
|
||||
|
||||
|
||||
class RunConfig(run_config_lib.RunConfig):
|
||||
"""RunConfig with TPU support."""
|
||||
|
||||
def __init__(self,
|
||||
tpu_config=None,
|
||||
evaluation_master=None,
|
||||
master=None,
|
||||
cluster=None,
|
||||
**kwargs):
|
||||
"""Constructs a RunConfig.
|
||||
|
||||
Args:
|
||||
tpu_config: the TPUConfig that specifies TPU-specific configuration.
|
||||
evaluation_master: a string. The address of the master to use for eval.
|
||||
Defaults to master if not set.
|
||||
master: a string. The address of the master to use for training.
|
||||
cluster: a ClusterResolver
|
||||
**kwargs: keyword config parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: if cluster is not None and the provided session_config has a
|
||||
cluster_def already.
|
||||
"""
|
||||
super(RunConfig, self).__init__(**kwargs)
|
||||
self._tpu_config = tpu_config or TPUConfig()
|
||||
self._cluster = cluster
|
||||
|
||||
# If user sets master and/or evaluation_master explicitly, including empty
|
||||
# string '', take it. Otherwise, take the values set by parent class.
|
||||
if master is not None:
|
||||
if cluster is not None:
|
||||
raise ValueError('Both master and cluster are set.')
|
||||
self._master = master
|
||||
else:
|
||||
if cluster:
|
||||
self._master = cluster.master()
|
||||
|
||||
if evaluation_master is not None:
|
||||
self._evaluation_master = evaluation_master
|
||||
elif (not self._evaluation_master and
|
||||
self.task_type != run_config_lib.TaskType.EVALUATOR):
|
||||
# If the task type is EVALUATOR, it means some cluster manager sets the
|
||||
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
|
||||
#
|
||||
# Otherwise, it means user executes the code without external cluster
|
||||
# manager. For that, we optimize the user experience by setting
|
||||
# evaluation_master to master, unless user overwrites it.
|
||||
self._evaluation_master = self._master
|
||||
|
||||
# Set the ClusterSpec to use
|
||||
if cluster:
|
||||
self._cluster_spec = cluster.cluster_spec()
|
||||
|
||||
# Merge the cluster_def into the ConfigProto.
|
||||
if self._session_config is None: # pylint: disable=access-member-before-definition
|
||||
self._session_config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=True, isolate_session_state=True)
|
||||
if self._session_config.HasField('cluster_def'):
|
||||
raise ValueError(
|
||||
'You cannot provide a ClusterResolver and '
|
||||
'session_config.cluster_def.')
|
||||
if self._cluster_spec:
|
||||
self._session_config.cluster_def.CopyFrom(
|
||||
self._cluster_spec.as_cluster_def())
|
||||
|
||||
def _maybe_overwrite_session_config_for_distributed_training(self):
|
||||
# Overrides the parent class session_config overwrite for between-graph. TPU
|
||||
# runs with in-graph, which should not have device filter. Doing nothing
|
||||
# ("pass") basically disables it.
|
||||
pass
|
||||
|
||||
@property
|
||||
def evaluation_master(self):
|
||||
return self._evaluation_master
|
||||
|
||||
@property
|
||||
def master(self):
|
||||
return self._master
|
||||
|
||||
@property
|
||||
def tpu_config(self):
|
||||
return self._tpu_config
|
||||
|
||||
@property
|
||||
def cluster(self):
|
||||
return self._cluster
|
||||
|
||||
def replace(self, **kwargs):
|
||||
if 'tpu_config' not in kwargs:
|
||||
return super(RunConfig, self).replace(**kwargs)
|
||||
|
||||
tpu_config = kwargs.pop('tpu_config')
|
||||
new_instance = super(RunConfig, self).replace(**kwargs)
|
||||
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
|
||||
return new_instance
|
||||
|
||||
|
||||
def _get_tpu_job_name_from_tf_config():
|
||||
"""Extracts the TPU job name from TF_CONFIG env variable."""
|
||||
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
|
||||
# spec propagation.
|
||||
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
|
||||
if tpu_job_name:
|
||||
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
|
||||
return tpu_job_name
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_config import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,763 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU system metadata and associated tooling."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding
|
||||
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_config
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||
_LOCAL_MASTERS = ('', 'local')
|
||||
_NUM_CORES_TO_COMPUTATION_SHAPE = {
|
||||
1: [1, 1, 1],
|
||||
2: [1, 1, 2],
|
||||
4: [1, 2, 2],
|
||||
8: [2, 2, 2],
|
||||
16: [4, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
class TPUContext(object):
|
||||
"""A context that holds the current configuration of the TPU computation."""
|
||||
|
||||
def __init__(self,
|
||||
internal_ctx,
|
||||
input_device=None,
|
||||
invocation_index=None,
|
||||
call_from_input_fn=True):
|
||||
self._internal_ctx = internal_ctx
|
||||
self._input_device = input_device
|
||||
self._invocation_index = invocation_index
|
||||
self._call_from_input_fn = call_from_input_fn
|
||||
|
||||
def current_input_fn_deployment(self):
|
||||
"""The configuration of the current input_fn invocation.
|
||||
|
||||
The configuration depends on `TPUConfig.per_host_input_for_training`. See
|
||||
`TPUConfig` for details.
|
||||
|
||||
Only set in params dict of input_fn
|
||||
|
||||
Returns:
|
||||
A tuple of
|
||||
1. Device spec string: String, is the current CPU host where the
|
||||
input_fn is invoked.
|
||||
2. Current invocation index: Int, 0-based index of the input_fn
|
||||
invocation. See next item for details.
|
||||
3. Total invocation count: Int, the total number of times to invoke the
|
||||
input_fn on all CPU hosts. Each invocation will be passed with a new
|
||||
`TPUContext` instance with current invocation index set properly.
|
||||
4. Total number of replicas consumed by current_invocation: Int, the
|
||||
number of replicas fed by the data returned by current input_fn. For
|
||||
example, for per_core input pipeline deployment
|
||||
and non-model-parallelism, total invocation count is equal to
|
||||
the number of cores in the system and num replicas consumed by
|
||||
current invocation is 1. For per-host v2 input pipeline deployment,
|
||||
total invocation count is equal to the number of hosts in the system
|
||||
and num replicas consumed by current invocation is equal to number of
|
||||
cores per host.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If this method must not be called from input_fn.
|
||||
"""
|
||||
if not self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' model_fn.')
|
||||
|
||||
if self._internal_ctx.is_input_sharded_per_core():
|
||||
total_invocation_count = (self._internal_ctx.num_hosts
|
||||
* self._internal_ctx.num_of_replicas_per_host)
|
||||
replicas_consumed = 1
|
||||
elif self._internal_ctx.is_input_broadcast_with_iterators():
|
||||
total_invocation_count = 1
|
||||
replicas_consumed = self._internal_ctx.num_replicas
|
||||
else:
|
||||
total_invocation_count = self._internal_ctx.num_hosts
|
||||
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
|
||||
return (self._input_device, self._invocation_index,
|
||||
total_invocation_count, replicas_consumed)
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
"""The total number of replicas.
|
||||
|
||||
For non-model-parallelism, num_replicas should be the total num of TPU
|
||||
cores in the system.
|
||||
|
||||
Returns:
|
||||
The number of replicas.
|
||||
"""
|
||||
return self._internal_ctx.num_replicas
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
"""The number of hosts for the TPU system."""
|
||||
return self._internal_ctx.num_hosts
|
||||
|
||||
@property
|
||||
def current_host(self):
|
||||
"""The current host index for the TPU system."""
|
||||
return self._invocation_index
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""The number of replicas for each host."""
|
||||
if self._internal_ctx.model_parallelism_enabled:
|
||||
raise ValueError(
|
||||
'num_of_replicas_per_host is not supported for model_parallelism')
|
||||
return self._internal_ctx.num_of_replicas_per_host
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
"""Returns device_assignment object."""
|
||||
if self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' input_fn.')
|
||||
return self._internal_ctx.device_assignment
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
# Note that: For the non-model parallelism, the mapping could be
|
||||
# a random permutation. The order should not matter in most cases
|
||||
# as far as model is replicated to all cores in the system.
|
||||
return self._internal_ctx.device_for_replica(replica_id)
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function.
|
||||
|
||||
The place function takes host_id as the input and returns the TF device
|
||||
for the correspoding host.
|
||||
"""
|
||||
|
||||
def _placement_function(host_id):
|
||||
"""Return the host device given host_id."""
|
||||
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
|
||||
class _InternalTPUContext(object):
|
||||
"""A context holds immutable states of TPU computation.
|
||||
|
||||
This immutable object holds TPUEstimator config, train/eval batch size, and
|
||||
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
|
||||
provides utility functions, based on the current state, to determine other
|
||||
information commonly required by TPU computation, such as TPU device names,
|
||||
TPU hosts, shard batch size, etc.
|
||||
|
||||
if eval_on_tpu is False, then execution of eval on TPU is disabled.
|
||||
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
|
||||
and TPU execution is disabled for all modes.
|
||||
|
||||
N.B. As `mode` is not immutable state in Estimator, but essential to
|
||||
distinguish between TPU training and evaluation, a common usage for
|
||||
_InternalTPUContext with `mode` is as follows:
|
||||
```
|
||||
with _ctx.with_mode(mode) as ctx:
|
||||
if ctx.is_running_on_cpu():
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
train_batch_size,
|
||||
eval_batch_size,
|
||||
predict_batch_size,
|
||||
use_tpu,
|
||||
eval_on_tpu=True,
|
||||
embedding_config_spec=None):
|
||||
self._config = config
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._predict_batch_size = predict_batch_size
|
||||
self._use_tpu = use_tpu
|
||||
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
|
||||
if not use_tpu and eval_on_tpu:
|
||||
logging.warning('eval_on_tpu ignored because use_tpu is False.')
|
||||
|
||||
self._eval_on_tpu = eval_on_tpu
|
||||
self._model_parallelism_enabled = (
|
||||
use_tpu and config.tpu_config.num_cores_per_replica)
|
||||
self._mode = None
|
||||
num_cores_per_replica = config.tpu_config.num_cores_per_replica
|
||||
if self._model_parallelism_enabled:
|
||||
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
|
||||
num_cores_per_replica]
|
||||
else:
|
||||
self._computation_shape = None
|
||||
self._lazy_tpu_system_metadata_dict = {} # key by master address
|
||||
self._lazy_device_assignment_dict = {} # key by master address
|
||||
self._lazy_validation_dict = {} # key by ModeKeys
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._lazy_embedding_config_dict = {} # key by master address
|
||||
|
||||
def _assert_mode(self):
|
||||
if self._mode is None:
|
||||
raise RuntimeError(
|
||||
'`mode` needs to be set via contextmanager `with_mode`.')
|
||||
return self._mode
|
||||
|
||||
@contextmanager
|
||||
def with_mode(self, mode):
|
||||
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
|
||||
# such as _lazy_tpu_system_metadata_dict between new copy and the original
|
||||
# one. Note that all lazy states stored in properties _lazy_foo are sort of
|
||||
# immutable as they should be same for the process lifetime.
|
||||
new_ctx = copy.copy(self)
|
||||
new_ctx._mode = mode # pylint: disable=protected-access
|
||||
yield new_ctx
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._assert_mode()
|
||||
|
||||
def _get_master_address(self):
|
||||
mode = self._assert_mode()
|
||||
config = self._config
|
||||
master = (
|
||||
config.master
|
||||
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
|
||||
return master
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
cluster_def = None
|
||||
if (self._config.session_config and
|
||||
self._config.session_config.cluster_def.job):
|
||||
cluster_def = self._config.session_config.cluster_def
|
||||
|
||||
# pylint: disable=protected-access
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._query_tpu_system_metadata(
|
||||
master,
|
||||
cluster_def=cluster_def,
|
||||
query_topology=self.model_parallelism_enabled))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
def _get_device_assignment(self):
|
||||
"""Gets the (maybe cached) TPU device assignment."""
|
||||
master = self._get_master_address()
|
||||
device_assignment = self._lazy_device_assignment_dict.get(master)
|
||||
if device_assignment is not None:
|
||||
return device_assignment
|
||||
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
|
||||
device_assignment = tpu_device_assignment.device_assignment(
|
||||
tpu_system_metadata.topology,
|
||||
computation_shape=self._computation_shape,
|
||||
num_replicas=self.num_replicas)
|
||||
|
||||
logging.info('num_cores_per_replica: %s',
|
||||
str(self._config.tpu_config.num_cores_per_replica))
|
||||
logging.info('computation_shape: %s', str(self._computation_shape))
|
||||
logging.info('num_replicas: %d', self.num_replicas)
|
||||
logging.info('device_assignment.topology.device_coordinates: %s',
|
||||
str(device_assignment.topology.device_coordinates))
|
||||
logging.info('device_assignment.core_assignment: %s',
|
||||
str(device_assignment.core_assignment))
|
||||
|
||||
self._lazy_device_assignment_dict[master] = device_assignment
|
||||
return device_assignment
|
||||
|
||||
@property
|
||||
def embedding_config(self):
|
||||
"""Returns the embedding config based on current mode."""
|
||||
master = self._get_master_address()
|
||||
if master in self._lazy_embedding_config_dict:
|
||||
embedding_config = self._lazy_embedding_config_dict[master]
|
||||
else:
|
||||
embedding_config = None
|
||||
if self._use_tpu and self._embedding_config_spec:
|
||||
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
|
||||
self._embedding_config_spec, self._train_batch_size,
|
||||
self._eval_batch_size, self.num_hosts, self.num_cores, master)
|
||||
if not embedding_config.has_embedding_tables():
|
||||
embedding_config = None
|
||||
self._lazy_embedding_config_dict[master] = embedding_config
|
||||
|
||||
if embedding_config is not None:
|
||||
mode = self._assert_mode()
|
||||
# Dynamically attach tpu_embedding based on mode. With
|
||||
# this, we could keep embedding_config immutable but call site always
|
||||
# accesses the unified API '.tpu_embedding'.
|
||||
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
|
||||
return embedding_config
|
||||
|
||||
@property
|
||||
def model_parallelism_enabled(self):
|
||||
return self._model_parallelism_enabled
|
||||
|
||||
@property
|
||||
def input_partition_dims(self):
|
||||
return self._config.tpu_config.input_partition_dims
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
return (self._get_device_assignment()
|
||||
if self._model_parallelism_enabled else None)
|
||||
|
||||
@property
|
||||
def num_of_cores_per_host(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_cores(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_cores
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""Return the number of replicas per host."""
|
||||
if self.model_parallelism_enabled:
|
||||
return self.num_replicas // self.num_hosts
|
||||
else:
|
||||
return self.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
num_cores_in_system = self.num_cores
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
if num_cores_per_replica > num_cores_in_system:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the total num of '
|
||||
'TPU cores in the system. num_cores_per_replica: {}, num cores '
|
||||
'in the system: {}'.format(num_cores_per_replica,
|
||||
num_cores_in_system))
|
||||
|
||||
if num_cores_in_system % num_cores_per_replica != 0:
|
||||
raise RuntimeError(
|
||||
'The num of cores in the system ({}) is not divisible by the num '
|
||||
'of cores ({}) required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
|
||||
num_cores_in_system, num_cores_per_replica))
|
||||
|
||||
return num_cores_in_system // num_cores_per_replica
|
||||
else:
|
||||
return num_cores_in_system
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_hosts
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def is_input_sharded_per_core(self):
|
||||
"""Return true if input_fn is invoked per-core (other than per-host)."""
|
||||
mode = self._assert_mode()
|
||||
return (mode == model_fn_lib.ModeKeys.TRAIN and
|
||||
(self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_SHARD_V1))
|
||||
|
||||
def is_input_per_host_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the per-host v2 config."""
|
||||
return (self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_HOST_V2)
|
||||
|
||||
def is_input_broadcast_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the full_replicae config."""
|
||||
return (self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.BROADCAST)
|
||||
|
||||
def is_running_on_cpu(self, is_export_mode=False):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU.
|
||||
|
||||
This API also validates user provided configuration, such as batch size,
|
||||
according the lazy initialized TPU system metadata.
|
||||
|
||||
Args:
|
||||
is_export_mode: Indicates whether the current mode is for exporting the
|
||||
model, when mode == PREDICT. Only with this bool, we could
|
||||
tell whether user is calling the Estimator.predict or
|
||||
Estimator.export_savedmodel, which are running on TPU and CPU
|
||||
respectively. Parent class Estimator does not distinguish these two.
|
||||
|
||||
Returns:
|
||||
bool, whether current input_fn or model_fn should be running on CPU.
|
||||
|
||||
Raises:
|
||||
ValueError: any configuration is invalid.
|
||||
"""
|
||||
|
||||
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
|
||||
if not is_running_on_cpu:
|
||||
self._validate_tpu_configuration()
|
||||
return is_running_on_cpu
|
||||
|
||||
def _is_running_on_cpu(self, is_export_mode):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
|
||||
mode = self._assert_mode()
|
||||
|
||||
if not self._use_tpu:
|
||||
return True
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
|
||||
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
||||
return True
|
||||
|
||||
if is_export_mode:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def global_batch_size(self):
|
||||
mode = self._assert_mode()
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
return self._train_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return self._eval_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return self._predict_batch_size
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def batch_size_for_input_fn(self):
|
||||
"""Returns the shard batch size for `input_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU
|
||||
if self.is_input_sharded_per_core() or (
|
||||
self.is_input_per_host_with_iterators()):
|
||||
return global_batch_size // self.num_replicas
|
||||
else:
|
||||
return global_batch_size // self.num_hosts
|
||||
|
||||
@property
|
||||
def batch_size_for_model_fn(self):
|
||||
"""Returns the shard batch size for `model_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU. always sharded per shard.
|
||||
return global_batch_size // self.num_replicas
|
||||
|
||||
@property
|
||||
def master_job(self):
|
||||
"""Returns the job name to use to place TPU computations on.
|
||||
|
||||
Returns:
|
||||
A string containing the job name, or None if no job should be specified.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user needs to specify a tpu_job_name, because we are
|
||||
unable to infer the job name automatically, or if the user-specified job
|
||||
names are inappropriate.
|
||||
"""
|
||||
run_config = self._config
|
||||
# If the user specifies the tpu_job_name, use that.
|
||||
if run_config.tpu_config.tpu_job_name:
|
||||
return run_config.tpu_config.tpu_job_name
|
||||
|
||||
# The tpu job is determined by the run_config. Right now, this method is
|
||||
# required as tpu_config is not part of the RunConfig.
|
||||
mode = self._assert_mode()
|
||||
master = (
|
||||
run_config.evaluation_master
|
||||
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
|
||||
if master in _LOCAL_MASTERS:
|
||||
return None
|
||||
|
||||
if (not run_config.session_config or
|
||||
not run_config.session_config.cluster_def.job):
|
||||
return _DEFAULT_JOB_NAME
|
||||
cluster_def = run_config.session_config.cluster_def
|
||||
job_names = set([job.name for job in cluster_def.job])
|
||||
if _DEFAULT_JOB_NAME in job_names:
|
||||
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
|
||||
raise ValueError('Currently, tpu_worker is not an allowed job name.')
|
||||
if len(job_names) == 1:
|
||||
return cluster_def.job[0].name
|
||||
if len(job_names) == 2:
|
||||
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
|
||||
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
|
||||
return job_names.pop()
|
||||
# TODO(b/67716447): Include more sophisticated heuristics.
|
||||
raise ValueError(
|
||||
'Could not infer TPU job name. Please specify a tpu_job_name as part '
|
||||
'of your TPUConfig.')
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function."""
|
||||
|
||||
master = self.master_job
|
||||
|
||||
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
|
||||
"""Return the host device given replica_id or host_id."""
|
||||
assert _sentinal is None
|
||||
if replica_id is not None and host_id is not None:
|
||||
raise RuntimeError(
|
||||
'replica_id and host_id can have only one non-None value.')
|
||||
|
||||
if master is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
if replica_id is not None:
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.host_device(
|
||||
replica=replica_id, job=master)
|
||||
else:
|
||||
host_id = replica_id / self.num_of_cores_per_host
|
||||
|
||||
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
@property
|
||||
def tpu_device_placement_function(self):
|
||||
"""Returns a TPU device placement Fn."""
|
||||
master = self.master_job
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
def _placement_function(i):
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.tpu_device(replica=i, job=master)
|
||||
else:
|
||||
num_of_cores_per_host = self.num_of_cores_per_host
|
||||
host_id = i / num_of_cores_per_host
|
||||
ordinal_id = i % num_of_cores_per_host
|
||||
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
def tpu_ordinal_function(self, host_id):
|
||||
"""Returns the TPU ordinal fn."""
|
||||
|
||||
def _tpu_ordinal_function(shard_index_in_host):
|
||||
"""Return the TPU ordinal associated with a shard.
|
||||
|
||||
Required because the enqueue ops are placed on CPU.
|
||||
|
||||
Args:
|
||||
shard_index_in_host: the shard index
|
||||
|
||||
Returns:
|
||||
The ordinal of the TPU device the shard's infeed should be placed on.
|
||||
"""
|
||||
if self.model_parallelism_enabled:
|
||||
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
|
||||
replica = self.device_assignment.lookup_replicas(host_id,
|
||||
0)[shard_index_in_host]
|
||||
return self.device_assignment.tpu_ordinal(replica=replica)
|
||||
else:
|
||||
return shard_index_in_host % self.num_of_cores_per_host
|
||||
|
||||
return _tpu_ordinal_function
|
||||
|
||||
def _validate_tpu_configuration(self):
|
||||
"""Validates the configuration based on the TPU system metadata."""
|
||||
mode = self._assert_mode()
|
||||
if self._lazy_validation_dict.get(mode):
|
||||
return
|
||||
|
||||
# All following information is obtained from TPU system metadata.
|
||||
num_cores = self.num_cores
|
||||
num_replicas = self.num_replicas
|
||||
num_hosts = self.num_hosts
|
||||
|
||||
if not num_cores:
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
raise RuntimeError(
|
||||
'Cannot find any TPU cores in the system. Please double check '
|
||||
'Tensorflow master address and TPU worker(s). Available devices '
|
||||
'are {}.'.format(tpu_system_metadata.devices))
|
||||
|
||||
if self._config.tpu_config.num_shards:
|
||||
user_provided_num_replicas = self._config.tpu_config.num_shards
|
||||
if user_provided_num_replicas != num_replicas:
|
||||
message = (
|
||||
'TPUConfig.num_shards is not set correctly. According to TPU '
|
||||
'system metadata for Tensorflow master ({}): num_replicas should '
|
||||
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
|
||||
'be the total num of TPU cores in the system. For '
|
||||
'model-parallelism, the total number of TPU cores should be '
|
||||
'num_cores_per_replica * num_replicas. Please set it '
|
||||
'accordingly or leave it as `None`'.format(
|
||||
self._get_master_address(), num_replicas,
|
||||
user_provided_num_replicas))
|
||||
|
||||
raise ValueError(message)
|
||||
|
||||
if self._config.tpu_config.num_cores_per_replica:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
|
||||
if num_cores_per_replica > num_cores_per_host:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the '
|
||||
'num_cores_per_host. num_cores_per_replica: {}, '
|
||||
'num_cores_per_host: {}'.format(num_cores_per_replica,
|
||||
num_cores_per_host))
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if (self._train_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'train batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._train_batch_size, num_replicas))
|
||||
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
if self._eval_batch_size is None:
|
||||
raise ValueError(
|
||||
'eval_batch_size in TPUEstimator constructor cannot be `None`'
|
||||
'if .evaluate is running on TPU.')
|
||||
if (self._eval_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'eval batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._eval_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.evaluate should be running on single TPU'
|
||||
' instead of a Pod.')
|
||||
else:
|
||||
assert mode == model_fn_lib.ModeKeys.PREDICT
|
||||
if self._predict_batch_size is None:
|
||||
raise ValueError(
|
||||
'predict_batch_size in TPUEstimator constructor should not be '
|
||||
'`None` if .predict is running on TPU.')
|
||||
if (self._predict_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'predict batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._predict_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.predict should be running on single TPU worker. '
|
||||
'got {}.'.format(num_hosts))
|
||||
|
||||
# Record the state "validated" into lazy dictionary.
|
||||
self._lazy_validation_dict[mode] = True
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
master = self.master_job
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
return (self.device_assignment.host_device(
|
||||
replica=replica_id, job=master),
|
||||
self.device_assignment.tpu_ordinal(replica=replica_id))
|
||||
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
num_of_replicas_per_host = self.num_of_replicas_per_host
|
||||
host_id = replica_id / num_of_replicas_per_host
|
||||
ordinal_id = replica_id % num_of_replicas_per_host
|
||||
|
||||
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
|
||||
return (host_device, ordinal_id)
|
||||
|
||||
|
||||
class _OneCoreTPUContext(_InternalTPUContext):
|
||||
"""Special _InternalTPUContext for one core usage."""
|
||||
|
||||
def __init__(self, config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu):
|
||||
|
||||
super(_OneCoreTPUContext, self).__init__(
|
||||
config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
|
||||
num_cores=1,
|
||||
num_hosts=1,
|
||||
num_of_cores_per_host=1,
|
||||
topology=None,
|
||||
devices=[]))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
|
||||
def _get_tpu_context(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec):
|
||||
"""Returns an instance of `_InternalTPUContext`."""
|
||||
|
||||
if (config.tpu_config.num_shards == 1 and
|
||||
config.tpu_config.num_cores_per_replica is None):
|
||||
if embedding_config_spec is not None:
|
||||
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
|
||||
'when embedding_config_spec is not None.')
|
||||
logging.warning(
|
||||
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
|
||||
'Please fix as soon as possible (leaving num_shards as None.)')
|
||||
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_context import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,153 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Optional helper for gradient handling."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def get_gradients_through_compute_gradients(optimizer, loss, activations):
|
||||
"""Compute gradients to send to TPU embedding.
|
||||
|
||||
Args:
|
||||
optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
|
||||
Used to call compute_gradients().
|
||||
loss: a Tensor to call optimizer.compute_gradients() on.
|
||||
activations: an OrderedDict mapping feature_name to Tensors of activations.
|
||||
|
||||
Returns:
|
||||
An OrderedDict mapping from feature name Strings to Tensors of gradients of
|
||||
the loss wrt the activations of the features.
|
||||
"""
|
||||
activation_list = activations.values()
|
||||
grads_and_vars = optimizer.compute_gradients(loss, activation_list)
|
||||
grads = [grad for grad, _ in grads_and_vars]
|
||||
feature_to_gradient_dict = collections.OrderedDict(
|
||||
zip(activations.keys(), grads))
|
||||
return feature_to_gradient_dict
|
||||
|
||||
|
||||
def create_dummy_table_variables(tpu_embedding):
|
||||
"""Create dummy embedding table variables.
|
||||
|
||||
The sole purpose of these dummy variables are to trigger gradient
|
||||
calcuation wrt them so that the gradients wrt activation can be captured
|
||||
and later sent to TPU embedding.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, dummy table variables will be created for use
|
||||
with tpu_embedding.
|
||||
|
||||
Returns:
|
||||
A tuple of dummy variables and their initializer.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if collection to store gradients already exists and is not
|
||||
empty.
|
||||
"""
|
||||
dummy_table_variables = collections.OrderedDict()
|
||||
for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
|
||||
dummy_table_variables[table] = (
|
||||
# Explicitly specifying collections prevents this variable from
|
||||
# being added to the GLOBAL_VARIABLES collection, so that Saver()
|
||||
# ignores it.
|
||||
# But Tensorflow optimizer creates slot variable for these dummy
|
||||
# variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
|
||||
# which will be in GLOBAL_VARIABLES collection,
|
||||
variable_scope.get_variable(
|
||||
'tpu_embedding_dummy_table_variable_{}'.format(table),
|
||||
dtype=dtypes.float32,
|
||||
shape=[1],
|
||||
use_resource=True,
|
||||
trainable=True,
|
||||
collections=['tpu_embedding_dummy_table_variables']))
|
||||
|
||||
g = ops.get_default_graph()
|
||||
table_gradients = g.get_collection_ref(
|
||||
'tpu_embedding_gradients_table_{}'.format(table_id))
|
||||
if table_gradients:
|
||||
raise RuntimeError(
|
||||
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
|
||||
table_gradients.extend(
|
||||
[None] * len(tpu_embedding.table_to_features_dict[table]))
|
||||
|
||||
return (dummy_table_variables,
|
||||
variables.variables_initializer(
|
||||
dummy_table_variables.values(),
|
||||
name='tpu_embedding_dummy_table_variables_init'))
|
||||
|
||||
|
||||
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
|
||||
dummy_table_variables):
|
||||
"""Have activations depend on dummy table variables for gradient intercept.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
|
||||
tpu_embedding.
|
||||
activations: An OrderedDict of feature name String to activation tensors.
|
||||
dummy_table_variables: An OrderedDict of table name String to dummy table
|
||||
variables.
|
||||
|
||||
Returns:
|
||||
An OrderedDict of feature name String to activation tensors, which can be
|
||||
used just as the activations input.
|
||||
"""
|
||||
new_activations = collections.OrderedDict()
|
||||
for feature in activations:
|
||||
table = tpu_embedding.feature_to_table_dict[feature]
|
||||
new_activations[feature] = tpu_ops.tpu_embedding_activations(
|
||||
dummy_table_variables[table],
|
||||
activations[feature],
|
||||
table_id=tpu_embedding.table_to_config_dict.keys().index(table),
|
||||
lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
|
||||
return new_activations
|
||||
|
||||
|
||||
def get_gradients_through_dummy_table_variables(tpu_embedding):
|
||||
"""Get gradients wrt the activations of each feature.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, create dummy table variable to be used with
|
||||
tpu_embedding.
|
||||
|
||||
Returns:
|
||||
An OrderedDict mapping feature name to gradient.
|
||||
|
||||
Raises:
|
||||
ValueError: if some gradients are not defined.
|
||||
"""
|
||||
g = ops.get_default_graph()
|
||||
feature_to_gradient_dict = collections.OrderedDict()
|
||||
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
|
||||
table_gradients = g.get_collection(
|
||||
'tpu_embedding_gradients_table_{}'.format(table_id))
|
||||
if any(gradient is None for gradient in table_gradients):
|
||||
raise ValueError(
|
||||
'Table {} with id {} has undefined gradients: this is probably '
|
||||
'because the model asked TPUEmbedding to compute activations that '
|
||||
'were not used.'.format(table, table_id))
|
||||
for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
|
||||
table_gradients):
|
||||
feature_to_gradient_dict[feature] = gradient
|
||||
return feature_to_gradient_dict
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_embedding_gradient import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,919 +1,25 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Helper library for handling infeed between hosts and TPUs.
|
||||
"""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def partition_or_replicate_on_host(tensor, dims):
|
||||
"""Partitions or replicates the input tensor.
|
||||
|
||||
The ops inside this function are placed on the host side.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor which will be partioned or replicated.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Returns:
|
||||
An iterator of `Tensor`s or a list of partioned tensors.
|
||||
"""
|
||||
if dims is None:
|
||||
return itertools.repeat(tensor)
|
||||
dims = np.array(dims)
|
||||
output = [tensor]
|
||||
shape_list = np.array(tensor.shape.as_list())
|
||||
quotients, remainders = np.divmod(shape_list, dims)
|
||||
for axis, (quotient, remainder, dim, original_size) in enumerate(
|
||||
zip(quotients, remainders, dims, shape_list)):
|
||||
if dim <= 1:
|
||||
continue
|
||||
if remainder > 0:
|
||||
# For each dimension, when it cannot be evenly partitioned, XLA assumes
|
||||
# tensors are partitioned in a greedy manner by using
|
||||
# ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
|
||||
# are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
|
||||
# [[(3, 4), (3, 4), (2, 4), (2, 2)],
|
||||
# [(2, 4), (2, 4), (2, 4), (2, 2)]]
|
||||
ceil_ratio = quotient + 1
|
||||
num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
|
||||
num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
|
||||
if len(num_or_size_splits) < dim:
|
||||
num_or_size_splits += [0] * (dim - len(num_or_size_splits))
|
||||
new_output = []
|
||||
for x in output:
|
||||
new_output.append(
|
||||
array_ops.split(
|
||||
x, num_or_size_splits=num_or_size_splits, axis=axis))
|
||||
output = new_output
|
||||
else:
|
||||
output = [array_ops.split(x, dim, axis=axis) for x in output]
|
||||
output = nest.flatten(output)
|
||||
return output
|
||||
|
||||
|
||||
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
||||
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
||||
|
||||
Args:
|
||||
tensor: The dequeued tensor on TPU.
|
||||
dims: A list of integer describes how the tensor is partitioned.
|
||||
|
||||
Returns:
|
||||
The same tensor with the xla_sharding attribute.
|
||||
"""
|
||||
if dims is None:
|
||||
return xla_sharding.replicate(tensor)
|
||||
elif np.prod(dims) == 1:
|
||||
return xla_sharding.assign_device(tensor, 0)
|
||||
else:
|
||||
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
||||
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
|
||||
|
||||
|
||||
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
||||
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
|
||||
|
||||
Args:
|
||||
dequeues: A list of dequeued tensors on TPU.
|
||||
dims: A list of integer describes how the tensor is partitioned.
|
||||
|
||||
Returns:
|
||||
The same dequeues with appropriate xla_sharding attribute.
|
||||
"""
|
||||
nest.assert_shallow_structure(dequeues, dims)
|
||||
return nest.map_structure_up_to(
|
||||
dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
|
||||
|
||||
|
||||
class InfeedQueue(object):
|
||||
"""A helper object to build a device infeed queue.
|
||||
|
||||
The InfeedQueue builds the host-side and device-side Ops to enqueue and
|
||||
dequeue elements, respectively, and ensures that their types and
|
||||
shapes match.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
number_of_tuple_elements=None,
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
name=None):
|
||||
"""Creates a new InfeedQueue with the given configuration.
|
||||
|
||||
The configuration need not be fully specified at creation since it
|
||||
can be modified subsequently by methods that set the values
|
||||
explicitly or infer them from the shapes of inputs.
|
||||
|
||||
Args:
|
||||
number_of_tuple_elements: the number of Tensors fed atomically through the
|
||||
queue, must be present unless it can be inferred from other arguments.
|
||||
tuple_types: if not None, a list of types of the elements of the queue.
|
||||
tuple_shapes: if not None, a list of shapes of the elements of the queue.
|
||||
shard_dimensions: if not None, a list of dimensions on which the
|
||||
elements of the queue should be sharded during automatic
|
||||
parallelization.
|
||||
name: the name of the queue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_tuple_elements <= 0; or
|
||||
number_of_tuple_arguments, tuple_types, tuple_shapes, and
|
||||
shard_dimensions are all None; or the length of tuple_types,
|
||||
tuple_shapes, or shard_dimensions is not equal to
|
||||
number_of_tuple_elements; or any element of shard_dimensions
|
||||
can't be converted to a Dimension.
|
||||
TypeError: if any element of tuple_types or tuple_shapes can't
|
||||
be converted to a dtype or TensorShape, respectively.
|
||||
"""
|
||||
self._frozen = False
|
||||
self._generated_enqueue_ops = False
|
||||
self._generated_dequeue_op = False
|
||||
self._name = "InfeedQueue" if name is None else name
|
||||
if number_of_tuple_elements is None:
|
||||
if tuple_types is not None:
|
||||
number_of_tuple_elements = len(tuple_types)
|
||||
elif tuple_shapes is not None:
|
||||
number_of_tuple_elements = len(tuple_shapes)
|
||||
elif shard_dimensions is not None:
|
||||
number_of_tuple_elements = len(shard_dimensions)
|
||||
else:
|
||||
raise ValueError(
|
||||
"number of tuple elements cannot be inferred from InfeedQueue "
|
||||
"constructor")
|
||||
if number_of_tuple_elements <= 0:
|
||||
raise ValueError("number_of_tuple_elements %d must be > 0" %
|
||||
number_of_tuple_elements)
|
||||
# Make an empty sharding policy for each tuple element.
|
||||
self._sharding_policies = [
|
||||
tpu_sharding.ShardingPolicy()
|
||||
for _ in xrange(number_of_tuple_elements)
|
||||
]
|
||||
if tuple_types is not None:
|
||||
self.set_tuple_types(tuple_types)
|
||||
else:
|
||||
self._tuple_types = None
|
||||
if tuple_shapes is not None:
|
||||
self.set_tuple_shapes(tuple_shapes)
|
||||
else:
|
||||
self._tuple_shapes = None
|
||||
if shard_dimensions is not None:
|
||||
self.set_shard_dimensions(shard_dimensions)
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Checks that the configuration is self-consistent.
|
||||
|
||||
Raises:
|
||||
ValueError: if the shapes and sharding policies don't match.
|
||||
"""
|
||||
if self.tuple_shapes is not None:
|
||||
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
|
||||
# Raise an error if the policy is incompatible with the shape.
|
||||
_ = policy.get_sharded_shape(shape)
|
||||
|
||||
@property
|
||||
def number_of_tuple_elements(self):
|
||||
"""Returns the number of InfeedQueue tuple elements."""
|
||||
return len(self._sharding_policies)
|
||||
|
||||
@property
|
||||
def tuple_types(self):
|
||||
"""Returns the types of the InfeedQueue tuple elements."""
|
||||
return self._tuple_types
|
||||
|
||||
def set_tuple_types(self, tuple_types):
|
||||
"""Sets the type of each element of the queue.
|
||||
|
||||
tuple_types must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a dtype.
|
||||
|
||||
Args:
|
||||
tuple_types: the types of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_types is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_types cannot be converted to a
|
||||
dtype.
|
||||
"""
|
||||
if len(tuple_types) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_types is %s, but must be a list of length %d" %
|
||||
(str(tuple_types), self.number_of_tuple_elements))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_types, tuple_types):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible type. Frozen types are %s, updated types are %s" % (
|
||||
str(self._tuple_types), str(tuple_types)))
|
||||
else:
|
||||
try:
|
||||
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
|
||||
except (TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_types is %s, but must be a list of elements each "
|
||||
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
|
||||
|
||||
@property
|
||||
def tuple_shapes(self):
|
||||
"""Returns the shapes of the InfeedQueue tuple elements."""
|
||||
return self._tuple_shapes
|
||||
|
||||
def set_tuple_shapes(self, tuple_shapes):
|
||||
"""Sets the shape of each element of the queue.
|
||||
|
||||
tuple_shapes must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a TensorShape.
|
||||
|
||||
Args:
|
||||
tuple_shapes: the shapes of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_shapes is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_shapes cannot be converted to
|
||||
a TensorShape.
|
||||
"""
|
||||
if len(tuple_shapes) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
|
||||
(str(tuple_shapes), self.number_of_tuple_elements))
|
||||
try:
|
||||
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_shapes is %s, but must be a list of elements each "
|
||||
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
|
||||
str(e)))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
|
||||
% (str(self._tuple_shapes), str(tuple_shapes)))
|
||||
else:
|
||||
self._tuple_shapes = tuple_shapes
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def sharding_policies(self):
|
||||
"""Returns the sharding policies of the InfeedQueue tuple elements."""
|
||||
return self._sharding_policies
|
||||
|
||||
@property
|
||||
def shard_dimensions(self):
|
||||
"""Gets the shard dimension of each tuple element.
|
||||
|
||||
Returns:
|
||||
A list of length number_of_tuple_elements, where each list entry
|
||||
is the shard dimension of that tuple element or None if the
|
||||
shard dimension has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return [policy.shard_dimension for policy in self._sharding_policies]
|
||||
|
||||
def set_shard_dimensions(self, shard_dimensions):
|
||||
"""Sets the shard_dimension of each element of the queue.
|
||||
|
||||
shard_dimensions must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a Dimension compatible with self.tuple_shapes.
|
||||
|
||||
Args:
|
||||
shard_dimensions: the dimensions of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if shard_dimensions is not of length
|
||||
self.number_of_tuple_elements; or an element of
|
||||
shard_dimensions cannot be converted to a Dimension; or an
|
||||
element of shard_dimensions is a Dimension that is out of
|
||||
range for the corresponding tuple element shape.
|
||||
"""
|
||||
if len(shard_dimensions) != self.number_of_tuple_elements:
|
||||
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
|
||||
% (str(shard_dimensions),
|
||||
self.number_of_tuple_elements))
|
||||
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
|
||||
policy.set_shard_dimension(dimension)
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Gets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Returns:
|
||||
Number of shards or None if the number of shards has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return self._sharding_policies[0].number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Args:
|
||||
number_of_shards: number of ways to shard the InfeedQueue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_shards is not > 0; or the policies have
|
||||
been frozen and number_of_shards was already set to something
|
||||
else.
|
||||
"""
|
||||
for policy in self._sharding_policies:
|
||||
policy.set_number_of_shards(number_of_shards)
|
||||
self._validate()
|
||||
|
||||
def set_configuration_from_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of Tensors whose types and shapes are used
|
||||
to set the queue configuration.
|
||||
|
||||
Args:
|
||||
input_tensors: list of Tensors of the same types and shapes as
|
||||
the desired queue Tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: if input_tensors is not a list of length
|
||||
self.number_of_tuple_elements
|
||||
"""
|
||||
if len(input_tensors) != self.number_of_tuple_elements:
|
||||
raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
|
||||
% (str(input_tensors), self.number_of_tuple_elements))
|
||||
self.set_tuple_shapes([t.shape for t in input_tensors])
|
||||
self.set_tuple_types([t.dtype for t in input_tensors])
|
||||
|
||||
def set_configuration_from_sharded_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of lists of Tensors whose types and shapes are used
|
||||
to set the queue configuration. The length of the outer list is the number
|
||||
of shards required, and each inner list is the tuple of Tensors to use to
|
||||
determine the types and shapes of the corresponding shard. This method
|
||||
depends on the shard dimension, and calling it freezes the shard policy.
|
||||
|
||||
Args:
|
||||
input_tensors: list of lists of Tensors. The outer list length corresponds
|
||||
to the desired number of shards, and each inner list is the size
|
||||
and shape of the desired configuration of the corresponding shard.
|
||||
|
||||
Raises:
|
||||
ValueError: if any inner list is not a list of length
|
||||
self.number_of_tuple_elements; or the inner lists do not combine to
|
||||
form a consistent unsharded shape.
|
||||
TypeError: if the types of the Tensors in the inner lists do not match.
|
||||
"""
|
||||
if not self._frozen:
|
||||
# Unset the tuple shapes in case the configuration becomes
|
||||
# transiently inconsistent.
|
||||
self._tuple_shapes = None
|
||||
number_of_shards = len(input_tensors)
|
||||
self.set_number_of_shards(number_of_shards)
|
||||
for t in input_tensors:
|
||||
if len(t) != self.number_of_tuple_elements:
|
||||
raise ValueError(
|
||||
"input_tensors is %s but must be a list of lists, where each inner"
|
||||
" list has length number_of_tuple_elements=%d" % (
|
||||
str(input_tensors), self.number_of_tuple_elements))
|
||||
# Transpose the inputs to make a list of shard shapes for each tuple
|
||||
# element.
|
||||
sharded_shapes = [[t[i].shape for t in input_tensors]
|
||||
for i in xrange(self.number_of_tuple_elements)]
|
||||
# For each tuple, get the unsharded shape using that tuple's policy.
|
||||
unsharded_shapes = [
|
||||
policy.get_unsharded_shape(s)
|
||||
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
|
||||
]
|
||||
self.set_tuple_shapes(unsharded_shapes)
|
||||
for i in xrange(1, self.number_of_shards):
|
||||
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
|
||||
if t1.dtype != t2.dtype:
|
||||
raise TypeError(
|
||||
"types of the tuple elements of input_tensors %s are not "
|
||||
"consistent" % str(input_tensors))
|
||||
self.set_tuple_types([t.dtype for t in input_tensors[0]])
|
||||
|
||||
def freeze(self):
|
||||
"""Freezes the InfeedQueue so it can no longer be modified.
|
||||
|
||||
The configuration is implicitly frozen before any host-side or
|
||||
device-side Ops are generated. The configuration cannot be frozen
|
||||
until the types and shapes of the tuple elements have been set.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set.
|
||||
"""
|
||||
self._frozen = True
|
||||
if self._tuple_types is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple types.")
|
||||
if self._tuple_shapes is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for shape in self._tuple_shapes:
|
||||
if shape.dims is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for policy in self._sharding_policies:
|
||||
policy.freeze()
|
||||
self._validate()
|
||||
|
||||
def generate_dequeue_op(self, tpu_device=0):
|
||||
"""Generates the device-side Op to dequeue a tuple from the queue.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen, which will raise errors if the shapes and types have not
|
||||
been fully specified.
|
||||
|
||||
Args:
|
||||
tpu_device: The TPU device ordinal where the infeed instruction should be
|
||||
placed. If None, no explicit placement will be performed, and it is up
|
||||
to the user to call this API from within a proper TPU device scope.
|
||||
The XLA code will fail if the TPU dequeue instruction is not bound to
|
||||
any device.
|
||||
|
||||
Returns:
|
||||
A list of Outputs corresponding to a shard of infeed dequeued
|
||||
into XLA, suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set; or if a dequeue op has already been generated.
|
||||
"""
|
||||
self.freeze()
|
||||
if self._generated_dequeue_op:
|
||||
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
if tpu_device is not None:
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
else:
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
|
||||
def _generate_enqueue_op(self,
|
||||
inputs,
|
||||
name_prefix,
|
||||
index,
|
||||
device=None,
|
||||
tpu_ordinal=-1):
|
||||
"""Generate a host-side Op to enqueue a tuple to the queue.
|
||||
|
||||
If device is None the inputs are all required to have the same
|
||||
device specification, and the enqueue Op is colocated with
|
||||
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors with the types and shapes of the tuple elements.
|
||||
name_prefix: the base name for the Op.
|
||||
index: the shard index, used to uniquify the Op name.
|
||||
device: device to place the Op on, or None if it should be
|
||||
colocated with the inputs.
|
||||
tpu_ordinal: ordinal of the TPU device on the host to use for
|
||||
infeed if device is a CPU device. Should be set to -1 if device
|
||||
is a TPU device.
|
||||
|
||||
Returns:
|
||||
An Op corresponding to a shard of infeed enqueued at the host,
|
||||
suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if device is None and inputs do not all have the
|
||||
same device specification.
|
||||
"""
|
||||
full_name = "%s/%d" % (name_prefix, index)
|
||||
shapes = [t.shape for t in inputs]
|
||||
if device is None:
|
||||
devices = [t.device for t in inputs]
|
||||
for i in xrange(1, self.number_of_tuple_elements):
|
||||
if devices[0] != devices[i]:
|
||||
raise ValueError(
|
||||
"input devices for shard %d are %s, but should all be the same" %
|
||||
(index, str(devices)))
|
||||
with ops.colocate_with(inputs[0]):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
else:
|
||||
with ops.device(device):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
|
||||
def generate_enqueue_ops(self,
|
||||
sharded_inputs,
|
||||
tpu_ordinal_function=None,
|
||||
placement_function=None):
|
||||
"""Generates the host-side Ops to enqueue the shards of a tuple.
|
||||
|
||||
sharded_inputs is a list, one for each shard, of lists of
|
||||
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
|
||||
shard 0 if the queue. Returns the host-side Ops that must be run to
|
||||
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
|
||||
for shard i.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of sharded_inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
sharded_inputs: a list of lists of Tensors. The length of the outer list
|
||||
determines the number of shards. Each inner list indicates the types
|
||||
and shapes of the tuples in the corresponding shard.
|
||||
tpu_ordinal_function: if not None, a function that takes the
|
||||
shard index as input and returns the ordinal of the TPU device
|
||||
the shard's infeed should be placed on. tpu_ordinal_function must be
|
||||
set if the inputs are placed on CPU devices.
|
||||
placement_function: if not None, a function that takes the shard index as
|
||||
input and returns the host device where the enqueue op should be placed
|
||||
on.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the shapes of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple; or if the elements of a tuple
|
||||
have different device constraints.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the types of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple.
|
||||
"""
|
||||
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = lambda index: -1
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
shard,
|
||||
name_prefix,
|
||||
index,
|
||||
tpu_ordinal=tpu_ordinal_function(index),
|
||||
device=placement_function(index) if placement_function else None)
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
||||
|
||||
# TODO(misard) Generalize this to the case of systems that don't
|
||||
# have 8 devices per host, and figure out what to do with
|
||||
# model-parallelism.
|
||||
def _default_placement_function(self, index):
|
||||
return "/task:%d/device:CPU:0" % (index / 8)
|
||||
|
||||
def _default_ordinal_function(self, index):
|
||||
return index % 8
|
||||
|
||||
# TODO(b/36470756) remove this from tutorials once we have a better story
|
||||
# for automatic placement of input pipelines.
|
||||
def split_inputs_and_generate_enqueue_ops(self,
|
||||
inputs,
|
||||
device_assignment=None,
|
||||
placement_function=None,
|
||||
tpu_ordinal_function=None):
|
||||
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
|
||||
|
||||
Generates the host-side Ops to enqueue a tuple.
|
||||
|
||||
This method performs poorly because it takes an entire input on a single
|
||||
host, splits it, and distributes it to all of the cores. It is present only
|
||||
to simplify tutorial examples.
|
||||
|
||||
inputs is a list of Tensors to use to feed the queue. Each input is split
|
||||
into self.number_of_shards shards. Returns an Op for each shard to enqueue
|
||||
the shard. The Op for shard i is placed on device placement_function(i).
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors which indicates the types and shapes of the
|
||||
queue tuple.
|
||||
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
|
||||
device_assignment is not `None`, but `placement_function` and
|
||||
`ordinal_function` are None, then `device_assignment` will be used to
|
||||
place infeeds on the first k TPU shards, where k is the number of shards
|
||||
in the queue. If all three are `None`, then default placement and
|
||||
ordinal functions are used.
|
||||
placement_function: if not None, a function that takes the shard
|
||||
index as input and returns a device string indicating which
|
||||
device the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
tpu_ordinal_function: if not None, a function that takes the
|
||||
shard index as input and returns the ordinal of the TPU device
|
||||
the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
"""
|
||||
if device_assignment is None:
|
||||
if placement_function is None:
|
||||
placement_function = self._default_placement_function
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = self._default_ordinal_function
|
||||
else:
|
||||
|
||||
def _placement_function_from_map(index):
|
||||
return device_assignment.host_device(replica=index)
|
||||
|
||||
def _ordinal_function_from_map(index):
|
||||
return device_assignment.tpu_ordinal(replica=index)
|
||||
|
||||
if placement_function is None:
|
||||
placement_function = _placement_function_from_map
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = _ordinal_function_from_map
|
||||
self.set_configuration_from_input_tensors(inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
split_name_prefix = "%s/split" % self._name
|
||||
if self.number_of_shards == 1:
|
||||
transposed_sharded_inputs = [[inp] for inp in inputs]
|
||||
else:
|
||||
|
||||
def split_fn(inp, num_shards, axis, name):
|
||||
with ops.colocate_with(inp):
|
||||
return array_ops.split(inp, num_shards, axis=axis, name=name)
|
||||
|
||||
transposed_sharded_inputs = [
|
||||
split_fn(
|
||||
inp,
|
||||
self.number_of_shards,
|
||||
axis=policy.shard_dimension,
|
||||
name="%s/%d" % (split_name_prefix, index))
|
||||
for (inp, policy, index) in zip(inputs, self._sharding_policies,
|
||||
xrange(self.number_of_tuple_elements))
|
||||
]
|
||||
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
|
||||
for i in xrange(self.number_of_shards)]
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
shard,
|
||||
name_prefix,
|
||||
index,
|
||||
device=placement_function(index),
|
||||
tpu_ordinal=tpu_ordinal_function(index))
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
||||
|
||||
|
||||
class _PartitionedInfeedQueue(InfeedQueue):
|
||||
"""A helper object to build a device infeed queue with input partition.
|
||||
|
||||
Args:
|
||||
number_of_tuple_elements: the number of Tensors fed atomically through the
|
||||
queue, must be present unless it can be inferred from other arguments.
|
||||
device_assignment: A TPU `DeviceAssignment` which is used to place all the
|
||||
partitions to different TPU infeed queues.
|
||||
host_id: The id of the host machine.
|
||||
input_partition_dims: A nested list/tuple of integers. Each inner
|
||||
list/tuple describes how to partition the corresponding input tensor.
|
||||
tuple_types: If not None, a list of types of the elements of the queue.
|
||||
tuple_shapes: If not None, a list of shapes of the elements of the queue.
|
||||
name: The name of the queue.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
number_of_tuple_elements,
|
||||
device_assignment,
|
||||
host_id,
|
||||
input_partition_dims=None,
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
name=None):
|
||||
super(_PartitionedInfeedQueue, self).__init__(
|
||||
number_of_tuple_elements=number_of_tuple_elements,
|
||||
tuple_types=tuple_types,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
name="PartitionedInfeedQueue" if name is None else name)
|
||||
self._input_partition_dims = input_partition_dims
|
||||
self._host_id = host_id
|
||||
self._device_assignment = device_assignment
|
||||
|
||||
def generate_dequeue_op(self, tpu_device=0):
|
||||
"""Generate TPU dequeue ops.
|
||||
|
||||
Args:
|
||||
tpu_device: The TPU device ordinal where the infeed instruction should be
|
||||
placed.
|
||||
|
||||
Returns:
|
||||
A list of Outputs corresponding to a partition of infeed dequeued
|
||||
into XLA, suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set; or if a dequeue op has already been generated.
|
||||
"""
|
||||
self.freeze()
|
||||
if self._generated_dequeue_op:
|
||||
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
values = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
return tag_sharding_attribute_for_dequeued_tensors(
|
||||
values, self._input_partition_dims)
|
||||
|
||||
def generate_enqueue_ops(self, per_host_sharded_inputs):
|
||||
"""Generates the host-side Ops to enqueue the partitioned inputs.
|
||||
|
||||
per_host_sharded_inputs is a list, one for each replica, of lists of
|
||||
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
|
||||
replica i.
|
||||
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
|
||||
|
||||
For example, if sharded_inputs[i][j] is a 2-D Tensor:
|
||||
[[A, B, C, D],
|
||||
[E ,F, G, H]]
|
||||
self._input_partition_dims[j] is [2, 4].
|
||||
|
||||
sharded_inputs[i][j] will be partitioned and flattened into:
|
||||
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
|
||||
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
|
||||
|
||||
Args:
|
||||
per_host_sharded_inputs: a list of lists of Tensors. The length of the
|
||||
outer list determines the number of shards. Each inner list indicates
|
||||
the types and shapes of the tuples in the corresponding shard.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the shapes of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple; or if the elements of a tuple
|
||||
have different device constraints; or if the partition dims are invalid.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the types of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple.
|
||||
"""
|
||||
self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
|
||||
number_of_replicas_per_host = len(per_host_sharded_inputs)
|
||||
number_of_tuple_elements = len(per_host_sharded_inputs[0])
|
||||
|
||||
assert len(self._input_partition_dims) == number_of_tuple_elements
|
||||
per_host_enqueue_ops = []
|
||||
|
||||
for replica_index in range(number_of_replicas_per_host):
|
||||
flattened_inputs = per_host_sharded_inputs[replica_index]
|
||||
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
|
||||
self._input_partition_dims)
|
||||
inputs_parted_iters = [
|
||||
iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
|
||||
for x, dims in zip(per_host_sharded_inputs[replica_index],
|
||||
inputs_part_dims_flat)
|
||||
]
|
||||
|
||||
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
|
||||
# Places different partitions to different logic cores.
|
||||
replica_id = self._device_assignment.lookup_replicas(
|
||||
self._host_id, logical_core)[replica_index]
|
||||
ordinal = self._device_assignment.tpu_ordinal(
|
||||
replica=replica_id, logical_core=logical_core)
|
||||
infeed_inputs = []
|
||||
for it in inputs_parted_iters:
|
||||
input_for_device = next(it, None)
|
||||
if input_for_device is not None:
|
||||
infeed_inputs.append(input_for_device)
|
||||
|
||||
if infeed_inputs:
|
||||
per_host_enqueue_ops.append(
|
||||
tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=infeed_inputs,
|
||||
shapes=[x.shape for x in infeed_inputs],
|
||||
name="enqueue/replica_{0}/input_{1}".format(
|
||||
replica_index, logical_core),
|
||||
device_ordinal=ordinal))
|
||||
return per_host_enqueue_ops
|
||||
|
||||
def _check_input_partition_dims(self, tensor, dims):
|
||||
"""Checks that input partition dims are valid for the `Tensor`.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor for partitioning.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tensor can't be partitioned by dims or the
|
||||
num_cores_per_replica doesn't match the number of
|
||||
partitions(dims.prod()).
|
||||
"""
|
||||
# No partitioning specified, so don't perform further checks.
|
||||
if dims is None:
|
||||
return
|
||||
|
||||
dims = np.array(dims)
|
||||
|
||||
if (dims < 1).any():
|
||||
raise ValueError("All input partition dims must be >= 1.")
|
||||
|
||||
# No partitioning, so don't perform further checks.
|
||||
if dims.prod() == 1:
|
||||
return
|
||||
|
||||
if dims.prod() != self._device_assignment.num_cores_per_replica:
|
||||
raise ValueError(
|
||||
"The product of each input parition dim should equal to "
|
||||
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
|
||||
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
|
||||
if dims.shape[0] != tensor.shape.ndims:
|
||||
raise ValueError(
|
||||
"Input partition dims must have the same number of dimensions "
|
||||
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
|
||||
"partition dims = {}).".format(tensor.shape.as_list(), dims))
|
||||
|
||||
tensor.shape.assert_is_fully_defined()
|
||||
|
||||
def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
|
||||
"""Checks dims and partitions or replicates the input tensor.
|
||||
|
||||
The ops inside this function are placed on the host side.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor which will be partioned or replicated.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Returns:
|
||||
An iterator of `Tensor`s or a list of partioned tensors.
|
||||
"""
|
||||
self._check_input_partition_dims(tensor, dims)
|
||||
return partition_or_replicate_on_host(tensor, dims)
|
||||
# pylint: disable=wildcard-import,unused-import,redefined-builtin
|
||||
from tensorflow.python.tpu.tpu_feed import *
|
||||
# used by tests
|
||||
from tensorflow.python.tpu.tpu_feed import _PartitionedInfeedQueue
|
||||
# pylint: enable=wildcard-import,unused-import,redefined-builtin
|
||||
|
@ -1,66 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for functions used during TPU compilation."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
class TpuContext(object):
|
||||
"""A context object holding state about the TPU computation being built."""
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a new TpuContext."""
|
||||
self._number_of_shards = None
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
self._number_of_shards = number_of_shards
|
||||
|
||||
|
||||
# The Tpu context holds the number of shards when a sharded computation is
|
||||
# being built, or None if no computation is being built.
|
||||
_current_tpu_context = TpuContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tpu_shard_context(number_of_shards):
|
||||
if _current_tpu_context.number_of_shards is not None:
|
||||
raise NotImplementedError("tpu_shard_context cannot be nested.")
|
||||
try:
|
||||
_current_tpu_context.set_number_of_shards(number_of_shards)
|
||||
yield
|
||||
finally:
|
||||
_current_tpu_context.set_number_of_shards(None)
|
||||
|
||||
|
||||
def get_tpu_context():
|
||||
return _current_tpu_context
|
||||
|
||||
|
||||
# Decorator function for tpu computation func that was passed to tpu.rewrite()
|
||||
# if there is an embedded training loop in this func, trace tools will generate
|
||||
# step markers for each iteration.
|
||||
def on_device_training_loop(func):
|
||||
# Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
|
||||
setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
|
||||
return func
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_function import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,203 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Optimizer that implements cross-shard gradient reduction for TPU."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import optimizer
|
||||
|
||||
|
||||
class CrossShardOptimizer(optimizer.Optimizer):
|
||||
"""An optimizer that averages gradients across TPU shards."""
|
||||
|
||||
def __init__(self,
|
||||
opt,
|
||||
reduction=losses.Reduction.MEAN,
|
||||
name="CrossShardOptimizer",
|
||||
group_assignment=None):
|
||||
"""Construct a new cross-shard optimizer.
|
||||
|
||||
Args:
|
||||
opt: An existing `Optimizer` to encapsulate.
|
||||
reduction: The reduction to apply to the shard losses.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "CrossShardOptimizer".
|
||||
group_assignment: Optional 2d int32 lists with shape
|
||||
[num_groups, num_replicas_per_group] which describles how to apply
|
||||
optimizer to subgroups.
|
||||
|
||||
Raises:
|
||||
ValueError: If reduction is not a valid cross-shard reduction.
|
||||
"""
|
||||
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
|
||||
raise ValueError("Unsupported reduction: %s." % reduction)
|
||||
|
||||
super(CrossShardOptimizer, self).__init__(False, name)
|
||||
self._opt = opt
|
||||
self._reduction = reduction
|
||||
self._group_assignment = group_assignment
|
||||
|
||||
def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
|
||||
"""Verify group_assignment and get the subgroup size".
|
||||
|
||||
Args:
|
||||
group_assignment: list of group ids for applying the optimizer
|
||||
to subgroups.
|
||||
num_shards: The number of TPU shards.
|
||||
|
||||
Returns:
|
||||
The size of one subgroup in group_assignment.
|
||||
|
||||
Raises:
|
||||
ValueError: If group_assignment is invalid.
|
||||
"""
|
||||
if not group_assignment:
|
||||
return None
|
||||
if not (isinstance(group_assignment, list) and
|
||||
all(isinstance(i, list) for i in group_assignment)):
|
||||
raise ValueError("group_assignment must be a list of list. Got {}".format(
|
||||
group_assignment))
|
||||
|
||||
replica_ids = set()
|
||||
for g in group_assignment:
|
||||
for i in g:
|
||||
replica_ids.add(i)
|
||||
|
||||
if set(range(num_shards)) != replica_ids:
|
||||
raise ValueError("group_assignment must be a permutation of range({0})."
|
||||
" Got group_assignment={1}".format(
|
||||
num_shards, group_assignment))
|
||||
|
||||
subgroup_size_list = [len(group) for group in group_assignment]
|
||||
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
|
||||
return subgroup_size_list[0]
|
||||
else:
|
||||
raise ValueError("The size of each subgroup in group_assignment must "
|
||||
"be equal. Got group_assignment={}".format(
|
||||
self._group_assignment))
|
||||
|
||||
def compute_gradients(self, loss, var_list=None, **kwargs):
|
||||
"""Compute gradients of "loss" for the variables in "var_list".
|
||||
|
||||
This simply wraps the compute_gradients() from the real optimizer. The
|
||||
gradients will be aggregated in the apply_gradients() so that user can
|
||||
modify the gradients like clipping with per replica global norm if needed.
|
||||
The global norm with aggregated gradients can be bad as one replica's huge
|
||||
gradients can hurt the gradients from other replicas.
|
||||
|
||||
Args:
|
||||
loss: A Tensor containing the value to minimize.
|
||||
var_list: Optional list or tuple of `tf.Variable` to update to minimize
|
||||
`loss`. Defaults to the list of variables collected in the graph
|
||||
under the key `GraphKey.TRAINABLE_VARIABLES`.
|
||||
**kwargs: Keyword arguments for compute_gradients().
|
||||
|
||||
Returns:
|
||||
A list of (gradient, variable) pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If not within a tpu_shard_context or group_assignment is
|
||||
invalid.
|
||||
"""
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
logging.warning(
|
||||
"CrossShardOptimizer should be used within a tpu_shard_context, but "
|
||||
"got unset number_of_shards. Assuming 1.")
|
||||
num_shards = 1
|
||||
|
||||
subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
|
||||
num_shards)
|
||||
|
||||
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
|
||||
if self._group_assignment:
|
||||
scale = 1.0 / subgroup_size
|
||||
else:
|
||||
scale = 1.0 / num_shards
|
||||
loss *= scale
|
||||
|
||||
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""Apply gradients to variables.
|
||||
|
||||
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
|
||||
replicas, and then applies the real optimizer.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||||
compute_gradients().
|
||||
global_step: Optional Variable to increment by one after the
|
||||
variables have been updated.
|
||||
name: Optional name for the returned operation. Default to the
|
||||
name passed to the Optimizer constructor.
|
||||
|
||||
Returns:
|
||||
An `Operation` that applies the gradients. If `global_step` was not None,
|
||||
that operation also increments `global_step`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the grads_and_vars is malformed.
|
||||
"""
|
||||
summed_grads_and_vars = []
|
||||
for (grad, var) in grads_and_vars:
|
||||
if grad is None:
|
||||
summed_grads_and_vars.append((grad, var))
|
||||
else:
|
||||
with ops.colocate_with(grad):
|
||||
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
|
||||
grad, self._group_assignment), var))
|
||||
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
|
||||
|
||||
def get_slot(self, *args, **kwargs):
|
||||
"""Return a slot named "name" created for "var" by the Optimizer.
|
||||
|
||||
This simply wraps the get_slot() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
"""
|
||||
return self._opt.get_slot(*args, **kwargs)
|
||||
|
||||
def get_slot_names(self, *args, **kwargs):
|
||||
"""Return a list of the names of slots created by the `Optimizer`.
|
||||
|
||||
This simply wraps the get_slot_names() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
A list of strings.
|
||||
"""
|
||||
return self._opt.get_slot_names(*args, **kwargs)
|
||||
|
||||
def variables(self):
|
||||
"""Forwarding the variables from the underlying optimizer."""
|
||||
return self._opt.variables()
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_optimizer import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,253 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for sharding during TPU compilation."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
_DEFAULT_NUMBER_OF_SHARDS = 1
|
||||
_DEFAULT_SHARD_DIMENSION = 0
|
||||
|
||||
|
||||
# TODO(b/36777903) change other parts of tpu.py to use this class.
|
||||
class ShardingPolicy(object):
|
||||
"""An object use to hold the sharding policy for a Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._number_of_shards = None
|
||||
self._shard_dimension = None
|
||||
self._frozen = False
|
||||
|
||||
def __str__(self):
|
||||
if self.number_of_shards is None or self.shard_dimension is None:
|
||||
return "ShardingPolicy(unset)"
|
||||
else:
|
||||
return ("ShardingPolicy(%d shards dimension %d)" %
|
||||
(self.number_of_shards, self.shard_dimension))
|
||||
|
||||
def _fill_default_values(self):
|
||||
if self._number_of_shards is None:
|
||||
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
|
||||
if self._shard_dimension is None:
|
||||
self._shard_dimension = tensor_shape.as_dimension(
|
||||
_DEFAULT_SHARD_DIMENSION)
|
||||
|
||||
def freeze(self):
|
||||
"""Prevents further modification to the sharding policy.
|
||||
|
||||
Any values that have not been set when freeze is called are set to
|
||||
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
|
||||
"""
|
||||
if not self._frozen:
|
||||
self._fill_default_values()
|
||||
self._frozen = True
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Returns the number of shards in the policy or None if unspecified."""
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards for the current policy.
|
||||
|
||||
If the policy has been frozen then number_of_shards must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
number_of_shards: The number of shards to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and number_of_shards
|
||||
differs from the frozen value; or number_of_shards <= 0.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._number_of_shards != number_of_shards:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %d shards since it has been "
|
||||
"frozen to use %d." % (number_of_shards, self._number_of_shards))
|
||||
else:
|
||||
if number_of_shards > 0:
|
||||
self._number_of_shards = number_of_shards
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %s shards; value must be >0",
|
||||
str(number_of_shards))
|
||||
|
||||
@property
|
||||
def shard_dimension(self):
|
||||
"""Returns the shard dimension of the policy or None if unspecified."""
|
||||
return self._shard_dimension
|
||||
|
||||
def set_shard_dimension(self, shard_dimension):
|
||||
"""Sets the shard dimension for the current policy.
|
||||
|
||||
If the policy has been frozen then shard_dimension must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
shard_dimension: The shard dimension to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and shard_dimension
|
||||
differs from the frozen value, or shard_dimension can't be
|
||||
interpreted as a Dimension.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._shard_dimension != shard_dimension:
|
||||
raise ValueError(
|
||||
"Can't set shard dimension to %d since it has been frozen to "
|
||||
"use %d." % (shard_dimension, self._shard_dimension))
|
||||
else:
|
||||
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
|
||||
|
||||
def merge(self, other):
|
||||
"""Merges the policy of another policy into the current policy.
|
||||
|
||||
Args:
|
||||
other: The policy to merge into this one.
|
||||
|
||||
Raises:
|
||||
ValueError: If this policy has been frozen and the merge conflicts with
|
||||
the frozen policy.
|
||||
"""
|
||||
if other.number_of_shards is not None:
|
||||
self.set_number_of_shards(other.number_of_shards)
|
||||
if other.shard_dimension is not None:
|
||||
self.set_shard_dimension(other.shard_dimension)
|
||||
|
||||
def get_sharded_shape(self, shape, shard_index=None):
|
||||
"""Returns the shape of a shard of a full Tensor.
|
||||
|
||||
When given the shape of a 'full-size' Tensor, returns the shape of
|
||||
the sub-Tensor after it has been sharded. Freezes the policy if it
|
||||
has not yet been frozen.
|
||||
|
||||
Args:
|
||||
shape: The shape of the full-size Tensor to be sharded.
|
||||
shard_index: The index of the shard whose shape should be returned.
|
||||
shard_index can be None for sharding policies that use the same
|
||||
shape for every shard.
|
||||
freeze_config:
|
||||
|
||||
Returns:
|
||||
The shape of the sharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If shard_index is None when shards are of different
|
||||
shapes; or shard_index is not None and
|
||||
!(0<=shard_index<number_of_shards); or shape does not have at
|
||||
least self.shard_dimension+1 dimensions; or the value of
|
||||
shape's shard dimension is not a multiple of
|
||||
self.number_of_shards
|
||||
"""
|
||||
if self._shard_dimension is None or self._number_of_shards is None:
|
||||
# Don't raise an error if the config is unset.
|
||||
return None
|
||||
if shard_index is not None:
|
||||
if shard_index < 0 or shard_index >= self.number_of_shards:
|
||||
raise ValueError("shard_index %d, but must be in [0,%d)." %
|
||||
(shard_index, self._number_of_shards))
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
if dims[self._shard_dimension] is None:
|
||||
raise ValueError("shape %s must have a fixed size for dimension %d "
|
||||
"that is known at graph construction time." %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
|
||||
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
||||
(shape.as_list(), self._number_of_shards,
|
||||
self._shard_dimension))
|
||||
dims[self._shard_dimension] /= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def _unshard_shape(self, shape):
|
||||
"""Return the unsharded shape that would generate a given sharded shape.
|
||||
|
||||
Args:
|
||||
shape: the sharded shape to unshard
|
||||
|
||||
Returns:
|
||||
The unsharded shape.
|
||||
|
||||
Raises:
|
||||
ValueError: if shape is unknown or does not contain
|
||||
self.shard_dimension
|
||||
TypeError: if shape is not convertible to a TensorShape
|
||||
"""
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
dims[self._shard_dimension] *= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def get_unsharded_shape(self, shapes):
|
||||
"""Returns the shape of an unsharded Tensor given a list of shards.
|
||||
|
||||
When given a list of shapes of shards, returns the shape of the
|
||||
unsharded Tensor that would generate the shards. Sets defaults for the
|
||||
policy if number_of_shards or shard_dimension is None.
|
||||
|
||||
Args:
|
||||
shapes: The shapes of the Tensor shards to be combined.
|
||||
|
||||
Returns:
|
||||
The shape of the unsharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if shapes is not a list of length
|
||||
self.number_of_shards; or any element of shapes is not a valid
|
||||
shape consistent with the sharding policy; or the list of
|
||||
shapes is not a valid sharding of a full shape.
|
||||
TypeError: if an element of shapes is not convertible to a
|
||||
TensorShape
|
||||
"""
|
||||
self._fill_default_values()
|
||||
if len(shapes) != self.number_of_shards:
|
||||
raise ValueError(
|
||||
"shapes is %s but must be a list of length number_of_shards=%d" % (
|
||||
str(shapes), self.number_of_shards))
|
||||
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
|
||||
for i in xrange(self.number_of_shards - 1):
|
||||
if not unsharded_shapes[i].is_compatible_with(
|
||||
unsharded_shapes[self.number_of_shards - 1]):
|
||||
raise ValueError(
|
||||
"sharded shapes %s are not consistent shards of a full shape "
|
||||
"sharded %d ways along dimension %d" % (
|
||||
str(shapes), self.number_of_shards, self.shard_dimension))
|
||||
return unsharded_shapes[0]
|
||||
# pylint: disable=wildcard-import,unused-import,redefined-builtin
|
||||
from tensorflow.python.tpu.tpu_sharding import *
|
||||
# pylint: enable=wildcard-import,unused-import,redefined-builtin
|
||||
|
@ -1,156 +1,25 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU system metadata and associated tooling."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min
|
||||
_RETRY_TIMES = 120
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||
|
||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
||||
|
||||
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
|
||||
# including num_cores and num_hosts.
|
||||
_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
|
||||
'num_cores',
|
||||
'num_hosts',
|
||||
'num_of_cores_per_host',
|
||||
'topology',
|
||||
'devices',
|
||||
])
|
||||
|
||||
|
||||
def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
query_topology=False):
|
||||
"""Automatically detects the TPU system metadata in the system."""
|
||||
tpu_core_count = 0
|
||||
devices = []
|
||||
device_dict = collections.defaultdict(list)
|
||||
|
||||
# TODO(b/120564445): Replace with standard library for retries.
|
||||
retry_count = 1
|
||||
while True:
|
||||
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
|
||||
master_address)
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
with session_lib.Session(
|
||||
master_address,
|
||||
config=get_session_config_with_timeout(
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS,
|
||||
cluster_def)) as sess:
|
||||
devices = sess.list_devices()
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
if match:
|
||||
host_id = match.group(1)
|
||||
core_id = match.group(2)
|
||||
device_dict[host_id].append(core_id)
|
||||
tpu_core_count += 1
|
||||
break
|
||||
except errors.DeadlineExceededError:
|
||||
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
||||
'not be ready (still scheduling) or the Tensorflow master address '
|
||||
'is incorrect: got (%s).' %
|
||||
(master_address))
|
||||
|
||||
# TODO(xiejw): For local or grpc master we might not need retry logic
|
||||
# here.
|
||||
if retry_count <= _RETRY_TIMES:
|
||||
logging.warning('%s', msg)
|
||||
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
|
||||
retry_count += 1
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
num_of_cores_per_host = 0
|
||||
if tpu_core_count:
|
||||
num_cores_per_host_set = set(
|
||||
[len(core_ids) for core_ids in device_dict.values()])
|
||||
if len(num_cores_per_host_set) != 1:
|
||||
raise RuntimeError(
|
||||
'TPU cores on each host is not same. This should not happen!. '
|
||||
'devices: {}'.format(devices))
|
||||
num_of_cores_per_host = num_cores_per_host_set.pop()
|
||||
|
||||
topology = None
|
||||
if query_topology:
|
||||
if not tpu_core_count:
|
||||
raise RuntimeError(
|
||||
'Cannot find any TPU cores in the system (master address {}). '
|
||||
'This usually means the master address is incorrect or the '
|
||||
'TPU worker has some problems. Available devices: {}'.format(
|
||||
master_address, devices))
|
||||
|
||||
topology = _obtain_topology(master_address, cluster_def)
|
||||
|
||||
metadata = _TPUSystemMetadata(
|
||||
num_cores=tpu_core_count,
|
||||
num_hosts=len(device_dict),
|
||||
num_of_cores_per_host=num_of_cores_per_host,
|
||||
topology=topology,
|
||||
devices=devices)
|
||||
|
||||
if tpu_core_count:
|
||||
logging.info('Found TPU system:')
|
||||
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
|
||||
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
|
||||
logging.info('*** Num TPU Cores Per Worker: %d',
|
||||
metadata.num_of_cores_per_host)
|
||||
for device in metadata.devices:
|
||||
logging.info('*** Available Device: %s', device)
|
||||
else:
|
||||
logging.info('Failed to find TPU: %s', metadata)
|
||||
return metadata
|
||||
|
||||
|
||||
def _obtain_topology(master_address, cluster_def):
|
||||
"""Obtains TPU fabric topology."""
|
||||
try:
|
||||
logging.info('Initializing TPU system (master: %s) to fetch topology '
|
||||
'for model parallelism. This might take a while.',
|
||||
master_address)
|
||||
with ops.Graph().as_default():
|
||||
session_config = get_session_config_with_timeout(
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
|
||||
with session_lib.Session(
|
||||
master_address, config=session_config) as sess:
|
||||
topology = sess.run(tpu.initialize_system())
|
||||
return topology
|
||||
except errors.DeadlineExceededError:
|
||||
raise ValueError(
|
||||
'Fail to initialize TPU system with master (%s). '
|
||||
'Please double check the TPU system is functional.' % (
|
||||
master_address))
|
||||
|
||||
|
||||
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
|
||||
"""Returns a session given a timeout and a cluster configuration."""
|
||||
config = config_pb2.ConfigProto(
|
||||
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
|
||||
return config
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_system_metadata import *
|
||||
# used by tests
|
||||
from tensorflow.python.tpu.tpu_system_metadata import _query_tpu_system_metadata
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,223 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Library for constructing a training loop, suitable for TPUs."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.compiler import xla
|
||||
from tensorflow.contrib.tpu.python.tpu import tensor_tracer
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop for TPUs.
|
||||
|
||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||
`condition` and `body` take the current value of the loop-carried
|
||||
tensors. 'body' additionally takes a tuple of infeed from
|
||||
infeed_queue if infeed_queue is not None. `condition` must return a
|
||||
single boolean value that determines whether iteration
|
||||
continues. `body` must return an updated list of values for the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
condition: a Python function that builds the loop condition.
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop, or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: if body or condition has the wrong signature.
|
||||
"""
|
||||
del name
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
|
||||
x in inputs]
|
||||
input_types = [x.dtype for x in inputs]
|
||||
input_arity = len(inputs)
|
||||
|
||||
body_arg_error = xla.check_function_argument_count(
|
||||
body, input_arity, infeed_queue)
|
||||
if body_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
|
||||
input_arity, str([i.name for i in inputs]), body_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s and %d additional inputs from "
|
||||
"infeed, but the computation needs %s" % (input_arity, str(
|
||||
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
|
||||
body_arg_error))
|
||||
condition_arg_error = xla.check_function_argument_count(
|
||||
condition, input_arity, None)
|
||||
if condition_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s. Note that infeed is not passed to the loop "
|
||||
"condition." % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
|
||||
def condition_wrapper(*inputs):
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
return condition(*inputs)
|
||||
|
||||
def body_wrapper(*inputs):
|
||||
"""Wrapper around `body` that handles infeed queues and control deps."""
|
||||
inputs = list(inputs)
|
||||
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
|
||||
# Runs `body` with the dequeue_ops appended.
|
||||
if infeed_queue:
|
||||
number_of_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if number_of_shards is None:
|
||||
raise ValueError("Can't build training loop with infeed when there is "
|
||||
"no tpu_shard_context. Are you building a loop or "
|
||||
"graph directly rather than from inside tpu.rewrite, "
|
||||
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
|
||||
infeed_queue.set_number_of_shards(number_of_shards)
|
||||
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
|
||||
else:
|
||||
dequeue_ops = []
|
||||
outputs = body(*(inputs + dequeue_ops))
|
||||
|
||||
# If the computation only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs
|
||||
if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
"TPU training loop body must return zero or more Tensor values "
|
||||
"followed by zero or more Operations.")
|
||||
|
||||
output_types = [op.dtype for op in output_tensors]
|
||||
if input_types != output_types:
|
||||
raise TypeError(
|
||||
"Mismatch between input types and output types for training loop "
|
||||
"body: {} vs {}".format(input_types, output_types))
|
||||
|
||||
# Add the dequeue operations to output_operations to ensure they are run
|
||||
# by the loop, even if the programmer's loop body does not use them.
|
||||
output_operations += dequeue_ops
|
||||
|
||||
# Add a dummy output, if needed.
|
||||
if not output_tensors:
|
||||
output_tensors = array_ops.constant(0)
|
||||
|
||||
if output_operations:
|
||||
# TODO(phawkins): in principle this is too restrictive since it serializes
|
||||
# the training loop steps. In practice it does not matter since this loop
|
||||
# will be compiled by XLA.
|
||||
output_tensors = control_flow_ops.tuple(output_tensors,
|
||||
control_inputs=output_operations)
|
||||
|
||||
if tensor_tracer.TensorTracer.is_enabled():
|
||||
num_replicas = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_replicas is None:
|
||||
num_replicas = 1
|
||||
tt = tensor_tracer.TensorTracer()
|
||||
output_tensors = tt.trace_tpu(ops.get_default_graph(),
|
||||
output_tensors, None,
|
||||
num_replicas)
|
||||
return output_tensors
|
||||
|
||||
# If the body has arity 0, add a dummy loop-carried value to which we can add
|
||||
# control dependencies from any side-effecting operations.
|
||||
if input_arity == 0:
|
||||
inputs = [array_ops.constant(0)]
|
||||
return control_flow_ops.while_loop(
|
||||
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
||||
|
||||
|
||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop that executes a fixed number of iterations.
|
||||
|
||||
The set of loop-carried tensors correspond to `inputs`.
|
||||
`body` must be a function that takes and returns the values of the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
n: the number of loop iterations
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
Raises:
|
||||
ValueError: if there is a type error.
|
||||
"""
|
||||
def _convert_to_list(xs):
|
||||
if not isinstance(xs, (list, tuple)):
|
||||
return [xs]
|
||||
else:
|
||||
return list(xs)
|
||||
|
||||
def cond(i, *args):
|
||||
del args
|
||||
return i < n
|
||||
|
||||
def body_wrapper(i, *args):
|
||||
return [i + 1] + _convert_to_list(body(*args))
|
||||
|
||||
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
|
||||
outputs = while_loop(
|
||||
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
|
||||
outputs = _convert_to_list(outputs)
|
||||
if len(outputs) == 1:
|
||||
# Returns the Op rather than an empty list.
|
||||
return outputs[0].op
|
||||
else:
|
||||
return outputs[1:]
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.training_loop import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,51 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Utilities for the functionalities."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training
|
||||
|
||||
def check_positive_integer(value, name):
|
||||
"""Checks whether `value` is a positive integer."""
|
||||
if not isinstance(value, six.integer_types):
|
||||
raise TypeError('{} must be int, got {}'.format(name, type(value)))
|
||||
|
||||
if value <= 0:
|
||||
raise ValueError('{} must be positive, got {}'.format(name, value))
|
||||
|
||||
|
||||
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
|
||||
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
|
||||
# python/estimator/util.py.
|
||||
class MultiHostDatasetInitializerHook(training.SessionRunHook):
|
||||
"""Creates a SessionRunHook that initializes all passed iterators."""
|
||||
|
||||
def __init__(self, dataset_initializers):
|
||||
self._initializers = dataset_initializers
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
del coord
|
||||
start = time.time()
|
||||
session.run(self._initializers)
|
||||
logging.info('Initialized dataset iterators in %d seconds',
|
||||
time.time() - start)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.util import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -2010,6 +2010,7 @@ tf_gen_op_wrapper_private_py(
|
||||
visibility = [
|
||||
"//smartass/brain/configure/python:__pkg__",
|
||||
"//tensorflow/contrib/tpu:__pkg__",
|
||||
"//tensorflow/python/tpu:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:tpu_configuration_ops_op_lib",
|
||||
|
334
tensorflow/python/tpu/BUILD
Normal file
334
tensorflow/python/tpu/BUILD
Normal file
@ -0,0 +1,334 @@
|
||||
# Description: Operations defined for Cloud TPUs
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_py_test",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//cloud/vmm/testing/tests/tpu:__subpackages__",
|
||||
"//knowledge/cerebra/sense/im2query:__subpackages__",
|
||||
"//learning/brain:__subpackages__",
|
||||
"//learning/deepmind:__subpackages__",
|
||||
"//medical/pathology:__subpackages__",
|
||||
"//tensorflow:__subpackages__",
|
||||
"//vr/perception:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_py",
|
||||
srcs = ["ops/tpu_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "async_checkpoint",
|
||||
srcs = ["async_checkpoint.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_estimator",
|
||||
srcs = [
|
||||
"_tpu_estimator_embedding.py",
|
||||
"error_handling.py",
|
||||
"tpu_config.py",
|
||||
"tpu_context.py",
|
||||
"tpu_estimator.py",
|
||||
"util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":async_checkpoint",
|
||||
":feature_column",
|
||||
":functional",
|
||||
":tpu_embedding",
|
||||
":tpu_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/estimator:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "functional",
|
||||
srcs = ["functional.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":feature_column",
|
||||
":tpu_embedding",
|
||||
":tpu_estimator",
|
||||
":tpu_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"bfloat16.py",
|
||||
"device_assignment.py",
|
||||
"session_support.py",
|
||||
"tensor_tracer.py",
|
||||
"topology.py",
|
||||
"tpu.py",
|
||||
"tpu_feed.py",
|
||||
"tpu_function.py",
|
||||
"tpu_optimizer.py",
|
||||
"tpu_sharding.py",
|
||||
"tpu_system_metadata.py",
|
||||
"training_loop.py",
|
||||
"xla.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":datasets",
|
||||
":functional",
|
||||
":tpu_py",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/tpu/profiler",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "datasets",
|
||||
srcs = [
|
||||
"datasets.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "datasets_test",
|
||||
size = "medium",
|
||||
srcs = ["datasets_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
":datasets",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
shard_count = 4,
|
||||
tags = ["no_oss"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:layers",
|
||||
],
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_sharding_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_sharding_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "bfloat16_test",
|
||||
size = "small",
|
||||
srcs = ["bfloat16_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_infeed_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_infeed_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_config_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_config_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_estimator_signals_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_estimator_signals_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
# TODO(jhseu): Remove. Fails in OSS on Python 3.
|
||||
tags = ["no_oss"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "topology_test",
|
||||
size = "medium",
|
||||
srcs = ["topology_test.py"],
|
||||
additional_deps = [
|
||||
":tpu",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_embedding",
|
||||
srcs = [
|
||||
"tpu_embedding.py",
|
||||
"tpu_embedding_gradient.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_lib",
|
||||
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
"//tensorflow/python:tpu_ops_gen",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "feature_column",
|
||||
srcs = ["feature_column.py"],
|
||||
deps = [
|
||||
":tpu_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/feature_column:feature_column_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "feature_column_test",
|
||||
srcs = [
|
||||
"feature_column_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":feature_column",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/feature_column:feature_column_py",
|
||||
],
|
||||
main = "feature_column_test.py",
|
||||
)
|
20
tensorflow/python/tpu/__init__.py
Normal file
20
tensorflow/python/tpu/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Ops related to Tensor Processing Units."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
334
tensorflow/python/tpu/_tpu_estimator_embedding.py
Normal file
334
tensorflow/python/tpu/_tpu_estimator_embedding.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Tooling for support TPU embedding in TPUEstimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.tpu import tpu_embedding
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
|
||||
tpu_fc._TPUSharedEmbeddingColumn)
|
||||
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
|
||||
core_fc_lib.EmbeddingColumn,
|
||||
core_fc._SharedEmbeddingColumn)
|
||||
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_TABLE_NAME_PREFIX = 'tbl_'
|
||||
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
|
||||
|
||||
|
||||
def _get_table_name_from_embedding_var_name(embedding_var_name):
|
||||
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
|
||||
|
||||
|
||||
def _get_embedding_var_name_from_table_name(table_name):
|
||||
return table_name[_LEN_TABLE_NAME_PREFIX:]
|
||||
|
||||
|
||||
def _get_embedding_variable_name(scope_name, var_name):
|
||||
return '{}/{}'.format(scope_name, var_name)
|
||||
|
||||
|
||||
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
|
||||
"""Return embedding variable names which are consistent with CPU runs."""
|
||||
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
|
||||
return tpu_embedding.AdagradSlotVariableName(
|
||||
'{}/{}/Adagrad'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
|
||||
return tpu_embedding.AdamSlotVariableNames(
|
||||
'{}/{}/Adam/m'.format(scope_name, var_name),
|
||||
'{}/{}/Adam/v'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters,
|
||||
tpu_embedding.StochasticGradientDescentParameters):
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Support to infer full variable name '
|
||||
'for optimization_parameter {} has not been added.'
|
||||
.format(optimization_parameters))
|
||||
|
||||
|
||||
def get_full_variable_names(
|
||||
graph, table_to_config_dict, optimization_parameters):
|
||||
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
|
||||
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
if not collection:
|
||||
raise RuntimeError(
|
||||
'Embedding feature column did not capture any thing. Make sure the '
|
||||
'feature columns passed to TPUEstimator constructor is properly '
|
||||
'used in model_fn.')
|
||||
|
||||
embedding_variable_name_by_table = {}
|
||||
slot_variable_names_by_table = {}
|
||||
for table_name in table_to_config_dict:
|
||||
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
|
||||
(scope_name, var_name) = collection[0][embedding_var_name]
|
||||
embedding_variable_name_by_table[table_name] = (
|
||||
_get_embedding_variable_name(scope_name, var_name))
|
||||
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
|
||||
scope_name, var_name, optimization_parameters)
|
||||
|
||||
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
return embedding_variable_name_by_table, slot_variable_names_by_table
|
||||
|
||||
|
||||
def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
"""Create configs for TPUEmbedding from a list of feature columns.
|
||||
|
||||
This function will place one embedding tensor per table and the return is
|
||||
intended to be used as input to TPUEmbedding.
|
||||
|
||||
Args:
|
||||
feature_columns: a list of supported feature columns.
|
||||
|
||||
Returns:
|
||||
A pair of dicts, the first maps tables to their config, the second maps
|
||||
features to tables.
|
||||
"""
|
||||
|
||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, allowed):
|
||||
raise TypeError(
|
||||
'Unsupported feature column {}. Supported types are {}.'.format(
|
||||
type(column), allowed))
|
||||
|
||||
table_to_config = {}
|
||||
feature_to_table = {}
|
||||
for column in feature_columns:
|
||||
feature_name = column.get_feature_key_name()
|
||||
table_name = _get_table_name_from_embedding_var_name(
|
||||
column.get_embedding_var_name())
|
||||
if feature_name in feature_to_table:
|
||||
raise ValueError(
|
||||
'Feature column {} is used with multiple embeddings and this is '
|
||||
'not supported.'.format(feature_name))
|
||||
feature_to_table[feature_name] = table_name
|
||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||
vocabulary_size=vocabulary_size,
|
||||
dimension=dimension,
|
||||
initializer=column.get_initializer(),
|
||||
combiner=column.get_combiner())
|
||||
|
||||
return table_to_config, feature_to_table
|
||||
|
||||
|
||||
def _get_tpu_embedding_optimization_parameters(embedding_config_spec):
|
||||
"""Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec."""
|
||||
if embedding_config_spec.optimizer_type == 'adagrad':
|
||||
return tpu_embedding.AdagradParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adagrad_initial_accumulator,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'sgd':
|
||||
return tpu_embedding.StochasticGradientDescentParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.use_gradient_accumulation)
|
||||
elif embedding_config_spec.optimizer_type == 'adam':
|
||||
return tpu_embedding.AdamParameters(
|
||||
embedding_config_spec.learning_rate,
|
||||
embedding_config_spec.adam_parameters.beta1,
|
||||
embedding_config_spec.adam_parameters.beta2,
|
||||
embedding_config_spec.adam_parameters.epsilon,
|
||||
use_gradient_accumulation=embedding_config_spec
|
||||
.use_gradient_accumulation)
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
|
||||
|
||||
AdamParameters = collections.namedtuple('AdamParameters',
|
||||
['beta1', 'beta2', 'epsilon'])
|
||||
|
||||
|
||||
# TODO(shizhiw): Improve the API to support more optimizer parameters in API.
|
||||
class EmbeddingConfigSpec(
|
||||
collections.namedtuple('EmbeddingConfigSpec', [
|
||||
'feature_columns', 'learning_rate', 'optimizer_type',
|
||||
'adagrad_initial_accumulator', 'clipping_limit',
|
||||
'use_gradient_accumulation', 'adam_parameters'
|
||||
])):
|
||||
"""Class to keep track of embedding config specification."""
|
||||
|
||||
def __new__(cls,
|
||||
feature_columns,
|
||||
learning_rate,
|
||||
optimizer_type='adagrad',
|
||||
adagrad_initial_accumulator=None,
|
||||
clipping_limit=None,
|
||||
use_gradient_accumulation=False,
|
||||
adam_parameters=None):
|
||||
"""Creates an EmbeddingConfigSpec instance.
|
||||
|
||||
Args:
|
||||
feature_columns: All `FeatureColumn`s used by model.
|
||||
learning_rate: embedding optimizer learning rate.
|
||||
optimizer_type: (String) Name of the optimizer for embedding gradients
|
||||
updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default
|
||||
value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam'
|
||||
(`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will
|
||||
be applied to all embedding variables specified by `feature_columns`.
|
||||
adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when
|
||||
optimizer_type is 'adagrad'. Default is `0.1`.
|
||||
clipping_limit: (Optional) Clipping limit (absolute value).
|
||||
use_gradient_accumulation: (Experimental) Whether to accumulate the
|
||||
gradients across TPU embedding mini-batches. Gradient accumulation does
|
||||
not affect SGD and therefore this is applicable only for Adagrad.
|
||||
adam_parameters: AdamParameters. Used when optimizer_type is 'adam'.
|
||||
Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon.
|
||||
|
||||
Returns:
|
||||
An EmbeddingConfigSpec instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the feature_columns are not specified.
|
||||
TypeError: If the feature columns are not of ths correct type (one of
|
||||
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
|
||||
_EMBEDDING_COLUMN_CLASSES).
|
||||
ValueError: If use_gradient_accumulation is True for SGD.
|
||||
ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or
|
||||
"adam".
|
||||
"""
|
||||
if not feature_columns:
|
||||
raise ValueError('`feature_columns` cannot be `None` or empty.')
|
||||
|
||||
# It is unknown at this moment, whether the TPUEstimator is running in CPU
|
||||
# or TPU mode. So allow non-TPU embedding columns also.
|
||||
supported_classes = tuple(
|
||||
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
|
||||
list(_EMBEDDING_COLUMN_CLASSES))
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, supported_classes):
|
||||
raise TypeError(
|
||||
'All feature columns must be supported types in {}. Got {}'.format(
|
||||
supported_classes, type(column)))
|
||||
|
||||
if optimizer_type == 'adagrad':
|
||||
if adagrad_initial_accumulator is None:
|
||||
adagrad_initial_accumulator = 0.1
|
||||
if adagrad_initial_accumulator <= 0:
|
||||
raise ValueError('Adagrad initial_accumulator must be positive')
|
||||
elif optimizer_type == 'sgd':
|
||||
if use_gradient_accumulation:
|
||||
raise ValueError('Gradient accumulation makes sense for Adagrad only.')
|
||||
elif optimizer_type == 'adam':
|
||||
if adam_parameters is None:
|
||||
adam_parameters = AdamParameters(0.9, 0.999, 1e-8)
|
||||
if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.:
|
||||
raise ValueError('beta1 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta1))
|
||||
if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.:
|
||||
raise ValueError('beta2 must be between 0. and 1; got {}.'.format(
|
||||
adam_parameters.beta2))
|
||||
if adam_parameters.epsilon <= 0.:
|
||||
raise ValueError('epsilon must be positive; got {}.'.format(
|
||||
adam_parameters.epsilon))
|
||||
else:
|
||||
raise ValueError('optimizer_type must be adagrad or sgd or adam for now.')
|
||||
|
||||
return super(EmbeddingConfigSpec, cls).__new__(
|
||||
cls,
|
||||
feature_columns=feature_columns,
|
||||
learning_rate=learning_rate,
|
||||
optimizer_type=optimizer_type,
|
||||
adagrad_initial_accumulator=adagrad_initial_accumulator,
|
||||
clipping_limit=clipping_limit,
|
||||
use_gradient_accumulation=use_gradient_accumulation,
|
||||
adam_parameters=adam_parameters)
|
||||
|
||||
|
||||
class EmbeddingConfig(object):
|
||||
"""This is the internal immutable object for embedding config.
|
||||
|
||||
`_EmbeddingConfig` is responsible to _translate_ user provided
|
||||
`EmbeddingConfigSpec` to internal data structures, mostly constructor
|
||||
arguments of `TPUEmbedding`.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
|
||||
num_hosts, num_cores, master):
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._num_hosts = num_hosts
|
||||
self._num_cores = num_cores
|
||||
self._master = master
|
||||
|
||||
self._table_to_config_dict, self._feature_to_table_dict = (
|
||||
get_tpu_embedding_config_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
self._optimization_parameters = _get_tpu_embedding_optimization_parameters(
|
||||
self._embedding_config_spec)
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
|
||||
def has_embedding_tables(self):
|
||||
return bool(self._table_to_config_dict)
|
||||
|
||||
def _create_tpu_embedding(self, mode):
|
||||
"""Create tpu_embedding.TPUEmbedding based on mode."""
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
batch_size = self._train_batch_size
|
||||
else:
|
||||
batch_size = self._eval_batch_size
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
tpu_embedding_mode = tpu_embedding.TRAINING
|
||||
elif (mode == model_fn_lib.ModeKeys.EVAL or
|
||||
mode == model_fn_lib.ModeKeys.PREDICT):
|
||||
tpu_embedding_mode = tpu_embedding.INFERENCE
|
||||
else:
|
||||
raise ValueError('Mode {} is not supported.'.format(mode))
|
||||
|
||||
tpu_embedding_ = tpu_embedding.TPUEmbedding(
|
||||
self._table_to_config_dict,
|
||||
self._feature_to_table_dict,
|
||||
batch_size,
|
||||
tpu_embedding_mode,
|
||||
self._master,
|
||||
self._optimization_parameters,
|
||||
)
|
||||
return tpu_embedding_
|
||||
|
||||
def get_tpu_embedding(self, mode):
|
||||
if mode not in self._mode_to_tpu_embedding_dict:
|
||||
self._mode_to_tpu_embedding_dict[mode] = (
|
||||
self._create_tpu_embedding(mode))
|
||||
return self._mode_to_tpu_embedding_dict[mode]
|
||||
|
||||
|
||||
def split_inputs(ctx, features, labels):
|
||||
"""Splits the dense and sparse tensors inside the features and labels."""
|
||||
sparse_features = collections.OrderedDict()
|
||||
if ctx.embedding_config:
|
||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||
for feature_key in tpu_embedding_.feature_to_table_dict:
|
||||
sparse_features[feature_key] = features.pop(feature_key)
|
||||
|
||||
return features, labels, sparse_features
|
212
tensorflow/python/tpu/async_checkpoint.py
Normal file
212
tensorflow/python/tpu/async_checkpoint.py
Normal file
@ -0,0 +1,212 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Hook for asynchronous checkpointing.
|
||||
|
||||
This hook dispatches checkpoint writing operations in a separate thread to
|
||||
allow execution to continue on the main thread.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tensorflow.core.util.event_pb2 import SessionLog
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
|
||||
|
||||
class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
"""Saves checkpoints every N steps or seconds."""
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_dir,
|
||||
save_secs=None,
|
||||
save_steps=None,
|
||||
saver=None,
|
||||
checkpoint_basename="model.ckpt",
|
||||
scaffold=None,
|
||||
listeners=None):
|
||||
"""Initializes a `CheckpointSaverHook`.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: `str`, base directory for the checkpoint files.
|
||||
save_secs: `int`, save every N secs.
|
||||
save_steps: `int`, save every N steps.
|
||||
saver: `Saver` object, used for saving.
|
||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||
scaffold: `Scaffold`, use to get saver object.
|
||||
listeners: List of `CheckpointSaverListener` subclass instances. Used for
|
||||
callbacks that run immediately before or after this hook saves the
|
||||
checkpoint.
|
||||
|
||||
Raises:
|
||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||
ValueError: At most one of `saver` or `scaffold` should be set.
|
||||
"""
|
||||
logging.info("Create AsyncCheckpointSaverHook.")
|
||||
if saver is not None and scaffold is not None:
|
||||
raise ValueError("You cannot provide both saver and scaffold.")
|
||||
self._saver = saver
|
||||
self._save_thread = None
|
||||
self._write_graph_thread = None
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||
self._scaffold = scaffold
|
||||
self._timer = basic_session_run_hooks.SecondOrStepTimer(
|
||||
every_secs=save_secs, every_steps=save_steps)
|
||||
self._listeners = listeners or []
|
||||
self._steps_per_run = 1
|
||||
self._summary_writer = None
|
||||
self._global_step_tensor = None
|
||||
|
||||
self._last_checkpoint_step = None
|
||||
|
||||
def _set_steps_per_run(self, steps_per_run):
|
||||
self._steps_per_run = steps_per_run
|
||||
|
||||
def begin(self):
|
||||
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
|
||||
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
|
||||
if self._global_step_tensor is None:
|
||||
raise RuntimeError(
|
||||
"Global step should be created to use CheckpointSaverHook.")
|
||||
for l in self._listeners:
|
||||
l.begin()
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
global_step = session.run(self._global_step_tensor)
|
||||
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
# We cannot do this in begin, since we let other hooks to change graph and
|
||||
# add variables in begin. Graph is finalized after all begin calls.
|
||||
def _write_graph_fn(self):
|
||||
training_util.write_graph(
|
||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir, "graph.pbtxt")
|
||||
self._write_graph_thread = threading.Thread(target=_write_graph_fn,
|
||||
args=[self])
|
||||
self._write_graph_thread.start()
|
||||
|
||||
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
||||
graph = ops.get_default_graph()
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
|
||||
self._summary_writer.add_graph(graph)
|
||||
self._summary_writer.add_meta_graph(meta_graph_def)
|
||||
# The checkpoint saved here is the state at step "global_step".
|
||||
self._save(session, global_step)
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
return SessionRunArgs(self._global_step_tensor)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
logging.info("Triggering checkpoint. %s", global_step)
|
||||
if self._save(run_context.session, global_step):
|
||||
run_context.request_stop()
|
||||
|
||||
def end(self, session):
|
||||
if self._save_thread:
|
||||
logging.info("Waiting for any pending checkpoints to finish.")
|
||||
self._save_thread.join()
|
||||
if self._write_graph_thread:
|
||||
logging.info("Waiting for any pending write_graph to finish.")
|
||||
self._write_graph_thread.join()
|
||||
|
||||
last_step = session.run(self._global_step_tensor)
|
||||
|
||||
if self._last_checkpoint_step != last_step:
|
||||
self._save(session, last_step, asynchronous=False)
|
||||
|
||||
for l in self._listeners:
|
||||
l.end(session, last_step)
|
||||
|
||||
def _save(self, session, step, asynchronous=True):
|
||||
"""Saves the latest checkpoint, returns should_stop."""
|
||||
|
||||
# Skip saving on step 0
|
||||
if step == 0:
|
||||
return
|
||||
|
||||
def _save_fn():
|
||||
"""Run the saver process."""
|
||||
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
|
||||
|
||||
start_time = time.time()
|
||||
for l in self._listeners:
|
||||
l.before_save(session, step)
|
||||
|
||||
self._get_saver().save(session, self._save_path, global_step=step)
|
||||
self._summary_writer.add_session_log(
|
||||
SessionLog(
|
||||
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
|
||||
step)
|
||||
|
||||
for l in self._listeners:
|
||||
l.after_save(session, step)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info("Checkpoint actual writing time: (%.3f sec)",
|
||||
end_time - start_time)
|
||||
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
|
||||
|
||||
if not asynchronous:
|
||||
self._last_checkpoint_step = step
|
||||
_save_fn()
|
||||
return
|
||||
|
||||
if self._save_thread is not None:
|
||||
self._save_thread.join(timeout=0.1)
|
||||
if self._save_thread.is_alive():
|
||||
logging.info("Saver thread still in progress, skipping checkpoint.")
|
||||
return
|
||||
|
||||
self._last_checkpoint_step = step
|
||||
self._save_thread = threading.Thread(target=_save_fn)
|
||||
self._save_thread.start()
|
||||
|
||||
def _get_saver(self):
|
||||
if self._saver is not None:
|
||||
return self._saver
|
||||
elif self._scaffold is not None:
|
||||
return self._scaffold.saver
|
||||
|
||||
# Get saver from the SAVERS collection if present.
|
||||
collection_key = ops.GraphKeys.SAVERS
|
||||
savers = ops.get_collection(collection_key)
|
||||
if not savers:
|
||||
raise RuntimeError(
|
||||
"No items in collection {}. Please add a saver to the collection "
|
||||
"or provide a saver or scaffold.".format(collection_key))
|
||||
elif len(savers) > 1:
|
||||
raise RuntimeError(
|
||||
"More than one item in collection {}. "
|
||||
"Please indicate which one to use by passing it to the constructor."
|
||||
.format(collection_key))
|
||||
|
||||
self._saver = savers[0]
|
||||
return savers[0]
|
77
tensorflow/python/tpu/bfloat16.py
Normal file
77
tensorflow/python/tpu/bfloat16.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper context for running models with bfloat16."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def _get_custom_getter():
|
||||
"""Returns a custom getter that this class's methods must be called under.
|
||||
|
||||
All methods of this class must be called under a variable scope that was
|
||||
passed this custom getter. Example:
|
||||
|
||||
```python
|
||||
network = ConvNetBuilder(...)
|
||||
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
|
||||
network.conv(...)
|
||||
# Call more methods of network here
|
||||
```
|
||||
|
||||
Currently, this custom getter only does anything if self.use_tf_layers is
|
||||
True. In that case, it causes variables to be stored as dtype
|
||||
self.variable_type, then casted to the requested dtype, instead of directly
|
||||
storing the variable as the requested dtype.
|
||||
"""
|
||||
|
||||
def inner_custom_getter(getter, *args, **kwargs):
|
||||
"""Custom getter that forces variables to have type self.variable_type."""
|
||||
cast_to_bfloat16 = False
|
||||
requested_dtype = kwargs['dtype']
|
||||
if requested_dtype == dtypes.bfloat16:
|
||||
# Only change the variable dtype if doing so does not decrease variable
|
||||
# precision.
|
||||
kwargs['dtype'] = dtypes.float32
|
||||
cast_to_bfloat16 = True
|
||||
var = getter(*args, **kwargs)
|
||||
# This if statement is needed to guard the cast, because batch norm
|
||||
# assigns directly to the return value of this custom getter. The cast
|
||||
# makes the return value not a variable so it cannot be assigned. Batch
|
||||
# norm variables are always in fp32 so this if statement is never
|
||||
# triggered for them.
|
||||
if cast_to_bfloat16:
|
||||
var = math_ops.cast(var, dtypes.bfloat16)
|
||||
return var
|
||||
|
||||
return inner_custom_getter
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def bfloat16_scope():
|
||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||
|
||||
This enables variables to be read as bfloat16 type when using get_variable.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
'', custom_getter=_get_custom_getter()) as varscope:
|
||||
yield varscope
|
@ -19,11 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import bfloat16
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import bfloat16
|
||||
|
||||
|
||||
class BFloat16ScopeTest(test.TestCase):
|
191
tensorflow/python/tpu/datasets.py
Normal file
191
tensorflow/python/tpu/datasets.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Library of Cloud TPU helper functions for data loading."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
|
||||
|
||||
def _TextLineDataset(filename):
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def _TFRecordDataset(filename):
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
_FILETYPE_MAP = {
|
||||
'tfrecord': _TFRecordDataset,
|
||||
'textline': _TextLineDataset,
|
||||
'text': _TextLineDataset,
|
||||
}
|
||||
|
||||
|
||||
def StreamingFilesDataset(files,
|
||||
filetype=None,
|
||||
file_reader_job=None,
|
||||
worker_job=None,
|
||||
num_epochs=None,
|
||||
filename_shuffle_buffer_size=None,
|
||||
num_parallel_reads=None,
|
||||
batch_transfer_size=None,
|
||||
sloppy=None):
|
||||
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
||||
|
||||
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
||||
files local to your GCE VM. In order to train using files stored on your local
|
||||
VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset
|
||||
helper to generate a dataset to feed your Cloud TPU with files from your GCE
|
||||
VM.
|
||||
|
||||
The resulting dataset may return an OutOfRangeError if there are no files
|
||||
found as a result of the fileglob expansion.
|
||||
|
||||
Note: StreamingFilesDataset assumes that the session is using a
|
||||
TPUClusterResolver and has therefore a worker and a coordinator job. File
|
||||
loading will be done on the coordinator job.
|
||||
|
||||
Args:
|
||||
files: A string glob to match files, or a `tf.data.Dataset` generating file
|
||||
names.
|
||||
filetype: A string (one of 'tfrecord', or 'textline') or a single-argument
|
||||
TensorFlow function that when given a filename returns a dataset.
|
||||
file_reader_job: An optional string that corresponds to the job that should
|
||||
perform the file reads.
|
||||
worker_job: An optional string that corresponds to the job that should
|
||||
process the tensors (i.e. your GPU or TPU worker).
|
||||
num_epochs: The number of epochs through the training set that should be
|
||||
generated. By default, it will repeat infinitely.
|
||||
filename_shuffle_buffer_size: An optional integer whose value controls the
|
||||
shuffling of the file names. If you would like to read from the files in
|
||||
the same order, set to 0 or False.
|
||||
num_parallel_reads: An optional integer controlling the number of files to
|
||||
read from concurrently. (Set to 1 for no parallelism.)
|
||||
batch_transfer_size: An optional integer controlling the batching used to
|
||||
amortize the remote function invocation overhead. Set to a very large
|
||||
number to increase throughput. Set to a very small number to reduce memory
|
||||
consumption. Set to False to skip batching.
|
||||
sloppy: (Optional.) If `False`, read input data while maintaining a
|
||||
deterministic order. (This may have significant performance impacts.)
|
||||
sloppy defaults to: True.
|
||||
Returns:
|
||||
A `tf.data.Dataset` with an infinite stream of elements generated by a
|
||||
parallel interleaving of the set of files matched (or generated) by `files`
|
||||
with a type is the output of the dataset specified by `filetype`.
|
||||
|
||||
Raises:
|
||||
ValueError: if any argument is not of the expected type.
|
||||
"""
|
||||
if filetype is None:
|
||||
filetype = 'tfrecord'
|
||||
|
||||
if isinstance(filetype, str):
|
||||
if filetype not in _FILETYPE_MAP:
|
||||
raise ValueError('Unexpected filetype: %s' % filetype)
|
||||
reader_fn = _FILETYPE_MAP[filetype]
|
||||
elif callable(filetype):
|
||||
reader_fn = filetype
|
||||
else:
|
||||
raise ValueError('filetype should be a string or a callable')
|
||||
|
||||
file_reader_job = file_reader_job or 'coordinator'
|
||||
|
||||
worker_job = worker_job or 'worker'
|
||||
|
||||
if filename_shuffle_buffer_size is None:
|
||||
filename_shuffle_buffer_size = 4096
|
||||
|
||||
num_parallel_reads = num_parallel_reads or 8
|
||||
|
||||
if batch_transfer_size is None:
|
||||
batch_transfer_size = 256
|
||||
|
||||
if sloppy is None:
|
||||
sloppy = True
|
||||
|
||||
with ops.device('/job:%s' % file_reader_job):
|
||||
if isinstance(files, str):
|
||||
source_dataset = dataset_ops.Dataset.list_files(files)
|
||||
elif isinstance(files, dataset_ops.DatasetV2):
|
||||
source_dataset = files
|
||||
else:
|
||||
raise ValueError('files was not a string or a dataset: %s' % files)
|
||||
|
||||
if filename_shuffle_buffer_size:
|
||||
source_dataset = source_dataset.shuffle(
|
||||
buffer_size=filename_shuffle_buffer_size)
|
||||
|
||||
source_dataset = source_dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
|
||||
|
||||
source_dataset = source_dataset.repeat(num_epochs)
|
||||
|
||||
if batch_transfer_size:
|
||||
source_dataset = source_dataset.batch(batch_transfer_size)
|
||||
|
||||
source_dataset = source_dataset.prefetch(1)
|
||||
|
||||
source_iterator = dataset_ops.make_one_shot_iterator(source_dataset)
|
||||
source_handle = source_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def LoadingFunc(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, source_dataset.output_types, source_dataset.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
def MapFn(unused_input):
|
||||
if isinstance(source_dataset.output_types, dtypes.DType):
|
||||
output_types = [source_dataset.output_types]
|
||||
elif isinstance(source_dataset.output_types, (list, tuple)):
|
||||
output_types = source_dataset.output_types
|
||||
else:
|
||||
raise ValueError('source dataset has invalid output types')
|
||||
remote_calls = functional_ops.remote_call(
|
||||
args=[source_handle],
|
||||
Tout=output_types,
|
||||
f=LoadingFunc,
|
||||
target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
|
||||
if len(remote_calls) == 1:
|
||||
return remote_calls[0]
|
||||
else:
|
||||
return remote_calls
|
||||
|
||||
with ops.device('/job:%s' % worker_job):
|
||||
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
|
||||
MapFn, num_parallel_calls=4 if sloppy else None)
|
||||
output_dataset = output_dataset.prefetch(1)
|
||||
|
||||
if batch_transfer_size:
|
||||
# Undo the batching used during the transfer.
|
||||
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
|
||||
|
||||
return output_dataset
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import datasets
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -31,6 +30,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import datasets
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
313
tensorflow/python/tpu/device_assignment.py
Normal file
313
tensorflow/python/tpu/device_assignment.py
Normal file
@ -0,0 +1,313 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Library of TPU helper functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.tpu.topology import Topology
|
||||
|
||||
|
||||
SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]]
|
||||
|
||||
|
||||
def _compute_task_and_cores_to_replicas(core_assignment, topology):
|
||||
"""Computes a nested dict which maps task and logical core to replicas."""
|
||||
task_and_cores_to_replicas = {}
|
||||
for replica in xrange(core_assignment.shape[0]):
|
||||
for logical_core in xrange(core_assignment.shape[1]):
|
||||
coordinates = core_assignment[replica, logical_core, :]
|
||||
task_id = topology.task_ordinal_at_coordinates(coordinates)
|
||||
if task_id not in task_and_cores_to_replicas:
|
||||
task_and_cores_to_replicas[task_id] = {}
|
||||
if logical_core not in task_and_cores_to_replicas[task_id]:
|
||||
task_and_cores_to_replicas[task_id][logical_core] = set()
|
||||
|
||||
task_and_cores_to_replicas[task_id][logical_core].add(replica)
|
||||
|
||||
task_to_sorted_replica_id = {}
|
||||
|
||||
for task, core_to_replicas in task_and_cores_to_replicas.items():
|
||||
core_to_sorted_replicas = {}
|
||||
for core, replicas in core_to_replicas.items():
|
||||
core_to_sorted_replicas[core] = sorted(replicas)
|
||||
|
||||
task_to_sorted_replica_id[task] = core_to_sorted_replicas
|
||||
return task_to_sorted_replica_id
|
||||
|
||||
|
||||
class DeviceAssignment(object):
|
||||
"""Mapping from logical cores in a computation to the physical TPU topology.
|
||||
|
||||
Prefer to use the `device_assignment()` helper to construct a
|
||||
`DeviceAssignment`; it is easier if less flexible than constructing a
|
||||
`DeviceAssignment` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, topology, core_assignment):
|
||||
"""Constructs a `DeviceAssignment` object.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the physical TPU topology.
|
||||
core_assignment: A logical to physical core mapping, represented as a
|
||||
rank 3 numpy array. See the description of the `core_assignment`
|
||||
property for more details.
|
||||
|
||||
Raises:
|
||||
ValueError: If `topology` is not `Topology` object.
|
||||
ValueError: If `core_assignment` is not a rank 3 numpy array.
|
||||
"""
|
||||
if not isinstance(topology, Topology):
|
||||
raise ValueError("topology must be a Topology object, got {}".format(
|
||||
type(topology)))
|
||||
core_assignment = np.asarray(core_assignment, dtype=np.int32)
|
||||
|
||||
self._topology = topology
|
||||
|
||||
if core_assignment.ndim != 3:
|
||||
raise ValueError("core_assignment must be a rank 3 numpy array, "
|
||||
"got shape {}".format(core_assignment.shape))
|
||||
|
||||
self._num_replicas = core_assignment.shape[0]
|
||||
self._num_cores_per_replica = core_assignment.shape[1]
|
||||
|
||||
if core_assignment.shape[-1] != topology.mesh_rank:
|
||||
raise ValueError(
|
||||
"minor dimension of core_assignment must have size equal to topology "
|
||||
"rank ({}), got shape {}".format(topology.mesh_rank,
|
||||
core_assignment.shape))
|
||||
|
||||
self._core_assignment = core_assignment
|
||||
self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
|
||||
self._core_assignment, topology)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
"""A `Topology` that describes the TPU topology."""
|
||||
return self._topology
|
||||
|
||||
@property
|
||||
def num_cores_per_replica(self):
|
||||
"""The number of cores per replica."""
|
||||
return self._num_cores_per_replica
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
"""The number of replicas of the computation."""
|
||||
return self._num_replicas
|
||||
|
||||
@property
|
||||
def core_assignment(self):
|
||||
"""The logical to physical core mapping.
|
||||
|
||||
Returns:
|
||||
An integer numpy array of rank 3, with shape
|
||||
`[num_replicas, num_cores_per_replica, topology_rank]`. Maps
|
||||
(replica, logical core) pairs to physical topology coordinates.
|
||||
"""
|
||||
return self._core_assignment
|
||||
|
||||
def coordinates(self, replica, logical_core):
|
||||
"""Returns the physical topology coordinates of a logical core."""
|
||||
return tuple(self.core_assignment[replica, logical_core, :])
|
||||
|
||||
def lookup_replicas(self, task_id, logical_core):
|
||||
"""Lookup replica ids by task number and logical core.
|
||||
|
||||
Args:
|
||||
task_id: TensorFlow task number.
|
||||
logical_core: An integer, identifying a logical core.
|
||||
Returns:
|
||||
A sorted list of the replicas that are attached to that task and
|
||||
logical_core.
|
||||
Raises:
|
||||
ValueError: If no replica exists in the task which contains the logical
|
||||
core.
|
||||
"""
|
||||
try:
|
||||
return self._task_and_cores_to_replicas[task_id][logical_core]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Can not find any replica in task: {} contains logical_core: {} ".
|
||||
format(task_id, logical_core))
|
||||
|
||||
def tpu_ordinal(self, replica=0, logical_core=0):
|
||||
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
||||
|
||||
def host_device(self, replica=0, logical_core=0, job=None):
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
def tpu_device(self, replica=0, logical_core=0, job=None):
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
|
||||
def device_assignment(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1):
|
||||
"""Computes a device_assignment of a computation across a TPU topology.
|
||||
|
||||
Attempts to choose a compact grid of cores for locality.
|
||||
|
||||
Returns a `DeviceAssignment` that describes the cores in the topology assigned
|
||||
to each core of each replica.
|
||||
|
||||
`computation_shape` and `computation_stride` values should be powers of 2 for
|
||||
optimal packing.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the TPU cluster topology.
|
||||
To obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
`initialize_system` using `Session.run`. Either a serialized
|
||||
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
|
||||
computation_shape: A rank 1 int32 numpy array with size equal to the
|
||||
topology rank, describing the shape of the computation's block of cores.
|
||||
If None, the `computation_shape` is `[1] * topology_rank`.
|
||||
computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
|
||||
describing the inter-core spacing of the `computation_shape` cores in the
|
||||
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
|
||||
num_replicas: The number of computation replicas to run. The replicas will
|
||||
be packed into the free spaces of the topology.
|
||||
|
||||
Returns:
|
||||
A DeviceAssignment object, which describes the mapping between the logical
|
||||
cores in each computation replica and the physical cores in the TPU
|
||||
topology.
|
||||
|
||||
Raises:
|
||||
ValueError: If `topology` is not a valid `Topology` object.
|
||||
ValueError: If `computation_shape` or `computation_stride` are not 1D int32
|
||||
numpy arrays with shape [3] where all values are positive.
|
||||
ValueError: If computation's replicas cannot fit into the TPU topology.
|
||||
"""
|
||||
# Deserialize the Topology proto, if it is a string.
|
||||
if isinstance(topology, bytes):
|
||||
topology = Topology(serialized=topology)
|
||||
|
||||
if not isinstance(topology, Topology):
|
||||
raise ValueError("`topology` is not a Topology object; got {}".format(
|
||||
type(topology)))
|
||||
|
||||
topology_rank = len(topology.mesh_shape)
|
||||
mesh_shape = topology.mesh_shape
|
||||
if computation_shape is None:
|
||||
computation_shape = np.array([1] * topology_rank, dtype=np.int32)
|
||||
else:
|
||||
computation_shape = np.asarray(computation_shape, dtype=np.int32)
|
||||
|
||||
if computation_stride is None:
|
||||
computation_stride = np.array([1] * topology_rank, dtype=np.int32)
|
||||
else:
|
||||
computation_stride = np.asarray(computation_stride, dtype=np.int32)
|
||||
|
||||
if computation_shape.shape != (topology_rank,):
|
||||
raise ValueError("computation_shape must have shape [{}]; got {}".format(
|
||||
topology_rank, computation_shape.shape))
|
||||
if computation_stride.shape != (topology_rank,):
|
||||
raise ValueError("computation_stride must have shape [{}]; got {}".format(
|
||||
topology_rank, computation_stride.shape))
|
||||
|
||||
if any(computation_shape < 1):
|
||||
raise ValueError(
|
||||
"computation_shape must be positive; got computation_shape={}".format(
|
||||
computation_shape))
|
||||
if any(computation_stride < 1):
|
||||
raise ValueError(
|
||||
"computation_stride must be positive; got computation_stride={}".format(
|
||||
computation_stride))
|
||||
|
||||
# Computes the physical size of one computation instance.
|
||||
computation_footprint = computation_shape * computation_stride
|
||||
if any(computation_footprint > mesh_shape):
|
||||
raise ValueError(
|
||||
"computation footprint {} does not fit in TPU topology shape {}".format(
|
||||
computation_footprint, mesh_shape))
|
||||
|
||||
# Computes how many copies of the computation footprint fit in the mesh.
|
||||
block_counts = mesh_shape // computation_footprint
|
||||
|
||||
replica_counts = block_counts * computation_stride
|
||||
max_replicas = np.prod(replica_counts)
|
||||
if num_replicas > max_replicas:
|
||||
raise ValueError(
|
||||
"requested {} replicas but only {} replicas with shape {} and "
|
||||
"computation_stride {} fit in a TPU mesh of shape {}".format(
|
||||
num_replicas, max_replicas, computation_shape, computation_stride,
|
||||
mesh_shape))
|
||||
|
||||
def ceil_of_ratio(n, m):
|
||||
return (n + m - 1) // m
|
||||
|
||||
replica_shape = [0] * topology_rank
|
||||
if num_replicas > 0:
|
||||
remaining_replicas = num_replicas
|
||||
remaining_dims = topology_rank
|
||||
|
||||
# Choose dimensions as close to an equal cube as possible, in order of
|
||||
# increasing dimension size. By visiting dimensions in increasing size, we
|
||||
# assign the most constrained dimension first, so we won't make infeasible
|
||||
# choices.
|
||||
#
|
||||
# As a secondary sort order, visit the dimensions in reverse order. This
|
||||
# means we try to use both cores on the same chip in preference to two cores
|
||||
# on different chips.
|
||||
for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
|
||||
i = -ni
|
||||
target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
|
||||
replica_shape[i] = min(target_size, x)
|
||||
remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
|
||||
remaining_dims -= 1
|
||||
|
||||
assert remaining_replicas == 1 and remaining_dims == 0
|
||||
|
||||
# Assigns an offset to each replica such that no two replicas overlap.
|
||||
replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
|
||||
for replica in xrange(num_replicas):
|
||||
# Chooses a replica number in each axis.
|
||||
t = replica
|
||||
pos = []
|
||||
for dim in replica_shape[::-1]:
|
||||
pos.append(t % dim)
|
||||
t //= dim
|
||||
replica_pos = np.array(pos[::-1], dtype=np.int32)
|
||||
|
||||
# Determines where that replica starts in each axis.
|
||||
outer = replica_pos // computation_stride
|
||||
inner = replica_pos % computation_stride
|
||||
replica_offsets[replica, :] = outer * computation_footprint + inner
|
||||
|
||||
# Computes a complete logical core -> physical core mapping for each replica.
|
||||
indices = [
|
||||
np.arange(0, computation_shape[i] * computation_stride[i],
|
||||
computation_stride[i]) for i in xrange(topology_rank)
|
||||
]
|
||||
indices = np.concatenate(
|
||||
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
|
||||
axis=-1)
|
||||
indices = indices.reshape((-1, topology_rank))
|
||||
assignment = indices + replica_offsets[:, np.newaxis, :]
|
||||
return DeviceAssignment(topology, core_assignment=assignment)
|
132
tensorflow/python/tpu/error_handling.py
Normal file
132
tensorflow/python/tpu/error_handling.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""ErrorRendezvous handler for collecting errors from multiple threads."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_UNINTERESTING_ERRORS = (errors.CancelledError,)
|
||||
|
||||
|
||||
class ErrorRendezvous(object):
|
||||
"""Resolve errors from multiple threads during TPU execution.
|
||||
|
||||
TPU errors can occur on the infeed or outfeed threads as well as the main
|
||||
training thread.
|
||||
|
||||
Depending on which thread "wins" and receives the session error first, we may
|
||||
end up showing users a confusing and non-actionable error message (session
|
||||
cancelled) instead of a root cause (e.g. a bad filename).
|
||||
|
||||
The rendezvous object provides a location to capture these errors until all
|
||||
threads terminate. At that point we can choose the most informative error
|
||||
to report.
|
||||
"""
|
||||
|
||||
def __init__(self, num_sources):
|
||||
# string -> (message, traceback)
|
||||
self._errors = {}
|
||||
self._num_sources = num_sources
|
||||
self._session_cancel_timer = None
|
||||
|
||||
def record_error(self, source, exc_info, session=None):
|
||||
"""Report an exception from the given source.
|
||||
|
||||
If a session is passed, a timer will be registered to close it after a few
|
||||
seconds. This is necessary to ensure the main training loop does not hang
|
||||
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
|
||||
interesting error from another thread to propagate.
|
||||
|
||||
Args:
|
||||
source: string, source of the error
|
||||
exc_info: Output from `sys.exc_info` (type, value, traceback)
|
||||
session: Session to close after delay.
|
||||
"""
|
||||
_, value, _ = exc_info
|
||||
self._errors[source] = exc_info
|
||||
logging.info('Error recorded from %s: %s', source, value)
|
||||
|
||||
if session is not None and self._session_cancel_timer is None:
|
||||
|
||||
def _cancel_session():
|
||||
time.sleep(5)
|
||||
try:
|
||||
session.close()
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
|
||||
self._session_cancel_timer.daemon = True
|
||||
self._session_cancel_timer.start()
|
||||
|
||||
def record_done(self, source):
|
||||
"""Mark execution source `source` as done.
|
||||
|
||||
If an error was originally reported from `source` it is left intact.
|
||||
|
||||
Args:
|
||||
source: `str`, source being recorded
|
||||
"""
|
||||
logging.info('%s marked as finished', source)
|
||||
if source not in self._errors:
|
||||
self._errors[source] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def catch_errors(self, source, session=None):
|
||||
"""Context manager to report any errors within a block."""
|
||||
try:
|
||||
yield
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.record_error(source, sys.exc_info(), session)
|
||||
|
||||
def raise_errors(self, timeout_sec=0):
|
||||
"""Wait for up to `timeout` seconds for all error sources to finish.
|
||||
|
||||
Preferentially raise "interesting" errors (errors not in the
|
||||
_UNINTERESTING_ERRORS) set.
|
||||
|
||||
Args:
|
||||
timeout_sec: Seconds to wait for other error sources.
|
||||
"""
|
||||
for _ in range(timeout_sec):
|
||||
if len(self._errors) == self._num_sources:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
|
||||
|
||||
# First check for any interesting errors, then fall back on the session
|
||||
# cancelled errors etc.
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
if isinstance(value, _UNINTERESTING_ERRORS):
|
||||
continue
|
||||
else:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
||||
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
435
tensorflow/python/tpu/feature_column.py
Normal file
435
tensorflow/python/tpu/feature_column.py
Normal file
@ -0,0 +1,435 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU Feature Column Library."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.feature_column import feature_column as fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as fc_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
|
||||
_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
|
||||
fc._VocabularyFileCategoricalColumn,
|
||||
fc._VocabularyListCategoricalColumn,
|
||||
fc._WeightedCategoricalColumn,
|
||||
fc_lib.IdentityCategoricalColumn,
|
||||
fc_lib.VocabularyFileCategoricalColumn,
|
||||
fc_lib.VocabularyListCategoricalColumn,
|
||||
fc_lib.WeightedCategoricalColumn)
|
||||
|
||||
|
||||
def embedding_column(categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None):
|
||||
"""TPU embedding_column for `tf.feature_column.embedding_column`.
|
||||
|
||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||
version. The following args available for the non-TPU version are NOT
|
||||
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
|
||||
|
||||
Args:
|
||||
categorical_column: A categorical_column returned from
|
||||
categorical_column_with_identity, weighted_categorical_column,
|
||||
categorical_column_with_vocabulary_list or
|
||||
categorical_column_with_vocabulary_file.
|
||||
dimension: An integer specifying dimension of the embedding, must be > 0.
|
||||
combiner: A string specifying how to reduce if there are multiple entries
|
||||
in a single row. For more information, see
|
||||
`tf.feature_column.embedding_column`.
|
||||
initializer: A variable initializer function to be used in embedding
|
||||
variable initialization. If not specified, defaults to
|
||||
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
|
||||
`1/sqrt(dimension)`.
|
||||
|
||||
Returns:
|
||||
A _TPUEmbeddingColumn.
|
||||
|
||||
Raises:
|
||||
ValueError: if `dimension` not > 0.
|
||||
ValueError: if `initializer` is specified but not callable.
|
||||
"""
|
||||
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
|
||||
raise TypeError(
|
||||
'categorical_column for tpu '
|
||||
' embedding_column must be type %s, got %s.' % (' or '.join([
|
||||
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
|
||||
]), type(categorical_column)))
|
||||
if (dimension is None) or (dimension < 1):
|
||||
raise ValueError('Invalid dimension {}.'.format(dimension))
|
||||
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified. '
|
||||
'Embedding of column_name: {}'.format(
|
||||
categorical_column.name))
|
||||
if initializer is None:
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||
|
||||
embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
|
||||
|
||||
def _creator(weight_collections, scope):
|
||||
embedding_column_layer = fc._EmbeddingColumnLayer(
|
||||
embedding_shape=embedding_shape,
|
||||
initializer=initializer,
|
||||
weight_collections=weight_collections,
|
||||
trainable=True,
|
||||
name='embedding_column_layer')
|
||||
return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
|
||||
|
||||
column = _TPUEmbeddingColumn(
|
||||
categorical_column=categorical_column,
|
||||
dimension=dimension,
|
||||
combiner=combiner,
|
||||
layer_creator=_creator,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
# For Embedding column, the initializer is hidden inside the creator Fn, which
|
||||
# is not accessiable later. So, we attach it to a speicial field. Also note
|
||||
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
|
||||
# initializer differently. See shared_embedding_columns for details.
|
||||
column._tpu_initializer = initializer
|
||||
return column
|
||||
|
||||
|
||||
def shared_embedding_columns(categorical_columns,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None):
|
||||
"""List of dense columns that convert from sparse, categorical input."""
|
||||
for categorical_column in categorical_columns:
|
||||
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
|
||||
raise TypeError(
|
||||
'categorical_column for tpu '
|
||||
' shared_embedding_columns must be type %s, got %s.' % (' or '.join([
|
||||
cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
|
||||
]), type(categorical_column)))
|
||||
columns = fc_lib.shared_embedding_columns(
|
||||
categorical_columns,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
|
||||
# Use the initializer and shared_embedding_collection_name to create TPU
|
||||
# version
|
||||
initializer = columns[0].initializer
|
||||
shared_embedding_collection_name = columns[0].shared_embedding_collection_name
|
||||
tpu_columns = []
|
||||
|
||||
# Create the state (_SharedEmbeddingColumnLayer) here.
|
||||
for categorical_column in categorical_columns:
|
||||
column = _TPUSharedEmbeddingColumn(
|
||||
categorical_column=categorical_column,
|
||||
dimension=dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True)
|
||||
tpu_columns.append(column)
|
||||
|
||||
return tpu_columns
|
||||
|
||||
|
||||
class _TPUBaseEmbeddingColumn(object):
|
||||
"""Base class for TPU Embedding Column."""
|
||||
|
||||
def __init__(self, categorical_column):
|
||||
self._tpu_categorical_column = categorical_column
|
||||
|
||||
def get_combiner(self):
|
||||
"""Returns the embedding combiner."""
|
||||
raise NotImplementedError('not implemented')
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns the embedding table size, tuple of vocab size and dimension."""
|
||||
raise NotImplementedError('not implemented')
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""Returns the feature key name in the features dict."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""Return the key name for weights."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""Returns the embedding variable name.
|
||||
|
||||
Feature key name and embedding variable name are usually one-to-one mapping.
|
||||
But for shared embedding columns, it is many-to-one mapping.
|
||||
"""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def get_initializer(self):
|
||||
"""Returns the initializer."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
raise NotImplementedError('not impl')
|
||||
|
||||
|
||||
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
"""Core Embedding Column."""
|
||||
|
||||
def __new__(cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
layer_creator=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
||||
# are not supported on TPU. They are solely for matching the signature of
|
||||
# __new__ of parent class fc._EmbeddingColumn.
|
||||
return fc._EmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
layer_creator=layer_creator,
|
||||
ckpt_to_load_from=ckpt_to_load_from,
|
||||
tensor_name_in_ckpt=tensor_name_in_ckpt,
|
||||
max_norm=max_norm,
|
||||
trainable=trainable)
|
||||
|
||||
def __init__(self,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
layer_creator=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns num_ids and width."""
|
||||
return (self.categorical_column._num_buckets, self.dimension)
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""get_feature_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.categorical_column.name
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""get_weight_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.weight_feature_key
|
||||
return None
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""get_embedding_var_name."""
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_initializer(self):
|
||||
return self._tpu_initializer
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
if isinstance(
|
||||
self.categorical_column,
|
||||
(
|
||||
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
||||
fc_lib.WeightedCategoricalColumn)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
def host_computation():
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
# TPU mode
|
||||
# Get the embeddings from the LazyBuilder.
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
fc._SharedEmbeddingColumn):
|
||||
"""Core Shared Embedding Column."""
|
||||
|
||||
def __new__(cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
return fc._SharedEmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
shared_embedding_collection_name=shared_embedding_collection_name,
|
||||
ckpt_to_load_from=ckpt_to_load_from,
|
||||
tensor_name_in_ckpt=tensor_name_in_ckpt,
|
||||
max_norm=max_norm,
|
||||
trainable=trainable)
|
||||
|
||||
def __init__(self,
|
||||
categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None,
|
||||
ckpt_to_load_from=None,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True):
|
||||
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
|
||||
def get_embedding_table_size(self):
|
||||
"""Returns num_ids and width."""
|
||||
return (self.categorical_column._num_buckets, self.dimension)
|
||||
|
||||
def get_feature_key_name(self):
|
||||
"""get_feature_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.categorical_column.name
|
||||
return self.categorical_column.name
|
||||
|
||||
def get_weight_key_name(self):
|
||||
"""get_weight_key_name."""
|
||||
if self.is_categorical_column_weighted():
|
||||
return self.categorical_column.weight_feature_key
|
||||
return None
|
||||
|
||||
def get_embedding_var_name(self):
|
||||
"""get_embedding_var_name."""
|
||||
return self.shared_embedding_collection_name
|
||||
|
||||
def get_initializer(self):
|
||||
return self.initializer
|
||||
|
||||
def is_categorical_column_weighted(self):
|
||||
"""Check if the categorical column of the embedding column is weighted."""
|
||||
if isinstance(
|
||||
self.categorical_column,
|
||||
(
|
||||
fc._WeightedCategoricalColumn, # pylint: disable=protected-access
|
||||
fc_lib.WeightedCategoricalColumn)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
def host_computation():
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
# TPU mode
|
||||
# Get the embeddings from the LazyBuilder.
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
is_shared_embedding=True)
|
||||
return tensor
|
||||
|
||||
|
||||
def _record_variable_scope_and_name(embedding_var_name,
|
||||
embedding_var_name_in_fc,
|
||||
is_shared_embedding=False):
|
||||
"""Add embedding variable name and scope to collection."""
|
||||
g = ops.get_default_graph()
|
||||
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
|
||||
if not collection:
|
||||
collection.append({})
|
||||
|
||||
var_def_dict = collection[0]
|
||||
|
||||
captured_scope = variable_scope.get_variable_scope()
|
||||
captured_scope_name = captured_scope.name
|
||||
|
||||
if embedding_var_name in var_def_dict:
|
||||
if (var_def_dict[embedding_var_name][0] != captured_scope_name
|
||||
and not is_shared_embedding):
|
||||
raise ValueError(
|
||||
'For embedding var name {}, the variable scope name is different, '
|
||||
'got {}; expected {}'.format(embedding_var_name,
|
||||
captured_scope_name,
|
||||
var_def_dict[embedding_var_name][0]))
|
||||
if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
|
||||
raise ValueError(
|
||||
'For embedding var name {}, the embedding name is different, '
|
||||
'got {}; expected {}'.format(embedding_var_name,
|
||||
embedding_var_name_in_fc,
|
||||
var_def_dict[embedding_var_name][1]))
|
||||
else:
|
||||
var_def_dict[embedding_var_name] = (captured_scope_name,
|
||||
embedding_var_name_in_fc)
|
||||
|
||||
|
||||
def _is_running_on_cpu():
|
||||
"""Returns True if the current context is CPU model."""
|
||||
return tpu_function.get_tpu_context().number_of_shards is None
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Tests for contrib.tpu.python.tpu.feature_column."""
|
||||
"""Tests for python.tpu.feature_column."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.feature_column import feature_column as fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as fc_lib
|
||||
@ -31,6 +30,7 @@ from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
|
||||
|
||||
def _initialized_session():
|
23
tensorflow/python/tpu/functional.py
Normal file
23
tensorflow/python/tpu/functional.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Functional operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name
|
418
tensorflow/python/tpu/ops/tpu_ops.py
Normal file
418
tensorflow/python/tpu/ops/tpu_ops.py
Normal file
@ -0,0 +1,418 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations for TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.python.ops import gen_tpu_ops
|
||||
from tensorflow.python.ops.gen_tpu_ops import *
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
def _create_default_group_assignment():
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
logging.warning(
|
||||
"cross_replica_sum should be used within a tpu_shard_context, but "
|
||||
"got unset number_of_shards. Assuming 1.")
|
||||
num_shards = 1
|
||||
group_assignment = [list(range(num_shards))]
|
||||
return group_assignment
|
||||
|
||||
def all_to_all(x,
|
||||
concat_dimension,
|
||||
split_dimension,
|
||||
split_count,
|
||||
group_assignment=None,
|
||||
name=None):
|
||||
"""Exchange data across TPU replicas.
|
||||
|
||||
Args:
|
||||
x: The local tensor.
|
||||
concat_dimension: The dimension number to concatenate.
|
||||
split_dimension: The dimension number to split.
|
||||
split_count: The number of splits, this number must equal to the sub-group
|
||||
size(group_assignment.get_shape()[1])
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is concatenated by data from different replicas.
|
||||
"""
|
||||
if group_assignment is None:
|
||||
group_assignment = _create_default_group_assignment()
|
||||
return gen_tpu_ops.all_to_all(
|
||||
x,
|
||||
group_assignment,
|
||||
concat_dimension=concat_dimension,
|
||||
split_dimension=split_dimension,
|
||||
split_count=split_count,
|
||||
name=name)
|
||||
|
||||
@ops.RegisterGradient("AllToAll")
|
||||
def _all_to_all_grad(op, grad):
|
||||
# The gradient of a all-to-all is also a all-to-all but the
|
||||
# split_dimension and concat_dimension is swapped.
|
||||
# The graident with respect to group_assignment is None.
|
||||
return [
|
||||
gen_tpu_ops.all_to_all(
|
||||
grad,
|
||||
op.inputs[1],
|
||||
concat_dimension=op.get_attr("split_dimension"),
|
||||
split_dimension=op.get_attr("concat_dimension"),
|
||||
split_count=op.get_attr("split_count")), None
|
||||
]
|
||||
|
||||
def cross_replica_sum(x, group_assignment=None, name=None):
|
||||
"""Sum the input tensor across replicas according to group_assignment.
|
||||
|
||||
Args:
|
||||
x: The local tensor to the sum.
|
||||
group_assignment: Optional 2d int32 lists with shape [num_groups,
|
||||
num_replicas_per_group]. `group_assignment[i]` represents the replica
|
||||
ids in the ith subgroup.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is summed across replicas.
|
||||
"""
|
||||
if group_assignment is None:
|
||||
group_assignment = _create_default_group_assignment()
|
||||
|
||||
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
|
||||
|
||||
def collective_permute(x, source_target_pairs, name=None):
|
||||
"""Permute the input tensor across replicas given source_target_pairs.
|
||||
|
||||
For each source_target_pair <a, b>, we send replica a's input to replica b.
|
||||
Each replica id must only appear once in the source column. Also it must
|
||||
only appear once in the target column.
|
||||
For the replica id not in the target column, this op returns a zero tensor
|
||||
with the same shape and dtype of the input x.
|
||||
|
||||
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
|
||||
source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
|
||||
`[0, A, B, C]`.
|
||||
|
||||
Args:
|
||||
x: The local tensor to be permuted.
|
||||
source_target_pairs: 2d int lists with shape [num_pairs, 2].
|
||||
source_target_pairs[i][0] represents the source replica id and
|
||||
source_target_pairs[i][1] represents the target replica id.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
A `Tensor` which is permuted.
|
||||
"""
|
||||
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
|
||||
|
||||
@ops.RegisterGradient("CollectivePermute")
|
||||
def _collective_permute_grad(op, grad):
|
||||
# The gradient of a collective permute operation is also a collective
|
||||
# permute, but with source/target pairs reversed. The gradient with respect
|
||||
# to input argument `source_target_pairs` is `None`.
|
||||
source_target_pairs = op.inputs[1][:, ::-1]
|
||||
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
|
||||
|
||||
@ops.RegisterGradient("CrossReplicaSum")
|
||||
def _cross_replica_sum_grad(op, grad):
|
||||
# The gradient of a cross replica sum is also a cross-replica sum.
|
||||
# The gradient with respect to group_assignment is None.
|
||||
return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
|
||||
|
||||
# This extra type checking exists to give a more helpful error message in
|
||||
# the common case that uint8 and int64 values are infed. Remove when both
|
||||
# types are supported.
|
||||
|
||||
_SUPPORTED_INFEED_DTYPES = set([
|
||||
dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
|
||||
dtypes.complex64, dtypes.uint32
|
||||
])
|
||||
|
||||
@ops.RegisterGradient("TPUEmbeddingActivations")
|
||||
def _embedding_activations_grad(activations_op, grad_wrt_activations):
|
||||
"""Saves the gradient of embedding activations ops in a graph collection."""
|
||||
g = ops.get_default_graph()
|
||||
table_id = activations_op.get_attr("table_id")
|
||||
lookup_id = activations_op.get_attr("lookup_id")
|
||||
table_gradients = g.get_collection_ref(
|
||||
"tpu_embedding_gradients_table_%d" % table_id)
|
||||
|
||||
if not table_gradients:
|
||||
raise RuntimeError(
|
||||
"Gradients for TPUEmbedding have been generated in non-training mode."
|
||||
"This is not expected. Consider putting your Optimizer.minimize code "
|
||||
"behind the training mode condition check. For Estimator, you can "
|
||||
"do \n\n"
|
||||
" if mode == tf.estimator.ModeKeys.TRAIN:\n"
|
||||
" train_op = opt.minimize(loss)\n"
|
||||
"\n")
|
||||
|
||||
table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
|
||||
return [
|
||||
# RegisterGradient requires that value be returned for all inputs. Since
|
||||
# the first argument (tpu_gradient_variable_{table_name}) has shape [1],
|
||||
# we will return zeros(shape=[1]). The actual gradient w.r.t. the
|
||||
# embedding activations (grad_wrt_activations) has the same shape as the
|
||||
# activations returned by embedding_activations.
|
||||
array_ops.zeros(arg.shape, dtype=dtypes.float32)
|
||||
for arg in activations_op.inputs
|
||||
]
|
||||
|
||||
def infeed_dequeue(dtype, shape, name=None):
|
||||
"""A placeholder op for a value that will be fed into the computation.
|
||||
|
||||
Args:
|
||||
dtype: A `tf.DType`. The type of elements in the tensor.
|
||||
shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of type `dtype`.
|
||||
A tensor that will be provided using the infeed mechanism.
|
||||
|
||||
Raises:
|
||||
TypeError: If 'dtype` is not a supported infeed type.
|
||||
"""
|
||||
if dtype not in _SUPPORTED_INFEED_DTYPES:
|
||||
raise TypeError(
|
||||
"{} is not a supported TPU infeed type. Supported types are: "
|
||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||
|
||||
return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def infeed_dequeue_tuple(dtypes, shapes, name=None):
|
||||
"""A placeholder op for values fed into the TPU simultaneously as a tuple.
|
||||
|
||||
Args:
|
||||
dtypes: A list of `tf.DType`s that has length `>= 1`.
|
||||
The element types of each element in `outputs`.
|
||||
shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
|
||||
The shapes of each tensor in `outputs`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` objects of type `dtypes`.
|
||||
A list of tensors that will be provided using the infeed mechanism.
|
||||
|
||||
Raises:
|
||||
TypeError: If a type in 'dtypes` is not a supported infeed type.
|
||||
"""
|
||||
for dtype in dtypes:
|
||||
if dtype not in _SUPPORTED_INFEED_DTYPES:
|
||||
raise TypeError(
|
||||
"{} is not a supported TPU infeed type. Supported types are: "
|
||||
"{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
|
||||
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
|
||||
# pylint: enable=redefined-outer-name
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def send_tpu_embedding_gradients(inputs,
|
||||
config,
|
||||
learning_rates=None,
|
||||
name=None):
|
||||
"""A placeholder op for feeding per-sample gradients to the embedding layer.
|
||||
|
||||
Args:
|
||||
inputs: A TensorList of gradients with which to update embedding tables.
|
||||
This argument has the same length and shapes as the return value of
|
||||
RecvTPUEmbeddingActivations, but contains gradients of the model's
|
||||
loss with respect to the embedding activations. The embedding tables
|
||||
are updated from these gradients via the optimizers specified in the
|
||||
TPU embedding configuration given to tpu.initialize_system.
|
||||
config: Serialized TPUEmbeddingConfiguration proto.
|
||||
learning_rates: A TensorList of float32 scalars, one for each dynamic
|
||||
learning rate tag: see the comments in
|
||||
//third_party/tensorflow/core/protobuf/tpu/
|
||||
optimization_parameters.proto.
|
||||
Multiple tables can share the same dynamic learning rate tag as
|
||||
specified in the configuration. If the learning rates for all tables
|
||||
are constant, this list should be empty.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A SendTPUEmbeddingGradients operation.
|
||||
"""
|
||||
if learning_rates is None:
|
||||
learning_rates = []
|
||||
return gen_tpu_ops.send_tpu_embedding_gradients(
|
||||
inputs=inputs, learning_rates=learning_rates, config=config, name=name)
|
||||
|
||||
send_tpu_embedding_gradients.__doc__ = (
|
||||
gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_integer_batch(batch,
|
||||
device_ordinal,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
batch: A list of 1D tensors, one for each embedding table, containing the
|
||||
indices into the tables.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingIntegerBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
|
||||
batch=batch,
|
||||
device_ordinal=device_ordinal,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_integer_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
||||
embedding_indices,
|
||||
aggregation_weights,
|
||||
device_ordinal,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
and feature to which the corresponding embedding_indices and
|
||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||
f, where nf is the number of features from the corresponding table, f is
|
||||
in [0, nf), and b is in [0, batch size).
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||
i.e. per (training example, feature) -- aggregation weights.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
|
||||
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
|
||||
is to use 'sum' for all tables (optional).
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingSparseBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
|
||||
sample_indices=sample_indices,
|
||||
embedding_indices=embedding_indices,
|
||||
aggregation_weights=aggregation_weights,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_sparse_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
embedding_indices,
|
||||
aggregation_weights,
|
||||
table_ids,
|
||||
device_ordinal,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
"""A placeholder op for enqueueing embedding IDs to the TPU.
|
||||
|
||||
Args:
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
to which the corresponding embedding_indices and aggregation_weights
|
||||
values belong. It corresponds to sp_ids.indices[:,0] in
|
||||
embedding_lookup_sparse().
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
|
||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||
example aggregation weights. It corresponds to sp_weights.values in
|
||||
embedding_lookup_sparse().
|
||||
table_ids: A list of integers specifying the identifier of the embedding
|
||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||
lookup the corresponding input. The ith input is looked up using
|
||||
table_ids[i]. The size of the table_ids list must be equal to that of
|
||||
sample_indices, embedding_indices and aggregation_weights.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
|
||||
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
|
||||
is to use 'sum' for all tables (optional).
|
||||
mode_override: A string input that overrides the mode specified in the
|
||||
TPUEmbeddingConfiguration. Supported values are {'unspecified',
|
||||
'inference', 'training', 'backward_pass_only'}. When set to
|
||||
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
|
||||
otherwise mode_override is used (optional).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
An EnqueueTPUEmbeddingSparseTensorBatch operation.
|
||||
"""
|
||||
if mode_override is None:
|
||||
mode_override = "unspecified"
|
||||
return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
sample_indices=sample_indices,
|
||||
embedding_indices=embedding_indices,
|
||||
aggregation_weights=aggregation_weights,
|
||||
table_ids=table_ids,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
name=name)
|
||||
|
||||
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
|
||||
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
|
||||
|
||||
else:
|
||||
# We have already built the appropriate libraries into the binary via CMake
|
||||
# if we have built contrib, so we don't need this
|
||||
pass
|
20
tensorflow/python/tpu/ops/tpu_ordinal_selector_op.py
Normal file
20
tensorflow/python/tpu/ops/tpu_ordinal_selector_op.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations to select TPU core to run."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
27
tensorflow/python/tpu/profiler/BUILD
Normal file
27
tensorflow/python/tpu/profiler/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "profiler",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_profiler_analysis_pb2_grpc",
|
||||
"//tensorflow/core/profiler:profiler_analysis_proto_py",
|
||||
"//tensorflow/core/profiler:protos_all_py",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_profiler_analysis_pb2_grpc",
|
||||
srcs = ["tpu_profiler_analysis_pb2_grpc.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/core/profiler:profiler_analysis_proto_py"],
|
||||
)
|
31
tensorflow/python/tpu/profiler/__init__.py
Normal file
31
tensorflow/python/tpu/profiler/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Classes for TPU trace events."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.core.profiler.trace_events_pb2 import *
|
||||
from tensorflow.core.profiler.profiler_analysis_pb2 import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
438
tensorflow/python/tpu/session_support.py
Normal file
438
tensorflow/python/tpu/session_support.py
Normal file
@ -0,0 +1,438 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Operations for handling session logging and shutdown notifications."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import time
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
_WATCHDOG = None
|
||||
|
||||
|
||||
class CoordinatorShutdownException(Exception):
|
||||
"""Raised when the coordinator needs to shutdown."""
|
||||
pass
|
||||
|
||||
|
||||
def _clone_session(session, graph=None):
|
||||
return session_lib.Session(
|
||||
target=session.sess_str,
|
||||
config=session._config, # pylint: disable=protected-access
|
||||
graph=graph if graph else session.graph)
|
||||
|
||||
|
||||
def _make_heartbeat_op(session, device, request_ph):
|
||||
"""Return a heartbeat op or None if heartbeats are not supported by device."""
|
||||
try:
|
||||
# Test if we can connect in a isolated graph + session
|
||||
with ops.Graph().as_default():
|
||||
with _clone_session(session) as temp_session:
|
||||
with ops.device(device):
|
||||
heartbeat_op = tpu_ops.worker_heartbeat('')
|
||||
options = config_pb2.RunOptions(timeout_in_ms=5000)
|
||||
temp_session.run(heartbeat_op, options=options)
|
||||
except errors.InvalidArgumentError as _:
|
||||
logging.warning('Error running heartbeat on %s', device)
|
||||
return None
|
||||
except errors.DeadlineExceededError as _:
|
||||
logging.warning('Timeout connecting to %s when testing heartbeat', device)
|
||||
return None
|
||||
|
||||
# If we successfully connected and pinged the worker, go ahead and construct
|
||||
# the operation.
|
||||
with ops.device(device):
|
||||
return tpu_ops.worker_heartbeat(request_ph)
|
||||
|
||||
|
||||
class WorkerHeartbeatManager(object):
|
||||
"""Manages the status/heartbeat monitor for a set of workers."""
|
||||
|
||||
def __init__(self, session, devices, heartbeat_ops, request_placeholder):
|
||||
"""Construct a new WorkerHeartbeatManager.
|
||||
|
||||
(Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
|
||||
|
||||
Args:
|
||||
session: `tf.Session`, session to use for heartbeat operations.
|
||||
devices: `list[string]` Set of devices to connect to.
|
||||
heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
|
||||
request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
|
||||
the WorkerHeartbeatRequest protocol buffer.
|
||||
"""
|
||||
self._session = session
|
||||
self._devices = devices
|
||||
self._ops = heartbeat_ops
|
||||
self._request_placeholder = request_placeholder
|
||||
|
||||
@staticmethod
|
||||
def from_devices(session, devices):
|
||||
"""Construct a heartbeat manager for the given devices."""
|
||||
if not devices:
|
||||
logging.error('Trying to create heartbeat manager with no devices?')
|
||||
|
||||
logging.info('Creating heartbeat manager for %s', devices)
|
||||
request_placeholder = array_ops.placeholder(
|
||||
name='worker_heartbeat_request', dtype=dtypes.string)
|
||||
|
||||
heartbeat_ops = []
|
||||
kept_devices = []
|
||||
for device in devices:
|
||||
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
|
||||
if heartbeat_op is not None:
|
||||
kept_devices.append(device)
|
||||
heartbeat_ops.append(heartbeat_op)
|
||||
else:
|
||||
logging.warning('Heartbeat support not available for %s', device)
|
||||
|
||||
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
|
||||
request_placeholder)
|
||||
|
||||
def num_workers(self):
|
||||
return len(self._devices)
|
||||
|
||||
def configure(self, message):
|
||||
"""Configure heartbeat manager for all devices.
|
||||
|
||||
Args:
|
||||
message: `event_pb2.WorkerHeartbeatRequest`
|
||||
Returns: `None`
|
||||
"""
|
||||
logging.info('Configuring worker heartbeat: %s',
|
||||
text_format.MessageToString(message))
|
||||
self._session.run(self._ops,
|
||||
{self._request_placeholder: message.SerializeToString()})
|
||||
|
||||
def ping(self, request=None, timeout_in_ms=5000):
|
||||
"""Ping all workers, returning the parsed status results."""
|
||||
if request is None:
|
||||
request = event_pb2.WorkerHeartbeatRequest()
|
||||
|
||||
options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
|
||||
results = self._session.run(
|
||||
self._ops,
|
||||
feed_dict={self._request_placeholder: request.SerializeToString()},
|
||||
options=options)
|
||||
parsed_results = [
|
||||
event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
|
||||
for res_pb in results
|
||||
]
|
||||
logging.debug('Ping results: %s', parsed_results)
|
||||
return parsed_results
|
||||
|
||||
def lame_workers(self):
|
||||
"""Ping all workers, returning manager containing lame workers (or None)."""
|
||||
ping_results = self.ping()
|
||||
lame_workers = []
|
||||
|
||||
for ping_response, device, op in zip(ping_results, self._devices,
|
||||
self._ops):
|
||||
if ping_response.health_status != event_pb2.OK:
|
||||
lame_workers.append((device, op))
|
||||
|
||||
if not lame_workers:
|
||||
return None
|
||||
|
||||
bad_devices, bad_ops = zip(*lame_workers)
|
||||
return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
|
||||
self._request_placeholder)
|
||||
|
||||
def __repr__(self):
|
||||
return 'HeartbeatManager(%s)' % ','.join(self._devices)
|
||||
|
||||
def shutdown(self, timeout_ms=10000):
|
||||
"""Shutdown all workers after `shutdown_timeout_secs`."""
|
||||
logging.info('Shutting down %s.', self)
|
||||
req = event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
|
||||
self.configure(req)
|
||||
|
||||
# Wait for workers to shutdown. This isn't strictly required
|
||||
# but it avoids triggering multiple checkpoints with the same lame worker.
|
||||
logging.info('Waiting %dms for worker shutdown.', timeout_ms)
|
||||
time.sleep(timeout_ms / 1000)
|
||||
|
||||
|
||||
def all_worker_devices(session):
|
||||
"""Return a list of devices for each worker in the system."""
|
||||
devices = session.list_devices()
|
||||
return [
|
||||
device.name
|
||||
for device in devices
|
||||
if ':CPU:' in device.name and 'coordinator' not in device.name
|
||||
]
|
||||
|
||||
|
||||
class WatchdogManager(threading.Thread):
|
||||
"""Configures worker watchdog timer and handles periodic pings.
|
||||
|
||||
Usage:
|
||||
# Ping workers every minute, shutting down workers if they haven't received
|
||||
# a ping after 1 hour.
|
||||
watchdog_manager = WatchdogManager(
|
||||
ping_interval=60, shutdown_timeout=3600
|
||||
)
|
||||
|
||||
# Use as a context manager, resetting watchdog on context exit:
|
||||
with watchdog_manager:
|
||||
session.run(...)
|
||||
|
||||
# Or setup globally; watchdog will remain active until program exit.
|
||||
watchdog_manager.configure_and_run()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session,
|
||||
devices=None,
|
||||
ping_interval=60,
|
||||
shutdown_timeout=3600):
|
||||
"""Initialize a watchdog manager.
|
||||
|
||||
Args:
|
||||
session: Session connected to worker devices. A cloned session and graph
|
||||
will be created for managing worker pings.
|
||||
devices: Set of devices to monitor. If none, all workers will be
|
||||
monitored.
|
||||
ping_interval: Time, in seconds, between watchdog pings.
|
||||
shutdown_timeout: Time, in seconds, before watchdog timeout.
|
||||
"""
|
||||
threading.Thread.__init__(self)
|
||||
self.ping_interval = ping_interval
|
||||
self.shutdown_timeout = shutdown_timeout
|
||||
self.daemon = True
|
||||
self._config = session._config # pylint: disable=protected-access
|
||||
self._target = session.sess_str
|
||||
self._running = False
|
||||
self._devices = devices
|
||||
|
||||
self._graph = None
|
||||
self._session = None
|
||||
self._worker_manager = None
|
||||
|
||||
def _reset_manager(self):
|
||||
"""Reset the graph, session and worker manager."""
|
||||
self._graph = ops.Graph()
|
||||
self._session = session_lib.Session(
|
||||
target=self._target,
|
||||
graph=self._graph,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
if self._devices is None:
|
||||
self._devices = all_worker_devices(self._session)
|
||||
|
||||
with self._graph.as_default():
|
||||
self._worker_manager = WorkerHeartbeatManager.from_devices(
|
||||
self._session, self._devices)
|
||||
|
||||
self._worker_manager.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(
|
||||
timeout_ms=self.shutdown_timeout * 1000,),
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
|
||||
def configure_and_run(self):
|
||||
logging.info(
|
||||
'Enabling watchdog timer with %d second timeout '
|
||||
'and %d second ping interval.', self.shutdown_timeout,
|
||||
self.ping_interval)
|
||||
self._reset_manager()
|
||||
self._running = True
|
||||
self.start()
|
||||
|
||||
def stop(self):
|
||||
logging.info('Stopping worker watchdog.')
|
||||
self._worker_manager.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
|
||||
shutdown_mode=event_pb2.NOT_CONFIGURED))
|
||||
self._running = False
|
||||
self.join()
|
||||
|
||||
def __enter__(self):
|
||||
self.configure_and_run()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def run(self):
|
||||
# Don't fetch logs or adjust timing: just ping the watchdog.
|
||||
#
|
||||
# If we hit an exception, reset our session as it is likely broken.
|
||||
while self._running:
|
||||
try:
|
||||
self._worker_manager.ping(request=None)
|
||||
time.sleep(self.ping_interval)
|
||||
except errors.OpError as e:
|
||||
# Catch any TF errors that occur so we don't stop sending heartbeats
|
||||
logging.debug('Caught error while sending heartbeat: %s', e)
|
||||
self._reset_manager()
|
||||
|
||||
|
||||
def start_worker_watchdog(session,
|
||||
devices=None,
|
||||
ping_interval=60,
|
||||
shutdown_timeout=3600):
|
||||
"""Start global worker watchdog to shutdown workers on coordinator exit."""
|
||||
global _WATCHDOG
|
||||
if _WATCHDOG is None:
|
||||
# Ensure we can send a few pings before we timeout!
|
||||
ping_interval = min(shutdown_timeout / 10., ping_interval)
|
||||
_WATCHDOG = WatchdogManager(session, devices, ping_interval,
|
||||
shutdown_timeout)
|
||||
_WATCHDOG.configure_and_run()
|
||||
|
||||
|
||||
class GracefulShutdownHook(session_run_hook.SessionRunHook):
|
||||
"""Session hook that watches for shutdown events.
|
||||
|
||||
If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
|
||||
SystemShutdown exception is raised to terminate the main session. If `saver`
|
||||
is None the `SAVERS` collection will be read to find a saver.
|
||||
|
||||
`on_shutdown_hooks` is an optional list of functions that should be called
|
||||
after checkpointing. The function is called with (`run_context`,
|
||||
`all_workers`, `lame_workers`).
|
||||
|
||||
If `heartbeat_group` is not specified, it will default to all CPU workers
|
||||
in the system.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
|
||||
self._saver = saver
|
||||
self._checkpoint_prefix = checkpoint_prefix
|
||||
self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
|
||||
|
||||
# Worker heartbeats are managed independently of the main training graph.
|
||||
self._graph = ops.Graph()
|
||||
self._workers = None
|
||||
self._session = None
|
||||
self._heartbeat_supported = False
|
||||
|
||||
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
|
||||
# N.B. We have to pull the global step here to avoid it being unavailable
|
||||
# at checkpoint time; the graph has been frozen at that point.
|
||||
if training_util.get_global_step() is None and self.saver() is not None:
|
||||
raise ValueError(
|
||||
'Saver defined but no global step. Run `get_or_create_global_step()`'
|
||||
' in your model definition to allow checkpointing.')
|
||||
|
||||
with self._graph.as_default():
|
||||
logging.info('Installing graceful shutdown hook.')
|
||||
self._session = _clone_session(training_session, self._graph)
|
||||
self._workers = WorkerHeartbeatManager.from_devices(
|
||||
self._session, all_worker_devices(self._session))
|
||||
self._heartbeat_supported = self._workers.num_workers() > 0
|
||||
if self._heartbeat_supported:
|
||||
self._workers.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
else:
|
||||
logging.warn(
|
||||
'No workers support hearbeats. Failure handling will be disabled.')
|
||||
|
||||
def saver(self):
|
||||
if self._saver:
|
||||
return self._saver
|
||||
|
||||
savers = ops.get_collection(ops.GraphKeys.SAVERS)
|
||||
if not savers:
|
||||
return None
|
||||
|
||||
if not isinstance(savers, list):
|
||||
return savers
|
||||
|
||||
if len(savers) > 1:
|
||||
logging.error(
|
||||
'Multiple savers in the SAVERS collection. On-demand checkpointing '
|
||||
'will be disabled. Pass an explicit `saver` to the constructor to '
|
||||
'override this behavior.')
|
||||
return None
|
||||
|
||||
return savers[0]
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
del run_values
|
||||
|
||||
if not self._heartbeat_supported:
|
||||
return
|
||||
|
||||
lame_workers = self._workers.lame_workers()
|
||||
if lame_workers:
|
||||
logging.info('ShutdownHook: lame workers found: %s', lame_workers)
|
||||
|
||||
if self.saver():
|
||||
logging.info('ShutdownHook: saving checkpoint to %s',
|
||||
self._checkpoint_prefix)
|
||||
self.saver().save(
|
||||
run_context.session,
|
||||
self._checkpoint_prefix,
|
||||
global_step=training_util.get_global_step(),
|
||||
write_state=True,
|
||||
)
|
||||
else:
|
||||
logging.info('ShutdownHook: no Saver defined.')
|
||||
|
||||
for fn in self._on_shutdown_hooks:
|
||||
fn(run_context, self._workers, lame_workers)
|
||||
|
||||
|
||||
class RestartComputation(object):
|
||||
"""Restart the entire computation.
|
||||
|
||||
This hook shuts down all workers and returns control to the top-level by
|
||||
throwing a CoordinatorShutdownException.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout_ms=10000):
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
def __call__(self, run_context, all_workers, lame_workers):
|
||||
del run_context, lame_workers
|
||||
all_workers.shutdown(timeout_ms=self.timeout_ms)
|
||||
|
||||
logging.info('Terminating coordinator.')
|
||||
raise CoordinatorShutdownException()
|
||||
|
||||
|
||||
class ShutdownLameWorkers(object):
|
||||
"""Shutdown lamed workers.
|
||||
|
||||
Processing will continue normally (typically by waiting for the down
|
||||
workers to be restarted).
|
||||
"""
|
||||
|
||||
def __init__(self, timeout_ms=10000):
|
||||
self.timeout_in_ms = timeout_ms
|
||||
|
||||
def __call__(self, run_context, all_workers, lame_workers):
|
||||
lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
|
1638
tensorflow/python/tpu/tensor_tracer.py
Normal file
1638
tensorflow/python/tpu/tensor_tracer.py
Normal file
File diff suppressed because it is too large
Load Diff
220
tensorflow/python/tpu/topology.py
Normal file
220
tensorflow/python/tpu/topology.py
Normal file
@ -0,0 +1,220 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
"""Defines the `Topology` class, that describes a TPU fabric topology."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf.tpu import topology_pb2
|
||||
|
||||
|
||||
def _tpu_device_name(job, task, device):
|
||||
"""Returns the device name for the TPU `device` on `task` of `job`."""
|
||||
if job is None:
|
||||
return "/task:%d/device:TPU:%d" % (task, device)
|
||||
else:
|
||||
return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
|
||||
|
||||
|
||||
def _tpu_host_device_name(job, task):
|
||||
"""Returns the device name for the CPU device on `task` of `job`."""
|
||||
if job is None:
|
||||
return "/task:%d/device:CPU:0" % task
|
||||
else:
|
||||
return "/job:%s/task:%d/device:CPU:0" % (job, task)
|
||||
|
||||
|
||||
class Topology(object):
|
||||
"""Describes a set of TPU devices.
|
||||
|
||||
Represents both the shape of the physical mesh, and the mapping between
|
||||
TensorFlow TPU devices to physical mesh coordinates.
|
||||
"""
|
||||
|
||||
def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
|
||||
"""Builds a Topology object.
|
||||
|
||||
If `serialized` is not `None`, the topology is parsed from `serialized` and
|
||||
the other arguments are ignored. Otherwise, the topology is computed from
|
||||
`mesh_shape` and `device_coordinates`.
|
||||
|
||||
Args:
|
||||
serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
|
||||
serialized proto is parsed to discover the topology.
|
||||
mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
|
||||
the shape of the TPU topology, in number of cores. Ignored if
|
||||
`serialized` is not `None`.
|
||||
device_coordinates: A rank 3 numpy array that describes the mapping from
|
||||
TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
|
||||
if `serialized is not `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `serialized` does not describe a well-formed topology.
|
||||
ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
|
||||
of 3 positive integers.
|
||||
ValueError: If `serialized` is `None` and `device_coordinates` is not a
|
||||
rank 3 numpy int32 array that describes a valid coordinate mapping.
|
||||
"""
|
||||
|
||||
self._serialized = serialized
|
||||
|
||||
if serialized:
|
||||
self._parse_topology(serialized)
|
||||
else:
|
||||
self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
|
||||
self._device_coordinates = np.asarray(device_coordinates, np.int32)
|
||||
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
|
||||
raise ValueError("`mesh_shape` must be a sequence of 3 positive "
|
||||
"entries; got {}".format(self._mesh_shape))
|
||||
|
||||
if (len(self._device_coordinates.shape) != 3 or
|
||||
self._device_coordinates.shape[2] != len(self._mesh_shape)):
|
||||
raise ValueError("`device_coordinates` must be a rank 3 int32 array "
|
||||
"with minor dimension equal to the mesh shape rank")
|
||||
|
||||
self._topology_tasks, self._topology_devices = self._invert_topology()
|
||||
|
||||
def _parse_topology(self, serialized):
|
||||
"""Parses a serialized `TopologyProto` into `self`."""
|
||||
proto = topology_pb2.TopologyProto()
|
||||
proto.ParseFromString(serialized)
|
||||
|
||||
self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
|
||||
if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
|
||||
raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
|
||||
"entries; got {}".format(self._mesh_shape))
|
||||
|
||||
if proto.num_tasks < 0:
|
||||
raise ValueError("`num_tasks` must be >= 0; got {}".format(
|
||||
proto.num_tasks))
|
||||
if proto.num_tpu_devices_per_task < 0:
|
||||
raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
|
||||
proto.num_tpu_devices_per_task))
|
||||
|
||||
expected_coordinates_size = (
|
||||
proto.num_tasks * proto.num_tpu_devices_per_task * len(
|
||||
proto.mesh_shape))
|
||||
if len(proto.device_coordinates) != expected_coordinates_size:
|
||||
raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
|
||||
"num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
|
||||
"got shape {}".format(proto.num_tasks,
|
||||
proto.num_tpu_devices_per_task,
|
||||
proto.mesh_shape,
|
||||
len(proto.device_coordinates)))
|
||||
|
||||
coords = np.array(proto.device_coordinates, dtype=np.int32)
|
||||
if any(coords < 0):
|
||||
raise ValueError("`device_coordinates` must be >= 0")
|
||||
coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
|
||||
len(proto.mesh_shape)))
|
||||
self._device_coordinates = coords
|
||||
|
||||
def _invert_topology(self):
|
||||
"""Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
|
||||
tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
|
||||
devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
|
||||
for task in xrange(self.device_coordinates.shape[0]):
|
||||
for device in xrange(self.device_coordinates.shape[1]):
|
||||
x, y, z = self.device_coordinates[task, device, :]
|
||||
tasks[x, y, z] = task
|
||||
devices[x, y, z] = device
|
||||
return tasks, devices
|
||||
|
||||
@property
|
||||
def mesh_shape(self):
|
||||
"""A rank 1 int32 array describing the shape of the TPU topology."""
|
||||
return self._mesh_shape
|
||||
|
||||
@property
|
||||
def mesh_rank(self):
|
||||
"""Returns the number of dimensions in the mesh."""
|
||||
return len(self._mesh_shape)
|
||||
|
||||
@property
|
||||
def device_coordinates(self):
|
||||
"""Describes the mapping from TPU devices to topology coordinates.
|
||||
|
||||
Returns:
|
||||
A rank 3 int32 array with shape `[tasks, devices, axis]`.
|
||||
`tasks` is the number of tasks in the TPU cluster, `devices` is the number
|
||||
of TPU devices per task, and `axis` is the number of axes in the TPU
|
||||
cluster topology. Each entry gives the `axis`-th coordinate in the
|
||||
topology of a task/device pair. TPU topologies are 3-dimensional, with
|
||||
dimensions `(x, y, core number)`.
|
||||
"""
|
||||
return self._device_coordinates
|
||||
|
||||
def task_ordinal_at_coordinates(self, device_coordinates):
|
||||
"""Returns the TensorFlow task number attached to `device_coordinates`.
|
||||
|
||||
Args:
|
||||
device_coordinates: An integer sequence describing a device's physical
|
||||
coordinates in the TPU fabric.
|
||||
|
||||
Returns:
|
||||
Returns the TensorFlow task number that contains the TPU device with those
|
||||
physical coordinates.
|
||||
"""
|
||||
return self._topology_tasks[tuple(device_coordinates)]
|
||||
|
||||
def tpu_device_ordinal_at_coordinates(self, device_coordinates):
|
||||
"""Returns the TensorFlow device number at `device_coordinates`.
|
||||
|
||||
Args:
|
||||
device_coordinates: An integer sequence describing a device's physical
|
||||
coordinates in the TPU fabric.
|
||||
|
||||
Returns:
|
||||
Returns the TensorFlow device number within the task corresponding to
|
||||
attached to the device with those physical coordinates.
|
||||
"""
|
||||
return self._topology_devices[tuple(device_coordinates)]
|
||||
|
||||
def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
return _tpu_host_device_name(
|
||||
job, self._topology_tasks[tuple(device_coordinates)])
|
||||
|
||||
def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
return _tpu_device_name(job,
|
||||
self._topology_tasks[tuple(device_coordinates)],
|
||||
self._topology_devices[tuple(device_coordinates)])
|
||||
|
||||
@property
|
||||
def num_tasks(self):
|
||||
"""Returns the number of TensorFlow tasks in the TPU slice."""
|
||||
return self._device_coordinates.shape[0]
|
||||
|
||||
@property
|
||||
def num_tpus_per_task(self):
|
||||
"""Returns the number of TPU devices per task in the TPU slice."""
|
||||
return self._device_coordinates.shape[1]
|
||||
|
||||
def serialized(self):
|
||||
"""Returns the serialized form of the topology."""
|
||||
if self._serialized is None:
|
||||
proto = topology_pb2.TopologyProto()
|
||||
proto.mesh_shape[:] = list(self._mesh_shape)
|
||||
proto.num_tasks = self._device_coordinates.shape[0]
|
||||
proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
|
||||
proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
|
||||
self._serialized = proto.SerializeToString()
|
||||
|
||||
return self._serialized
|
@ -19,9 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import topology
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import topology
|
||||
|
||||
|
||||
class TopologyTest(test.TestCase):
|
1576
tensorflow/python/tpu/tpu.py
Normal file
1576
tensorflow/python/tpu/tpu.py
Normal file
File diff suppressed because it is too large
Load Diff
276
tensorflow/python/tpu/tpu_config.py
Normal file
276
tensorflow/python/tpu/tpu_config.py
Normal file
@ -0,0 +1,276 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""A RunConfig subclass with TPU support."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import util as util_lib
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
|
||||
_SERVICE_KEY = run_config_lib._SERVICE_KEY
|
||||
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class InputPipelineConfig(object):
|
||||
r"""Please see the definition of these values in TPUConfig."""
|
||||
PER_SHARD_V1 = 1
|
||||
PER_HOST_V1 = 2
|
||||
PER_HOST_V2 = 3
|
||||
BROADCAST = 4
|
||||
|
||||
|
||||
class TPUConfig(
|
||||
collections.namedtuple('TPUConfig', [
|
||||
'iterations_per_loop',
|
||||
'num_shards',
|
||||
'num_cores_per_replica',
|
||||
'per_host_input_for_training',
|
||||
'tpu_job_name',
|
||||
'initial_infeed_sleep_secs',
|
||||
'input_partition_dims',
|
||||
])):
|
||||
r"""TPU related configuration required by `TPUEstimator`.
|
||||
|
||||
Args:
|
||||
iterations_per_loop: This is the number of train steps running in TPU
|
||||
system before returning to CPU host for each `Session.run`. This means
|
||||
global step is increased `iterations_per_loop` times in one `Session.run`.
|
||||
It is recommended to be set as number of global steps for next checkpoint.
|
||||
num_shards: (Deprecated, ignored by TPUEstimator).
|
||||
The number of model replicas in the system. For non-model-parallelism
|
||||
case, this number equals the total number of TPU cores. For
|
||||
model-parallelism, the total number of TPU cores equals
|
||||
num_cores_per_replica * num_shards.
|
||||
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
|
||||
An integer which describes the number of TPU cores per model replica. This
|
||||
is required by model-parallelism which enables partitioning
|
||||
the model to multiple cores. Currently num_cores_per_replica must be
|
||||
1, 2, 4, or 8.
|
||||
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
|
||||
`input_fn` is invoked once on each host. With the per-core input pipeline
|
||||
configuration, it is invoked once for each core.
|
||||
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
|
||||
the batch size for each shard is `train_batch_size` // #hosts in the
|
||||
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
|
||||
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
|
||||
invoked once on host 0 and the tensors are broadcasted to all other
|
||||
replicas. The batch size equals to train_batch_size`. With the per-core
|
||||
input pipeline configuration, the shard batch size is also
|
||||
`train_batch_size` // #cores.
|
||||
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
|
||||
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
|
||||
within TPUEstimator, however when using ClusterSpec propagation in more
|
||||
esoteric cluster configurations, you may need to specify the job name as a
|
||||
string.
|
||||
initial_infeed_sleep_secs: The number of seconds the infeed thread should
|
||||
wait before enqueueing the first batch. This helps avoid timeouts for
|
||||
models that require a long compilation time.
|
||||
input_partition_dims: A nested list to describe the partition dims
|
||||
for all the tensors from input_fn(). The structure of
|
||||
input_partition_dims must match the structure of `features` and
|
||||
`labels` from input_fn(). The total number of partitions must match
|
||||
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
|
||||
images with shape [N, H, W, C] and labels [N].
|
||||
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
|
||||
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
|
||||
to all the TPU cores since the partition dims is `None`.
|
||||
Current limitations: This feature is only supported with the PER_HOST_V2
|
||||
input mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
iterations_per_loop=2,
|
||||
num_shards=None,
|
||||
num_cores_per_replica=None,
|
||||
per_host_input_for_training=True,
|
||||
tpu_job_name=None,
|
||||
initial_infeed_sleep_secs=None,
|
||||
input_partition_dims=None):
|
||||
|
||||
# Check iterations_per_loop.
|
||||
util_lib.check_positive_integer(iterations_per_loop,
|
||||
'TPUConfig iterations_per_loop')
|
||||
|
||||
# Check num_shards.
|
||||
if num_shards is not None:
|
||||
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
|
||||
|
||||
if input_partition_dims is not None:
|
||||
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
|
||||
raise ValueError(
|
||||
'input_partition_dims must be a list/tuple with one or two'
|
||||
' elements.')
|
||||
|
||||
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
|
||||
raise ValueError(
|
||||
'input_partition_dims is only supported in PER_HOST_V2 mode.')
|
||||
|
||||
if num_cores_per_replica is None:
|
||||
raise ValueError(
|
||||
'input_partition_dims requires setting num_cores_per_replica.')
|
||||
|
||||
# Check num_cores_per_replica
|
||||
if num_cores_per_replica is not None:
|
||||
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
|
||||
raise ValueError(
|
||||
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
|
||||
str(num_cores_per_replica)))
|
||||
|
||||
# per_host_input_for_training may be True, False, or integer in [1..3].
|
||||
# Map legacy values (True, False) to numeric values.
|
||||
if per_host_input_for_training is False:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
|
||||
elif per_host_input_for_training is True:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
|
||||
|
||||
# Check initial_infeed_sleep_secs.
|
||||
if initial_infeed_sleep_secs:
|
||||
util_lib.check_positive_integer(initial_infeed_sleep_secs,
|
||||
'TPUConfig initial_infeed_sleep_secs')
|
||||
|
||||
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
|
||||
|
||||
return super(TPUConfig, cls).__new__(
|
||||
cls,
|
||||
iterations_per_loop=iterations_per_loop,
|
||||
num_shards=num_shards,
|
||||
num_cores_per_replica=num_cores_per_replica,
|
||||
per_host_input_for_training=per_host_input_for_training,
|
||||
tpu_job_name=tpu_job_name,
|
||||
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
|
||||
input_partition_dims=input_partition_dims)
|
||||
|
||||
|
||||
class RunConfig(run_config_lib.RunConfig):
|
||||
"""RunConfig with TPU support."""
|
||||
|
||||
def __init__(self,
|
||||
tpu_config=None,
|
||||
evaluation_master=None,
|
||||
master=None,
|
||||
cluster=None,
|
||||
**kwargs):
|
||||
"""Constructs a RunConfig.
|
||||
|
||||
Args:
|
||||
tpu_config: the TPUConfig that specifies TPU-specific configuration.
|
||||
evaluation_master: a string. The address of the master to use for eval.
|
||||
Defaults to master if not set.
|
||||
master: a string. The address of the master to use for training.
|
||||
cluster: a ClusterResolver
|
||||
**kwargs: keyword config parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: if cluster is not None and the provided session_config has a
|
||||
cluster_def already.
|
||||
"""
|
||||
super(RunConfig, self).__init__(**kwargs)
|
||||
self._tpu_config = tpu_config or TPUConfig()
|
||||
self._cluster = cluster
|
||||
|
||||
# If user sets master and/or evaluation_master explicitly, including empty
|
||||
# string '', take it. Otherwise, take the values set by parent class.
|
||||
if master is not None:
|
||||
if cluster is not None:
|
||||
raise ValueError('Both master and cluster are set.')
|
||||
self._master = master
|
||||
else:
|
||||
if cluster:
|
||||
self._master = cluster.master()
|
||||
|
||||
if evaluation_master is not None:
|
||||
self._evaluation_master = evaluation_master
|
||||
elif (not self._evaluation_master and
|
||||
self.task_type != run_config_lib.TaskType.EVALUATOR):
|
||||
# If the task type is EVALUATOR, it means some cluster manager sets the
|
||||
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
|
||||
#
|
||||
# Otherwise, it means user executes the code without external cluster
|
||||
# manager. For that, we optimize the user experience by setting
|
||||
# evaluation_master to master, unless user overwrites it.
|
||||
self._evaluation_master = self._master
|
||||
|
||||
# Set the ClusterSpec to use
|
||||
if cluster:
|
||||
self._cluster_spec = cluster.cluster_spec()
|
||||
|
||||
# Merge the cluster_def into the ConfigProto.
|
||||
if self._session_config is None: # pylint: disable=access-member-before-definition
|
||||
self._session_config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=True, isolate_session_state=True)
|
||||
if self._session_config.HasField('cluster_def'):
|
||||
raise ValueError(
|
||||
'You cannot provide a ClusterResolver and '
|
||||
'session_config.cluster_def.')
|
||||
if self._cluster_spec:
|
||||
self._session_config.cluster_def.CopyFrom(
|
||||
self._cluster_spec.as_cluster_def())
|
||||
|
||||
def _maybe_overwrite_session_config_for_distributed_training(self):
|
||||
# Overrides the parent class session_config overwrite for between-graph. TPU
|
||||
# runs with in-graph, which should not have device filter. Doing nothing
|
||||
# ("pass") basically disables it.
|
||||
pass
|
||||
|
||||
@property
|
||||
def evaluation_master(self):
|
||||
return self._evaluation_master
|
||||
|
||||
@property
|
||||
def master(self):
|
||||
return self._master
|
||||
|
||||
@property
|
||||
def tpu_config(self):
|
||||
return self._tpu_config
|
||||
|
||||
@property
|
||||
def cluster(self):
|
||||
return self._cluster
|
||||
|
||||
def replace(self, **kwargs):
|
||||
if 'tpu_config' not in kwargs:
|
||||
return super(RunConfig, self).replace(**kwargs)
|
||||
|
||||
tpu_config = kwargs.pop('tpu_config')
|
||||
new_instance = super(RunConfig, self).replace(**kwargs)
|
||||
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
|
||||
return new_instance
|
||||
|
||||
|
||||
def _get_tpu_job_name_from_tf_config():
|
||||
"""Extracts the TPU job name from TF_CONFIG env variable."""
|
||||
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
|
||||
# spec propagation.
|
||||
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
|
||||
if tpu_job_name:
|
||||
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
|
||||
return tpu_job_name
|
@ -20,10 +20,10 @@ from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_config as tpu_config_lib
|
||||
|
||||
|
||||
def _set_tf_config_env_variable(tf_config):
|
763
tensorflow/python/tpu/tpu_context.py
Normal file
763
tensorflow/python/tpu/tpu_context.py
Normal file
@ -0,0 +1,763 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU system metadata and associated tooling."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import _tpu_estimator_embedding
|
||||
from tensorflow.python.tpu import device_assignment as tpu_device_assignment
|
||||
from tensorflow.python.tpu import tpu_config
|
||||
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
|
||||
|
||||
|
||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||
_LOCAL_MASTERS = ('', 'local')
|
||||
_NUM_CORES_TO_COMPUTATION_SHAPE = {
|
||||
1: [1, 1, 1],
|
||||
2: [1, 1, 2],
|
||||
4: [1, 2, 2],
|
||||
8: [2, 2, 2],
|
||||
16: [4, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
class TPUContext(object):
|
||||
"""A context that holds the current configuration of the TPU computation."""
|
||||
|
||||
def __init__(self,
|
||||
internal_ctx,
|
||||
input_device=None,
|
||||
invocation_index=None,
|
||||
call_from_input_fn=True):
|
||||
self._internal_ctx = internal_ctx
|
||||
self._input_device = input_device
|
||||
self._invocation_index = invocation_index
|
||||
self._call_from_input_fn = call_from_input_fn
|
||||
|
||||
def current_input_fn_deployment(self):
|
||||
"""The configuration of the current input_fn invocation.
|
||||
|
||||
The configuration depends on `TPUConfig.per_host_input_for_training`. See
|
||||
`TPUConfig` for details.
|
||||
|
||||
Only set in params dict of input_fn
|
||||
|
||||
Returns:
|
||||
A tuple of
|
||||
1. Device spec string: String, is the current CPU host where the
|
||||
input_fn is invoked.
|
||||
2. Current invocation index: Int, 0-based index of the input_fn
|
||||
invocation. See next item for details.
|
||||
3. Total invocation count: Int, the total number of times to invoke the
|
||||
input_fn on all CPU hosts. Each invocation will be passed with a new
|
||||
`TPUContext` instance with current invocation index set properly.
|
||||
4. Total number of replicas consumed by current_invocation: Int, the
|
||||
number of replicas fed by the data returned by current input_fn. For
|
||||
example, for per_core input pipeline deployment
|
||||
and non-model-parallelism, total invocation count is equal to
|
||||
the number of cores in the system and num replicas consumed by
|
||||
current invocation is 1. For per-host v2 input pipeline deployment,
|
||||
total invocation count is equal to the number of hosts in the system
|
||||
and num replicas consumed by current invocation is equal to number of
|
||||
cores per host.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If this method must not be called from input_fn.
|
||||
"""
|
||||
if not self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' model_fn.')
|
||||
|
||||
if self._internal_ctx.is_input_sharded_per_core():
|
||||
total_invocation_count = (self._internal_ctx.num_hosts
|
||||
* self._internal_ctx.num_of_replicas_per_host)
|
||||
replicas_consumed = 1
|
||||
elif self._internal_ctx.is_input_broadcast_with_iterators():
|
||||
total_invocation_count = 1
|
||||
replicas_consumed = self._internal_ctx.num_replicas
|
||||
else:
|
||||
total_invocation_count = self._internal_ctx.num_hosts
|
||||
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
|
||||
return (self._input_device, self._invocation_index,
|
||||
total_invocation_count, replicas_consumed)
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
"""The total number of replicas.
|
||||
|
||||
For non-model-parallelism, num_replicas should be the total num of TPU
|
||||
cores in the system.
|
||||
|
||||
Returns:
|
||||
The number of replicas.
|
||||
"""
|
||||
return self._internal_ctx.num_replicas
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
"""The number of hosts for the TPU system."""
|
||||
return self._internal_ctx.num_hosts
|
||||
|
||||
@property
|
||||
def current_host(self):
|
||||
"""The current host index for the TPU system."""
|
||||
return self._invocation_index
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""The number of replicas for each host."""
|
||||
if self._internal_ctx.model_parallelism_enabled:
|
||||
raise ValueError(
|
||||
'num_of_replicas_per_host is not supported for model_parallelism')
|
||||
return self._internal_ctx.num_of_replicas_per_host
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
"""Returns device_assignment object."""
|
||||
if self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' input_fn.')
|
||||
return self._internal_ctx.device_assignment
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
# Note that: For the non-model parallelism, the mapping could be
|
||||
# a random permutation. The order should not matter in most cases
|
||||
# as far as model is replicated to all cores in the system.
|
||||
return self._internal_ctx.device_for_replica(replica_id)
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function.
|
||||
|
||||
The place function takes host_id as the input and returns the TF device
|
||||
for the correspoding host.
|
||||
"""
|
||||
|
||||
def _placement_function(host_id):
|
||||
"""Return the host device given host_id."""
|
||||
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
|
||||
class _InternalTPUContext(object):
|
||||
"""A context holds immutable states of TPU computation.
|
||||
|
||||
This immutable object holds TPUEstimator config, train/eval batch size, and
|
||||
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
|
||||
provides utility functions, based on the current state, to determine other
|
||||
information commonly required by TPU computation, such as TPU device names,
|
||||
TPU hosts, shard batch size, etc.
|
||||
|
||||
if eval_on_tpu is False, then execution of eval on TPU is disabled.
|
||||
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
|
||||
and TPU execution is disabled for all modes.
|
||||
|
||||
N.B. As `mode` is not immutable state in Estimator, but essential to
|
||||
distinguish between TPU training and evaluation, a common usage for
|
||||
_InternalTPUContext with `mode` is as follows:
|
||||
```
|
||||
with _ctx.with_mode(mode) as ctx:
|
||||
if ctx.is_running_on_cpu():
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
train_batch_size,
|
||||
eval_batch_size,
|
||||
predict_batch_size,
|
||||
use_tpu,
|
||||
eval_on_tpu=True,
|
||||
embedding_config_spec=None):
|
||||
self._config = config
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._predict_batch_size = predict_batch_size
|
||||
self._use_tpu = use_tpu
|
||||
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
|
||||
if not use_tpu and eval_on_tpu:
|
||||
logging.warning('eval_on_tpu ignored because use_tpu is False.')
|
||||
|
||||
self._eval_on_tpu = eval_on_tpu
|
||||
self._model_parallelism_enabled = (
|
||||
use_tpu and config.tpu_config.num_cores_per_replica)
|
||||
self._mode = None
|
||||
num_cores_per_replica = config.tpu_config.num_cores_per_replica
|
||||
if self._model_parallelism_enabled:
|
||||
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
|
||||
num_cores_per_replica]
|
||||
else:
|
||||
self._computation_shape = None
|
||||
self._lazy_tpu_system_metadata_dict = {} # key by master address
|
||||
self._lazy_device_assignment_dict = {} # key by master address
|
||||
self._lazy_validation_dict = {} # key by ModeKeys
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._lazy_embedding_config_dict = {} # key by master address
|
||||
|
||||
def _assert_mode(self):
|
||||
if self._mode is None:
|
||||
raise RuntimeError(
|
||||
'`mode` needs to be set via contextmanager `with_mode`.')
|
||||
return self._mode
|
||||
|
||||
@contextmanager
|
||||
def with_mode(self, mode):
|
||||
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
|
||||
# such as _lazy_tpu_system_metadata_dict between new copy and the original
|
||||
# one. Note that all lazy states stored in properties _lazy_foo are sort of
|
||||
# immutable as they should be same for the process lifetime.
|
||||
new_ctx = copy.copy(self)
|
||||
new_ctx._mode = mode # pylint: disable=protected-access
|
||||
yield new_ctx
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._assert_mode()
|
||||
|
||||
def _get_master_address(self):
|
||||
mode = self._assert_mode()
|
||||
config = self._config
|
||||
master = (
|
||||
config.master
|
||||
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
|
||||
return master
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
cluster_def = None
|
||||
if (self._config.session_config and
|
||||
self._config.session_config.cluster_def.job):
|
||||
cluster_def = self._config.session_config.cluster_def
|
||||
|
||||
# pylint: disable=protected-access
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._query_tpu_system_metadata(
|
||||
master,
|
||||
cluster_def=cluster_def,
|
||||
query_topology=self.model_parallelism_enabled))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
def _get_device_assignment(self):
|
||||
"""Gets the (maybe cached) TPU device assignment."""
|
||||
master = self._get_master_address()
|
||||
device_assignment = self._lazy_device_assignment_dict.get(master)
|
||||
if device_assignment is not None:
|
||||
return device_assignment
|
||||
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
|
||||
device_assignment = tpu_device_assignment.device_assignment(
|
||||
tpu_system_metadata.topology,
|
||||
computation_shape=self._computation_shape,
|
||||
num_replicas=self.num_replicas)
|
||||
|
||||
logging.info('num_cores_per_replica: %s',
|
||||
str(self._config.tpu_config.num_cores_per_replica))
|
||||
logging.info('computation_shape: %s', str(self._computation_shape))
|
||||
logging.info('num_replicas: %d', self.num_replicas)
|
||||
logging.info('device_assignment.topology.device_coordinates: %s',
|
||||
str(device_assignment.topology.device_coordinates))
|
||||
logging.info('device_assignment.core_assignment: %s',
|
||||
str(device_assignment.core_assignment))
|
||||
|
||||
self._lazy_device_assignment_dict[master] = device_assignment
|
||||
return device_assignment
|
||||
|
||||
@property
|
||||
def embedding_config(self):
|
||||
"""Returns the embedding config based on current mode."""
|
||||
master = self._get_master_address()
|
||||
if master in self._lazy_embedding_config_dict:
|
||||
embedding_config = self._lazy_embedding_config_dict[master]
|
||||
else:
|
||||
embedding_config = None
|
||||
if self._use_tpu and self._embedding_config_spec:
|
||||
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
|
||||
self._embedding_config_spec, self._train_batch_size,
|
||||
self._eval_batch_size, self.num_hosts, self.num_cores, master)
|
||||
if not embedding_config.has_embedding_tables():
|
||||
embedding_config = None
|
||||
self._lazy_embedding_config_dict[master] = embedding_config
|
||||
|
||||
if embedding_config is not None:
|
||||
mode = self._assert_mode()
|
||||
# Dynamically attach tpu_embedding based on mode. With
|
||||
# this, we could keep embedding_config immutable but call site always
|
||||
# accesses the unified API '.tpu_embedding'.
|
||||
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
|
||||
return embedding_config
|
||||
|
||||
@property
|
||||
def model_parallelism_enabled(self):
|
||||
return self._model_parallelism_enabled
|
||||
|
||||
@property
|
||||
def input_partition_dims(self):
|
||||
return self._config.tpu_config.input_partition_dims
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
return (self._get_device_assignment()
|
||||
if self._model_parallelism_enabled else None)
|
||||
|
||||
@property
|
||||
def num_of_cores_per_host(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_cores(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_cores
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""Return the number of replicas per host."""
|
||||
if self.model_parallelism_enabled:
|
||||
return self.num_replicas // self.num_hosts
|
||||
else:
|
||||
return self.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
num_cores_in_system = self.num_cores
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
if num_cores_per_replica > num_cores_in_system:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the total num of '
|
||||
'TPU cores in the system. num_cores_per_replica: {}, num cores '
|
||||
'in the system: {}'.format(num_cores_per_replica,
|
||||
num_cores_in_system))
|
||||
|
||||
if num_cores_in_system % num_cores_per_replica != 0:
|
||||
raise RuntimeError(
|
||||
'The num of cores in the system ({}) is not divisible by the num '
|
||||
'of cores ({}) required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
|
||||
num_cores_in_system, num_cores_per_replica))
|
||||
|
||||
return num_cores_in_system // num_cores_per_replica
|
||||
else:
|
||||
return num_cores_in_system
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_hosts
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def is_input_sharded_per_core(self):
|
||||
"""Return true if input_fn is invoked per-core (other than per-host)."""
|
||||
mode = self._assert_mode()
|
||||
return (mode == model_fn_lib.ModeKeys.TRAIN and
|
||||
(self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_SHARD_V1))
|
||||
|
||||
def is_input_per_host_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the per-host v2 config."""
|
||||
return (self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_HOST_V2)
|
||||
|
||||
def is_input_broadcast_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the full_replicae config."""
|
||||
return (self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.BROADCAST)
|
||||
|
||||
def is_running_on_cpu(self, is_export_mode=False):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU.
|
||||
|
||||
This API also validates user provided configuration, such as batch size,
|
||||
according the lazy initialized TPU system metadata.
|
||||
|
||||
Args:
|
||||
is_export_mode: Indicates whether the current mode is for exporting the
|
||||
model, when mode == PREDICT. Only with this bool, we could
|
||||
tell whether user is calling the Estimator.predict or
|
||||
Estimator.export_savedmodel, which are running on TPU and CPU
|
||||
respectively. Parent class Estimator does not distinguish these two.
|
||||
|
||||
Returns:
|
||||
bool, whether current input_fn or model_fn should be running on CPU.
|
||||
|
||||
Raises:
|
||||
ValueError: any configuration is invalid.
|
||||
"""
|
||||
|
||||
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
|
||||
if not is_running_on_cpu:
|
||||
self._validate_tpu_configuration()
|
||||
return is_running_on_cpu
|
||||
|
||||
def _is_running_on_cpu(self, is_export_mode):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
|
||||
mode = self._assert_mode()
|
||||
|
||||
if not self._use_tpu:
|
||||
return True
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
|
||||
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
||||
return True
|
||||
|
||||
if is_export_mode:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def global_batch_size(self):
|
||||
mode = self._assert_mode()
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
return self._train_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return self._eval_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return self._predict_batch_size
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def batch_size_for_input_fn(self):
|
||||
"""Returns the shard batch size for `input_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU
|
||||
if self.is_input_sharded_per_core() or (
|
||||
self.is_input_per_host_with_iterators()):
|
||||
return global_batch_size // self.num_replicas
|
||||
else:
|
||||
return global_batch_size // self.num_hosts
|
||||
|
||||
@property
|
||||
def batch_size_for_model_fn(self):
|
||||
"""Returns the shard batch size for `model_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU. always sharded per shard.
|
||||
return global_batch_size // self.num_replicas
|
||||
|
||||
@property
|
||||
def master_job(self):
|
||||
"""Returns the job name to use to place TPU computations on.
|
||||
|
||||
Returns:
|
||||
A string containing the job name, or None if no job should be specified.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user needs to specify a tpu_job_name, because we are
|
||||
unable to infer the job name automatically, or if the user-specified job
|
||||
names are inappropriate.
|
||||
"""
|
||||
run_config = self._config
|
||||
# If the user specifies the tpu_job_name, use that.
|
||||
if run_config.tpu_config.tpu_job_name:
|
||||
return run_config.tpu_config.tpu_job_name
|
||||
|
||||
# The tpu job is determined by the run_config. Right now, this method is
|
||||
# required as tpu_config is not part of the RunConfig.
|
||||
mode = self._assert_mode()
|
||||
master = (
|
||||
run_config.evaluation_master
|
||||
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
|
||||
if master in _LOCAL_MASTERS:
|
||||
return None
|
||||
|
||||
if (not run_config.session_config or
|
||||
not run_config.session_config.cluster_def.job):
|
||||
return _DEFAULT_JOB_NAME
|
||||
cluster_def = run_config.session_config.cluster_def
|
||||
job_names = set([job.name for job in cluster_def.job])
|
||||
if _DEFAULT_JOB_NAME in job_names:
|
||||
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
|
||||
raise ValueError('Currently, tpu_worker is not an allowed job name.')
|
||||
if len(job_names) == 1:
|
||||
return cluster_def.job[0].name
|
||||
if len(job_names) == 2:
|
||||
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
|
||||
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
|
||||
return job_names.pop()
|
||||
# TODO(b/67716447): Include more sophisticated heuristics.
|
||||
raise ValueError(
|
||||
'Could not infer TPU job name. Please specify a tpu_job_name as part '
|
||||
'of your TPUConfig.')
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function."""
|
||||
|
||||
master = self.master_job
|
||||
|
||||
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
|
||||
"""Return the host device given replica_id or host_id."""
|
||||
assert _sentinal is None
|
||||
if replica_id is not None and host_id is not None:
|
||||
raise RuntimeError(
|
||||
'replica_id and host_id can have only one non-None value.')
|
||||
|
||||
if master is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
if replica_id is not None:
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.host_device(
|
||||
replica=replica_id, job=master)
|
||||
else:
|
||||
host_id = replica_id / self.num_of_cores_per_host
|
||||
|
||||
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
@property
|
||||
def tpu_device_placement_function(self):
|
||||
"""Returns a TPU device placement Fn."""
|
||||
master = self.master_job
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
def _placement_function(i):
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.tpu_device(replica=i, job=master)
|
||||
else:
|
||||
num_of_cores_per_host = self.num_of_cores_per_host
|
||||
host_id = i / num_of_cores_per_host
|
||||
ordinal_id = i % num_of_cores_per_host
|
||||
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
def tpu_ordinal_function(self, host_id):
|
||||
"""Returns the TPU ordinal fn."""
|
||||
|
||||
def _tpu_ordinal_function(shard_index_in_host):
|
||||
"""Return the TPU ordinal associated with a shard.
|
||||
|
||||
Required because the enqueue ops are placed on CPU.
|
||||
|
||||
Args:
|
||||
shard_index_in_host: the shard index
|
||||
|
||||
Returns:
|
||||
The ordinal of the TPU device the shard's infeed should be placed on.
|
||||
"""
|
||||
if self.model_parallelism_enabled:
|
||||
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
|
||||
replica = self.device_assignment.lookup_replicas(host_id,
|
||||
0)[shard_index_in_host]
|
||||
return self.device_assignment.tpu_ordinal(replica=replica)
|
||||
else:
|
||||
return shard_index_in_host % self.num_of_cores_per_host
|
||||
|
||||
return _tpu_ordinal_function
|
||||
|
||||
def _validate_tpu_configuration(self):
|
||||
"""Validates the configuration based on the TPU system metadata."""
|
||||
mode = self._assert_mode()
|
||||
if self._lazy_validation_dict.get(mode):
|
||||
return
|
||||
|
||||
# All following information is obtained from TPU system metadata.
|
||||
num_cores = self.num_cores
|
||||
num_replicas = self.num_replicas
|
||||
num_hosts = self.num_hosts
|
||||
|
||||
if not num_cores:
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
raise RuntimeError(
|
||||
'Cannot find any TPU cores in the system. Please double check '
|
||||
'Tensorflow master address and TPU worker(s). Available devices '
|
||||
'are {}.'.format(tpu_system_metadata.devices))
|
||||
|
||||
if self._config.tpu_config.num_shards:
|
||||
user_provided_num_replicas = self._config.tpu_config.num_shards
|
||||
if user_provided_num_replicas != num_replicas:
|
||||
message = (
|
||||
'TPUConfig.num_shards is not set correctly. According to TPU '
|
||||
'system metadata for Tensorflow master ({}): num_replicas should '
|
||||
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
|
||||
'be the total num of TPU cores in the system. For '
|
||||
'model-parallelism, the total number of TPU cores should be '
|
||||
'num_cores_per_replica * num_replicas. Please set it '
|
||||
'accordingly or leave it as `None`'.format(
|
||||
self._get_master_address(), num_replicas,
|
||||
user_provided_num_replicas))
|
||||
|
||||
raise ValueError(message)
|
||||
|
||||
if self._config.tpu_config.num_cores_per_replica:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
|
||||
if num_cores_per_replica > num_cores_per_host:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the '
|
||||
'num_cores_per_host. num_cores_per_replica: {}, '
|
||||
'num_cores_per_host: {}'.format(num_cores_per_replica,
|
||||
num_cores_per_host))
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if (self._train_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'train batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._train_batch_size, num_replicas))
|
||||
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
if self._eval_batch_size is None:
|
||||
raise ValueError(
|
||||
'eval_batch_size in TPUEstimator constructor cannot be `None`'
|
||||
'if .evaluate is running on TPU.')
|
||||
if (self._eval_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'eval batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._eval_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.evaluate should be running on single TPU'
|
||||
' instead of a Pod.')
|
||||
else:
|
||||
assert mode == model_fn_lib.ModeKeys.PREDICT
|
||||
if self._predict_batch_size is None:
|
||||
raise ValueError(
|
||||
'predict_batch_size in TPUEstimator constructor should not be '
|
||||
'`None` if .predict is running on TPU.')
|
||||
if (self._predict_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'predict batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._predict_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.predict should be running on single TPU worker. '
|
||||
'got {}.'.format(num_hosts))
|
||||
|
||||
# Record the state "validated" into lazy dictionary.
|
||||
self._lazy_validation_dict[mode] = True
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
master = self.master_job
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
return (self.device_assignment.host_device(
|
||||
replica=replica_id, job=master),
|
||||
self.device_assignment.tpu_ordinal(replica=replica_id))
|
||||
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
num_of_replicas_per_host = self.num_of_replicas_per_host
|
||||
host_id = replica_id / num_of_replicas_per_host
|
||||
ordinal_id = replica_id % num_of_replicas_per_host
|
||||
|
||||
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
|
||||
return (host_device, ordinal_id)
|
||||
|
||||
|
||||
class _OneCoreTPUContext(_InternalTPUContext):
|
||||
"""Special _InternalTPUContext for one core usage."""
|
||||
|
||||
def __init__(self, config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu):
|
||||
|
||||
super(_OneCoreTPUContext, self).__init__(
|
||||
config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
|
||||
num_cores=1,
|
||||
num_hosts=1,
|
||||
num_of_cores_per_host=1,
|
||||
topology=None,
|
||||
devices=[]))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
|
||||
def _get_tpu_context(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec):
|
||||
"""Returns an instance of `_InternalTPUContext`."""
|
||||
|
||||
if (config.tpu_config.num_shards == 1 and
|
||||
config.tpu_config.num_cores_per_replica is None):
|
||||
if embedding_config_spec is not None:
|
||||
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
|
||||
'when embedding_config_spec is not None.')
|
||||
logging.warning(
|
||||
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
|
||||
'Please fix as soon as possible (leaving num_shards as None.)')
|
||||
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec)
|
1105
tensorflow/python/tpu/tpu_embedding.py
Normal file
1105
tensorflow/python/tpu/tpu_embedding.py
Normal file
File diff suppressed because it is too large
Load Diff
153
tensorflow/python/tpu/tpu_embedding_gradient.py
Normal file
153
tensorflow/python/tpu/tpu_embedding_gradient.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Optional helper for gradient handling."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
|
||||
def get_gradients_through_compute_gradients(optimizer, loss, activations):
|
||||
"""Compute gradients to send to TPU embedding.
|
||||
|
||||
Args:
|
||||
optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
|
||||
Used to call compute_gradients().
|
||||
loss: a Tensor to call optimizer.compute_gradients() on.
|
||||
activations: an OrderedDict mapping feature_name to Tensors of activations.
|
||||
|
||||
Returns:
|
||||
An OrderedDict mapping from feature name Strings to Tensors of gradients of
|
||||
the loss wrt the activations of the features.
|
||||
"""
|
||||
activation_list = activations.values()
|
||||
grads_and_vars = optimizer.compute_gradients(loss, activation_list)
|
||||
grads = [grad for grad, _ in grads_and_vars]
|
||||
feature_to_gradient_dict = collections.OrderedDict(
|
||||
zip(activations.keys(), grads))
|
||||
return feature_to_gradient_dict
|
||||
|
||||
|
||||
def create_dummy_table_variables(tpu_embedding):
|
||||
"""Create dummy embedding table variables.
|
||||
|
||||
The sole purpose of these dummy variables are to trigger gradient
|
||||
calcuation wrt them so that the gradients wrt activation can be captured
|
||||
and later sent to TPU embedding.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, dummy table variables will be created for use
|
||||
with tpu_embedding.
|
||||
|
||||
Returns:
|
||||
A tuple of dummy variables and their initializer.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if collection to store gradients already exists and is not
|
||||
empty.
|
||||
"""
|
||||
dummy_table_variables = collections.OrderedDict()
|
||||
for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
|
||||
dummy_table_variables[table] = (
|
||||
# Explicitly specifying collections prevents this variable from
|
||||
# being added to the GLOBAL_VARIABLES collection, so that Saver()
|
||||
# ignores it.
|
||||
# But Tensorflow optimizer creates slot variable for these dummy
|
||||
# variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
|
||||
# which will be in GLOBAL_VARIABLES collection,
|
||||
variable_scope.get_variable(
|
||||
'tpu_embedding_dummy_table_variable_{}'.format(table),
|
||||
dtype=dtypes.float32,
|
||||
shape=[1],
|
||||
use_resource=True,
|
||||
trainable=True,
|
||||
collections=['tpu_embedding_dummy_table_variables']))
|
||||
|
||||
g = ops.get_default_graph()
|
||||
table_gradients = g.get_collection_ref(
|
||||
'tpu_embedding_gradients_table_{}'.format(table_id))
|
||||
if table_gradients:
|
||||
raise RuntimeError(
|
||||
'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
|
||||
table_gradients.extend(
|
||||
[None] * len(tpu_embedding.table_to_features_dict[table]))
|
||||
|
||||
return (dummy_table_variables,
|
||||
variables.variables_initializer(
|
||||
dummy_table_variables.values(),
|
||||
name='tpu_embedding_dummy_table_variables_init'))
|
||||
|
||||
|
||||
def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
|
||||
dummy_table_variables):
|
||||
"""Have activations depend on dummy table variables for gradient intercept.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
|
||||
tpu_embedding.
|
||||
activations: An OrderedDict of feature name String to activation tensors.
|
||||
dummy_table_variables: An OrderedDict of table name String to dummy table
|
||||
variables.
|
||||
|
||||
Returns:
|
||||
An OrderedDict of feature name String to activation tensors, which can be
|
||||
used just as the activations input.
|
||||
"""
|
||||
new_activations = collections.OrderedDict()
|
||||
for feature in activations:
|
||||
table = tpu_embedding.feature_to_table_dict[feature]
|
||||
new_activations[feature] = tpu_ops.tpu_embedding_activations(
|
||||
dummy_table_variables[table],
|
||||
activations[feature],
|
||||
table_id=tpu_embedding.table_to_config_dict.keys().index(table),
|
||||
lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
|
||||
return new_activations
|
||||
|
||||
|
||||
def get_gradients_through_dummy_table_variables(tpu_embedding):
|
||||
"""Get gradients wrt the activations of each feature.
|
||||
|
||||
Args:
|
||||
tpu_embedding: TPUEmbedding, create dummy table variable to be used with
|
||||
tpu_embedding.
|
||||
|
||||
Returns:
|
||||
An OrderedDict mapping feature name to gradient.
|
||||
|
||||
Raises:
|
||||
ValueError: if some gradients are not defined.
|
||||
"""
|
||||
g = ops.get_default_graph()
|
||||
feature_to_gradient_dict = collections.OrderedDict()
|
||||
for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
|
||||
table_gradients = g.get_collection(
|
||||
'tpu_embedding_gradients_table_{}'.format(table_id))
|
||||
if any(gradient is None for gradient in table_gradients):
|
||||
raise ValueError(
|
||||
'Table {} with id {} has undefined gradients: this is probably '
|
||||
'because the model asked TPUEmbedding to compute activations that '
|
||||
'were not used.'.format(table, table_id))
|
||||
for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
|
||||
table_gradients):
|
||||
feature_to_gradient_dict[feature] = gradient
|
||||
return feature_to_gradient_dict
|
3760
tensorflow/python/tpu/tpu_estimator.py
Normal file
3760
tensorflow/python/tpu/tpu_estimator.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -20,12 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_estimator
|
||||
|
||||
|
||||
def make_input_fn(num_samples):
|
919
tensorflow/python/tpu/tpu_feed.py
Normal file
919
tensorflow/python/tpu/tpu_feed.py
Normal file
@ -0,0 +1,919 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Helper library for handling infeed between hosts and TPUs.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_sharding
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def partition_or_replicate_on_host(tensor, dims):
|
||||
"""Partitions or replicates the input tensor.
|
||||
|
||||
The ops inside this function are placed on the host side.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor which will be partioned or replicated.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Returns:
|
||||
An iterator of `Tensor`s or a list of partioned tensors.
|
||||
"""
|
||||
if dims is None:
|
||||
return itertools.repeat(tensor)
|
||||
dims = np.array(dims)
|
||||
output = [tensor]
|
||||
shape_list = np.array(tensor.shape.as_list())
|
||||
quotients, remainders = np.divmod(shape_list, dims)
|
||||
for axis, (quotient, remainder, dim, original_size) in enumerate(
|
||||
zip(quotients, remainders, dims, shape_list)):
|
||||
if dim <= 1:
|
||||
continue
|
||||
if remainder > 0:
|
||||
# For each dimension, when it cannot be evenly partitioned, XLA assumes
|
||||
# tensors are partitioned in a greedy manner by using
|
||||
# ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
|
||||
# are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
|
||||
# [[(3, 4), (3, 4), (2, 4), (2, 2)],
|
||||
# [(2, 4), (2, 4), (2, 4), (2, 2)]]
|
||||
ceil_ratio = quotient + 1
|
||||
num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
|
||||
num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
|
||||
if len(num_or_size_splits) < dim:
|
||||
num_or_size_splits += [0] * (dim - len(num_or_size_splits))
|
||||
new_output = []
|
||||
for x in output:
|
||||
new_output.append(
|
||||
array_ops.split(
|
||||
x, num_or_size_splits=num_or_size_splits, axis=axis))
|
||||
output = new_output
|
||||
else:
|
||||
output = [array_ops.split(x, dim, axis=axis) for x in output]
|
||||
output = nest.flatten(output)
|
||||
return output
|
||||
|
||||
|
||||
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
||||
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
||||
|
||||
Args:
|
||||
tensor: The dequeued tensor on TPU.
|
||||
dims: A list of integer describes how the tensor is partitioned.
|
||||
|
||||
Returns:
|
||||
The same tensor with the xla_sharding attribute.
|
||||
"""
|
||||
if dims is None:
|
||||
return xla_sharding.replicate(tensor)
|
||||
elif np.prod(dims) == 1:
|
||||
return xla_sharding.assign_device(tensor, 0)
|
||||
else:
|
||||
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
||||
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
|
||||
|
||||
|
||||
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
||||
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
|
||||
|
||||
Args:
|
||||
dequeues: A list of dequeued tensors on TPU.
|
||||
dims: A list of integer describes how the tensor is partitioned.
|
||||
|
||||
Returns:
|
||||
The same dequeues with appropriate xla_sharding attribute.
|
||||
"""
|
||||
nest.assert_shallow_structure(dequeues, dims)
|
||||
return nest.map_structure_up_to(
|
||||
dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
|
||||
|
||||
|
||||
class InfeedQueue(object):
|
||||
"""A helper object to build a device infeed queue.
|
||||
|
||||
The InfeedQueue builds the host-side and device-side Ops to enqueue and
|
||||
dequeue elements, respectively, and ensures that their types and
|
||||
shapes match.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
number_of_tuple_elements=None,
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
name=None):
|
||||
"""Creates a new InfeedQueue with the given configuration.
|
||||
|
||||
The configuration need not be fully specified at creation since it
|
||||
can be modified subsequently by methods that set the values
|
||||
explicitly or infer them from the shapes of inputs.
|
||||
|
||||
Args:
|
||||
number_of_tuple_elements: the number of Tensors fed atomically through the
|
||||
queue, must be present unless it can be inferred from other arguments.
|
||||
tuple_types: if not None, a list of types of the elements of the queue.
|
||||
tuple_shapes: if not None, a list of shapes of the elements of the queue.
|
||||
shard_dimensions: if not None, a list of dimensions on which the
|
||||
elements of the queue should be sharded during automatic
|
||||
parallelization.
|
||||
name: the name of the queue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_tuple_elements <= 0; or
|
||||
number_of_tuple_arguments, tuple_types, tuple_shapes, and
|
||||
shard_dimensions are all None; or the length of tuple_types,
|
||||
tuple_shapes, or shard_dimensions is not equal to
|
||||
number_of_tuple_elements; or any element of shard_dimensions
|
||||
can't be converted to a Dimension.
|
||||
TypeError: if any element of tuple_types or tuple_shapes can't
|
||||
be converted to a dtype or TensorShape, respectively.
|
||||
"""
|
||||
self._frozen = False
|
||||
self._generated_enqueue_ops = False
|
||||
self._generated_dequeue_op = False
|
||||
self._name = "InfeedQueue" if name is None else name
|
||||
if number_of_tuple_elements is None:
|
||||
if tuple_types is not None:
|
||||
number_of_tuple_elements = len(tuple_types)
|
||||
elif tuple_shapes is not None:
|
||||
number_of_tuple_elements = len(tuple_shapes)
|
||||
elif shard_dimensions is not None:
|
||||
number_of_tuple_elements = len(shard_dimensions)
|
||||
else:
|
||||
raise ValueError(
|
||||
"number of tuple elements cannot be inferred from InfeedQueue "
|
||||
"constructor")
|
||||
if number_of_tuple_elements <= 0:
|
||||
raise ValueError("number_of_tuple_elements %d must be > 0" %
|
||||
number_of_tuple_elements)
|
||||
# Make an empty sharding policy for each tuple element.
|
||||
self._sharding_policies = [
|
||||
tpu_sharding.ShardingPolicy()
|
||||
for _ in xrange(number_of_tuple_elements)
|
||||
]
|
||||
if tuple_types is not None:
|
||||
self.set_tuple_types(tuple_types)
|
||||
else:
|
||||
self._tuple_types = None
|
||||
if tuple_shapes is not None:
|
||||
self.set_tuple_shapes(tuple_shapes)
|
||||
else:
|
||||
self._tuple_shapes = None
|
||||
if shard_dimensions is not None:
|
||||
self.set_shard_dimensions(shard_dimensions)
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Checks that the configuration is self-consistent.
|
||||
|
||||
Raises:
|
||||
ValueError: if the shapes and sharding policies don't match.
|
||||
"""
|
||||
if self.tuple_shapes is not None:
|
||||
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
|
||||
# Raise an error if the policy is incompatible with the shape.
|
||||
_ = policy.get_sharded_shape(shape)
|
||||
|
||||
@property
|
||||
def number_of_tuple_elements(self):
|
||||
"""Returns the number of InfeedQueue tuple elements."""
|
||||
return len(self._sharding_policies)
|
||||
|
||||
@property
|
||||
def tuple_types(self):
|
||||
"""Returns the types of the InfeedQueue tuple elements."""
|
||||
return self._tuple_types
|
||||
|
||||
def set_tuple_types(self, tuple_types):
|
||||
"""Sets the type of each element of the queue.
|
||||
|
||||
tuple_types must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a dtype.
|
||||
|
||||
Args:
|
||||
tuple_types: the types of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_types is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_types cannot be converted to a
|
||||
dtype.
|
||||
"""
|
||||
if len(tuple_types) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_types is %s, but must be a list of length %d" %
|
||||
(str(tuple_types), self.number_of_tuple_elements))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_types, tuple_types):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible type. Frozen types are %s, updated types are %s" % (
|
||||
str(self._tuple_types), str(tuple_types)))
|
||||
else:
|
||||
try:
|
||||
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
|
||||
except (TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_types is %s, but must be a list of elements each "
|
||||
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
|
||||
|
||||
@property
|
||||
def tuple_shapes(self):
|
||||
"""Returns the shapes of the InfeedQueue tuple elements."""
|
||||
return self._tuple_shapes
|
||||
|
||||
def set_tuple_shapes(self, tuple_shapes):
|
||||
"""Sets the shape of each element of the queue.
|
||||
|
||||
tuple_shapes must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a TensorShape.
|
||||
|
||||
Args:
|
||||
tuple_shapes: the shapes of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_shapes is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_shapes cannot be converted to
|
||||
a TensorShape.
|
||||
"""
|
||||
if len(tuple_shapes) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
|
||||
(str(tuple_shapes), self.number_of_tuple_elements))
|
||||
try:
|
||||
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_shapes is %s, but must be a list of elements each "
|
||||
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
|
||||
str(e)))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
|
||||
% (str(self._tuple_shapes), str(tuple_shapes)))
|
||||
else:
|
||||
self._tuple_shapes = tuple_shapes
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def sharding_policies(self):
|
||||
"""Returns the sharding policies of the InfeedQueue tuple elements."""
|
||||
return self._sharding_policies
|
||||
|
||||
@property
|
||||
def shard_dimensions(self):
|
||||
"""Gets the shard dimension of each tuple element.
|
||||
|
||||
Returns:
|
||||
A list of length number_of_tuple_elements, where each list entry
|
||||
is the shard dimension of that tuple element or None if the
|
||||
shard dimension has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return [policy.shard_dimension for policy in self._sharding_policies]
|
||||
|
||||
def set_shard_dimensions(self, shard_dimensions):
|
||||
"""Sets the shard_dimension of each element of the queue.
|
||||
|
||||
shard_dimensions must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a Dimension compatible with self.tuple_shapes.
|
||||
|
||||
Args:
|
||||
shard_dimensions: the dimensions of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if shard_dimensions is not of length
|
||||
self.number_of_tuple_elements; or an element of
|
||||
shard_dimensions cannot be converted to a Dimension; or an
|
||||
element of shard_dimensions is a Dimension that is out of
|
||||
range for the corresponding tuple element shape.
|
||||
"""
|
||||
if len(shard_dimensions) != self.number_of_tuple_elements:
|
||||
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
|
||||
% (str(shard_dimensions),
|
||||
self.number_of_tuple_elements))
|
||||
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
|
||||
policy.set_shard_dimension(dimension)
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Gets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Returns:
|
||||
Number of shards or None if the number of shards has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return self._sharding_policies[0].number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Args:
|
||||
number_of_shards: number of ways to shard the InfeedQueue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_shards is not > 0; or the policies have
|
||||
been frozen and number_of_shards was already set to something
|
||||
else.
|
||||
"""
|
||||
for policy in self._sharding_policies:
|
||||
policy.set_number_of_shards(number_of_shards)
|
||||
self._validate()
|
||||
|
||||
def set_configuration_from_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of Tensors whose types and shapes are used
|
||||
to set the queue configuration.
|
||||
|
||||
Args:
|
||||
input_tensors: list of Tensors of the same types and shapes as
|
||||
the desired queue Tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: if input_tensors is not a list of length
|
||||
self.number_of_tuple_elements
|
||||
"""
|
||||
if len(input_tensors) != self.number_of_tuple_elements:
|
||||
raise ValueError("input_tensors is %s, but should be a list of %d Tensors"
|
||||
% (str(input_tensors), self.number_of_tuple_elements))
|
||||
self.set_tuple_shapes([t.shape for t in input_tensors])
|
||||
self.set_tuple_types([t.dtype for t in input_tensors])
|
||||
|
||||
def set_configuration_from_sharded_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of lists of Tensors whose types and shapes are used
|
||||
to set the queue configuration. The length of the outer list is the number
|
||||
of shards required, and each inner list is the tuple of Tensors to use to
|
||||
determine the types and shapes of the corresponding shard. This method
|
||||
depends on the shard dimension, and calling it freezes the shard policy.
|
||||
|
||||
Args:
|
||||
input_tensors: list of lists of Tensors. The outer list length corresponds
|
||||
to the desired number of shards, and each inner list is the size
|
||||
and shape of the desired configuration of the corresponding shard.
|
||||
|
||||
Raises:
|
||||
ValueError: if any inner list is not a list of length
|
||||
self.number_of_tuple_elements; or the inner lists do not combine to
|
||||
form a consistent unsharded shape.
|
||||
TypeError: if the types of the Tensors in the inner lists do not match.
|
||||
"""
|
||||
if not self._frozen:
|
||||
# Unset the tuple shapes in case the configuration becomes
|
||||
# transiently inconsistent.
|
||||
self._tuple_shapes = None
|
||||
number_of_shards = len(input_tensors)
|
||||
self.set_number_of_shards(number_of_shards)
|
||||
for t in input_tensors:
|
||||
if len(t) != self.number_of_tuple_elements:
|
||||
raise ValueError(
|
||||
"input_tensors is %s but must be a list of lists, where each inner"
|
||||
" list has length number_of_tuple_elements=%d" % (
|
||||
str(input_tensors), self.number_of_tuple_elements))
|
||||
# Transpose the inputs to make a list of shard shapes for each tuple
|
||||
# element.
|
||||
sharded_shapes = [[t[i].shape for t in input_tensors]
|
||||
for i in xrange(self.number_of_tuple_elements)]
|
||||
# For each tuple, get the unsharded shape using that tuple's policy.
|
||||
unsharded_shapes = [
|
||||
policy.get_unsharded_shape(s)
|
||||
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
|
||||
]
|
||||
self.set_tuple_shapes(unsharded_shapes)
|
||||
for i in xrange(1, self.number_of_shards):
|
||||
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
|
||||
if t1.dtype != t2.dtype:
|
||||
raise TypeError(
|
||||
"types of the tuple elements of input_tensors %s are not "
|
||||
"consistent" % str(input_tensors))
|
||||
self.set_tuple_types([t.dtype for t in input_tensors[0]])
|
||||
|
||||
def freeze(self):
|
||||
"""Freezes the InfeedQueue so it can no longer be modified.
|
||||
|
||||
The configuration is implicitly frozen before any host-side or
|
||||
device-side Ops are generated. The configuration cannot be frozen
|
||||
until the types and shapes of the tuple elements have been set.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set.
|
||||
"""
|
||||
self._frozen = True
|
||||
if self._tuple_types is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple types.")
|
||||
if self._tuple_shapes is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for shape in self._tuple_shapes:
|
||||
if shape.dims is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for policy in self._sharding_policies:
|
||||
policy.freeze()
|
||||
self._validate()
|
||||
|
||||
def generate_dequeue_op(self, tpu_device=0):
|
||||
"""Generates the device-side Op to dequeue a tuple from the queue.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen, which will raise errors if the shapes and types have not
|
||||
been fully specified.
|
||||
|
||||
Args:
|
||||
tpu_device: The TPU device ordinal where the infeed instruction should be
|
||||
placed. If None, no explicit placement will be performed, and it is up
|
||||
to the user to call this API from within a proper TPU device scope.
|
||||
The XLA code will fail if the TPU dequeue instruction is not bound to
|
||||
any device.
|
||||
|
||||
Returns:
|
||||
A list of Outputs corresponding to a shard of infeed dequeued
|
||||
into XLA, suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set; or if a dequeue op has already been generated.
|
||||
"""
|
||||
self.freeze()
|
||||
if self._generated_dequeue_op:
|
||||
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
if tpu_device is not None:
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
else:
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
|
||||
def _generate_enqueue_op(self,
|
||||
inputs,
|
||||
name_prefix,
|
||||
index,
|
||||
device=None,
|
||||
tpu_ordinal=-1):
|
||||
"""Generate a host-side Op to enqueue a tuple to the queue.
|
||||
|
||||
If device is None the inputs are all required to have the same
|
||||
device specification, and the enqueue Op is colocated with
|
||||
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors with the types and shapes of the tuple elements.
|
||||
name_prefix: the base name for the Op.
|
||||
index: the shard index, used to uniquify the Op name.
|
||||
device: device to place the Op on, or None if it should be
|
||||
colocated with the inputs.
|
||||
tpu_ordinal: ordinal of the TPU device on the host to use for
|
||||
infeed if device is a CPU device. Should be set to -1 if device
|
||||
is a TPU device.
|
||||
|
||||
Returns:
|
||||
An Op corresponding to a shard of infeed enqueued at the host,
|
||||
suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if device is None and inputs do not all have the
|
||||
same device specification.
|
||||
"""
|
||||
full_name = "%s/%d" % (name_prefix, index)
|
||||
shapes = [t.shape for t in inputs]
|
||||
if device is None:
|
||||
devices = [t.device for t in inputs]
|
||||
for i in xrange(1, self.number_of_tuple_elements):
|
||||
if devices[0] != devices[i]:
|
||||
raise ValueError(
|
||||
"input devices for shard %d are %s, but should all be the same" %
|
||||
(index, str(devices)))
|
||||
with ops.colocate_with(inputs[0]):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
else:
|
||||
with ops.device(device):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
|
||||
def generate_enqueue_ops(self,
|
||||
sharded_inputs,
|
||||
tpu_ordinal_function=None,
|
||||
placement_function=None):
|
||||
"""Generates the host-side Ops to enqueue the shards of a tuple.
|
||||
|
||||
sharded_inputs is a list, one for each shard, of lists of
|
||||
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
|
||||
shard 0 if the queue. Returns the host-side Ops that must be run to
|
||||
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
|
||||
for shard i.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of sharded_inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
sharded_inputs: a list of lists of Tensors. The length of the outer list
|
||||
determines the number of shards. Each inner list indicates the types
|
||||
and shapes of the tuples in the corresponding shard.
|
||||
tpu_ordinal_function: if not None, a function that takes the
|
||||
shard index as input and returns the ordinal of the TPU device
|
||||
the shard's infeed should be placed on. tpu_ordinal_function must be
|
||||
set if the inputs are placed on CPU devices.
|
||||
placement_function: if not None, a function that takes the shard index as
|
||||
input and returns the host device where the enqueue op should be placed
|
||||
on.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the shapes of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple; or if the elements of a tuple
|
||||
have different device constraints.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the types of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple.
|
||||
"""
|
||||
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = lambda index: -1
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
shard,
|
||||
name_prefix,
|
||||
index,
|
||||
tpu_ordinal=tpu_ordinal_function(index),
|
||||
device=placement_function(index) if placement_function else None)
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
||||
|
||||
# TODO(misard) Generalize this to the case of systems that don't
|
||||
# have 8 devices per host, and figure out what to do with
|
||||
# model-parallelism.
|
||||
def _default_placement_function(self, index):
|
||||
return "/task:%d/device:CPU:0" % (index / 8)
|
||||
|
||||
def _default_ordinal_function(self, index):
|
||||
return index % 8
|
||||
|
||||
# TODO(b/36470756) remove this from tutorials once we have a better story
|
||||
# for automatic placement of input pipelines.
|
||||
def split_inputs_and_generate_enqueue_ops(self,
|
||||
inputs,
|
||||
device_assignment=None,
|
||||
placement_function=None,
|
||||
tpu_ordinal_function=None):
|
||||
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
|
||||
|
||||
Generates the host-side Ops to enqueue a tuple.
|
||||
|
||||
This method performs poorly because it takes an entire input on a single
|
||||
host, splits it, and distributes it to all of the cores. It is present only
|
||||
to simplify tutorial examples.
|
||||
|
||||
inputs is a list of Tensors to use to feed the queue. Each input is split
|
||||
into self.number_of_shards shards. Returns an Op for each shard to enqueue
|
||||
the shard. The Op for shard i is placed on device placement_function(i).
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors which indicates the types and shapes of the
|
||||
queue tuple.
|
||||
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
|
||||
device_assignment is not `None`, but `placement_function` and
|
||||
`ordinal_function` are None, then `device_assignment` will be used to
|
||||
place infeeds on the first k TPU shards, where k is the number of shards
|
||||
in the queue. If all three are `None`, then default placement and
|
||||
ordinal functions are used.
|
||||
placement_function: if not None, a function that takes the shard
|
||||
index as input and returns a device string indicating which
|
||||
device the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
tpu_ordinal_function: if not None, a function that takes the
|
||||
shard index as input and returns the ordinal of the TPU device
|
||||
the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
"""
|
||||
if device_assignment is None:
|
||||
if placement_function is None:
|
||||
placement_function = self._default_placement_function
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = self._default_ordinal_function
|
||||
else:
|
||||
|
||||
def _placement_function_from_map(index):
|
||||
return device_assignment.host_device(replica=index)
|
||||
|
||||
def _ordinal_function_from_map(index):
|
||||
return device_assignment.tpu_ordinal(replica=index)
|
||||
|
||||
if placement_function is None:
|
||||
placement_function = _placement_function_from_map
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = _ordinal_function_from_map
|
||||
self.set_configuration_from_input_tensors(inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
split_name_prefix = "%s/split" % self._name
|
||||
if self.number_of_shards == 1:
|
||||
transposed_sharded_inputs = [[inp] for inp in inputs]
|
||||
else:
|
||||
|
||||
def split_fn(inp, num_shards, axis, name):
|
||||
with ops.colocate_with(inp):
|
||||
return array_ops.split(inp, num_shards, axis=axis, name=name)
|
||||
|
||||
transposed_sharded_inputs = [
|
||||
split_fn(
|
||||
inp,
|
||||
self.number_of_shards,
|
||||
axis=policy.shard_dimension,
|
||||
name="%s/%d" % (split_name_prefix, index))
|
||||
for (inp, policy, index) in zip(inputs, self._sharding_policies,
|
||||
xrange(self.number_of_tuple_elements))
|
||||
]
|
||||
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
|
||||
for i in xrange(self.number_of_shards)]
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
shard,
|
||||
name_prefix,
|
||||
index,
|
||||
device=placement_function(index),
|
||||
tpu_ordinal=tpu_ordinal_function(index))
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
||||
|
||||
|
||||
class _PartitionedInfeedQueue(InfeedQueue):
|
||||
"""A helper object to build a device infeed queue with input partition.
|
||||
|
||||
Args:
|
||||
number_of_tuple_elements: the number of Tensors fed atomically through the
|
||||
queue, must be present unless it can be inferred from other arguments.
|
||||
device_assignment: A TPU `DeviceAssignment` which is used to place all the
|
||||
partitions to different TPU infeed queues.
|
||||
host_id: The id of the host machine.
|
||||
input_partition_dims: A nested list/tuple of integers. Each inner
|
||||
list/tuple describes how to partition the corresponding input tensor.
|
||||
tuple_types: If not None, a list of types of the elements of the queue.
|
||||
tuple_shapes: If not None, a list of shapes of the elements of the queue.
|
||||
name: The name of the queue.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
number_of_tuple_elements,
|
||||
device_assignment,
|
||||
host_id,
|
||||
input_partition_dims=None,
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
name=None):
|
||||
super(_PartitionedInfeedQueue, self).__init__(
|
||||
number_of_tuple_elements=number_of_tuple_elements,
|
||||
tuple_types=tuple_types,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
name="PartitionedInfeedQueue" if name is None else name)
|
||||
self._input_partition_dims = input_partition_dims
|
||||
self._host_id = host_id
|
||||
self._device_assignment = device_assignment
|
||||
|
||||
def generate_dequeue_op(self, tpu_device=0):
|
||||
"""Generate TPU dequeue ops.
|
||||
|
||||
Args:
|
||||
tpu_device: The TPU device ordinal where the infeed instruction should be
|
||||
placed.
|
||||
|
||||
Returns:
|
||||
A list of Outputs corresponding to a partition of infeed dequeued
|
||||
into XLA, suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set; or if a dequeue op has already been generated.
|
||||
"""
|
||||
self.freeze()
|
||||
if self._generated_dequeue_op:
|
||||
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
values = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
return tag_sharding_attribute_for_dequeued_tensors(
|
||||
values, self._input_partition_dims)
|
||||
|
||||
def generate_enqueue_ops(self, per_host_sharded_inputs):
|
||||
"""Generates the host-side Ops to enqueue the partitioned inputs.
|
||||
|
||||
per_host_sharded_inputs is a list, one for each replica, of lists of
|
||||
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
|
||||
replica i.
|
||||
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
|
||||
|
||||
For example, if sharded_inputs[i][j] is a 2-D Tensor:
|
||||
[[A, B, C, D],
|
||||
[E ,F, G, H]]
|
||||
self._input_partition_dims[j] is [2, 4].
|
||||
|
||||
sharded_inputs[i][j] will be partitioned and flattened into:
|
||||
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
|
||||
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
|
||||
|
||||
Args:
|
||||
per_host_sharded_inputs: a list of lists of Tensors. The length of the
|
||||
outer list determines the number of shards. Each inner list indicates
|
||||
the types and shapes of the tuples in the corresponding shard.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the shapes of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple; or if the elements of a tuple
|
||||
have different device constraints; or if the partition dims are invalid.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the types of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple.
|
||||
"""
|
||||
self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
|
||||
number_of_replicas_per_host = len(per_host_sharded_inputs)
|
||||
number_of_tuple_elements = len(per_host_sharded_inputs[0])
|
||||
|
||||
assert len(self._input_partition_dims) == number_of_tuple_elements
|
||||
per_host_enqueue_ops = []
|
||||
|
||||
for replica_index in range(number_of_replicas_per_host):
|
||||
flattened_inputs = per_host_sharded_inputs[replica_index]
|
||||
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
|
||||
self._input_partition_dims)
|
||||
inputs_parted_iters = [
|
||||
iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
|
||||
for x, dims in zip(per_host_sharded_inputs[replica_index],
|
||||
inputs_part_dims_flat)
|
||||
]
|
||||
|
||||
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
|
||||
# Places different partitions to different logic cores.
|
||||
replica_id = self._device_assignment.lookup_replicas(
|
||||
self._host_id, logical_core)[replica_index]
|
||||
ordinal = self._device_assignment.tpu_ordinal(
|
||||
replica=replica_id, logical_core=logical_core)
|
||||
infeed_inputs = []
|
||||
for it in inputs_parted_iters:
|
||||
input_for_device = next(it, None)
|
||||
if input_for_device is not None:
|
||||
infeed_inputs.append(input_for_device)
|
||||
|
||||
if infeed_inputs:
|
||||
per_host_enqueue_ops.append(
|
||||
tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=infeed_inputs,
|
||||
shapes=[x.shape for x in infeed_inputs],
|
||||
name="enqueue/replica_{0}/input_{1}".format(
|
||||
replica_index, logical_core),
|
||||
device_ordinal=ordinal))
|
||||
return per_host_enqueue_ops
|
||||
|
||||
def _check_input_partition_dims(self, tensor, dims):
|
||||
"""Checks that input partition dims are valid for the `Tensor`.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor for partitioning.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tensor can't be partitioned by dims or the
|
||||
num_cores_per_replica doesn't match the number of
|
||||
partitions(dims.prod()).
|
||||
"""
|
||||
# No partitioning specified, so don't perform further checks.
|
||||
if dims is None:
|
||||
return
|
||||
|
||||
dims = np.array(dims)
|
||||
|
||||
if (dims < 1).any():
|
||||
raise ValueError("All input partition dims must be >= 1.")
|
||||
|
||||
# No partitioning, so don't perform further checks.
|
||||
if dims.prod() == 1:
|
||||
return
|
||||
|
||||
if dims.prod() != self._device_assignment.num_cores_per_replica:
|
||||
raise ValueError(
|
||||
"The product of each input parition dim should equal to "
|
||||
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
|
||||
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
|
||||
if dims.shape[0] != tensor.shape.ndims:
|
||||
raise ValueError(
|
||||
"Input partition dims must have the same number of dimensions "
|
||||
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
|
||||
"partition dims = {}).".format(tensor.shape.as_list(), dims))
|
||||
|
||||
tensor.shape.assert_is_fully_defined()
|
||||
|
||||
def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
|
||||
"""Checks dims and partitions or replicates the input tensor.
|
||||
|
||||
The ops inside this function are placed on the host side.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor which will be partioned or replicated.
|
||||
dims: A list of integer describes how to partition the input tensor.
|
||||
|
||||
Returns:
|
||||
An iterator of `Tensor`s or a list of partioned tensors.
|
||||
"""
|
||||
self._check_input_partition_dims(tensor, dims)
|
||||
return partition_or_replicate_on_host(tensor, dims)
|
66
tensorflow/python/tpu/tpu_function.py
Normal file
66
tensorflow/python/tpu/tpu_function.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for functions used during TPU compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
class TpuContext(object):
|
||||
"""A context object holding state about the TPU computation being built."""
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a new TpuContext."""
|
||||
self._number_of_shards = None
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
self._number_of_shards = number_of_shards
|
||||
|
||||
|
||||
# The Tpu context holds the number of shards when a sharded computation is
|
||||
# being built, or None if no computation is being built.
|
||||
_current_tpu_context = TpuContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tpu_shard_context(number_of_shards):
|
||||
if _current_tpu_context.number_of_shards is not None:
|
||||
raise NotImplementedError("tpu_shard_context cannot be nested.")
|
||||
try:
|
||||
_current_tpu_context.set_number_of_shards(number_of_shards)
|
||||
yield
|
||||
finally:
|
||||
_current_tpu_context.set_number_of_shards(None)
|
||||
|
||||
|
||||
def get_tpu_context():
|
||||
return _current_tpu_context
|
||||
|
||||
|
||||
# Decorator function for tpu computation func that was passed to tpu.rewrite()
|
||||
# if there is an embedded training loop in this func, trace tools will generate
|
||||
# step markers for each iteration.
|
||||
def on_device_training_loop(func):
|
||||
# Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
|
||||
setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
|
||||
return func
|
@ -19,11 +19,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
|
||||
|
||||
class InfeedTest(test.TestCase):
|
203
tensorflow/python/tpu/tpu_optimizer.py
Normal file
203
tensorflow/python/tpu/tpu_optimizer.py
Normal file
@ -0,0 +1,203 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Optimizer that implements cross-shard gradient reduction for TPU."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
|
||||
|
||||
class CrossShardOptimizer(optimizer.Optimizer):
|
||||
"""An optimizer that averages gradients across TPU shards."""
|
||||
|
||||
def __init__(self,
|
||||
opt,
|
||||
reduction=losses.Reduction.MEAN,
|
||||
name="CrossShardOptimizer",
|
||||
group_assignment=None):
|
||||
"""Construct a new cross-shard optimizer.
|
||||
|
||||
Args:
|
||||
opt: An existing `Optimizer` to encapsulate.
|
||||
reduction: The reduction to apply to the shard losses.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "CrossShardOptimizer".
|
||||
group_assignment: Optional 2d int32 lists with shape
|
||||
[num_groups, num_replicas_per_group] which describles how to apply
|
||||
optimizer to subgroups.
|
||||
|
||||
Raises:
|
||||
ValueError: If reduction is not a valid cross-shard reduction.
|
||||
"""
|
||||
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
|
||||
raise ValueError("Unsupported reduction: %s." % reduction)
|
||||
|
||||
super(CrossShardOptimizer, self).__init__(False, name)
|
||||
self._opt = opt
|
||||
self._reduction = reduction
|
||||
self._group_assignment = group_assignment
|
||||
|
||||
def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
|
||||
"""Verify group_assignment and get the subgroup size".
|
||||
|
||||
Args:
|
||||
group_assignment: list of group ids for applying the optimizer
|
||||
to subgroups.
|
||||
num_shards: The number of TPU shards.
|
||||
|
||||
Returns:
|
||||
The size of one subgroup in group_assignment.
|
||||
|
||||
Raises:
|
||||
ValueError: If group_assignment is invalid.
|
||||
"""
|
||||
if not group_assignment:
|
||||
return None
|
||||
if not (isinstance(group_assignment, list) and
|
||||
all(isinstance(i, list) for i in group_assignment)):
|
||||
raise ValueError("group_assignment must be a list of list. Got {}".format(
|
||||
group_assignment))
|
||||
|
||||
replica_ids = set()
|
||||
for g in group_assignment:
|
||||
for i in g:
|
||||
replica_ids.add(i)
|
||||
|
||||
if set(range(num_shards)) != replica_ids:
|
||||
raise ValueError("group_assignment must be a permutation of range({0})."
|
||||
" Got group_assignment={1}".format(
|
||||
num_shards, group_assignment))
|
||||
|
||||
subgroup_size_list = [len(group) for group in group_assignment]
|
||||
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
|
||||
return subgroup_size_list[0]
|
||||
else:
|
||||
raise ValueError("The size of each subgroup in group_assignment must "
|
||||
"be equal. Got group_assignment={}".format(
|
||||
self._group_assignment))
|
||||
|
||||
def compute_gradients(self, loss, var_list=None, **kwargs):
|
||||
"""Compute gradients of "loss" for the variables in "var_list".
|
||||
|
||||
This simply wraps the compute_gradients() from the real optimizer. The
|
||||
gradients will be aggregated in the apply_gradients() so that user can
|
||||
modify the gradients like clipping with per replica global norm if needed.
|
||||
The global norm with aggregated gradients can be bad as one replica's huge
|
||||
gradients can hurt the gradients from other replicas.
|
||||
|
||||
Args:
|
||||
loss: A Tensor containing the value to minimize.
|
||||
var_list: Optional list or tuple of `tf.Variable` to update to minimize
|
||||
`loss`. Defaults to the list of variables collected in the graph
|
||||
under the key `GraphKey.TRAINABLE_VARIABLES`.
|
||||
**kwargs: Keyword arguments for compute_gradients().
|
||||
|
||||
Returns:
|
||||
A list of (gradient, variable) pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If not within a tpu_shard_context or group_assignment is
|
||||
invalid.
|
||||
"""
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
logging.warning(
|
||||
"CrossShardOptimizer should be used within a tpu_shard_context, but "
|
||||
"got unset number_of_shards. Assuming 1.")
|
||||
num_shards = 1
|
||||
|
||||
subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
|
||||
num_shards)
|
||||
|
||||
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
|
||||
if self._group_assignment:
|
||||
scale = 1.0 / subgroup_size
|
||||
else:
|
||||
scale = 1.0 / num_shards
|
||||
loss *= scale
|
||||
|
||||
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""Apply gradients to variables.
|
||||
|
||||
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
|
||||
replicas, and then applies the real optimizer.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||||
compute_gradients().
|
||||
global_step: Optional Variable to increment by one after the
|
||||
variables have been updated.
|
||||
name: Optional name for the returned operation. Default to the
|
||||
name passed to the Optimizer constructor.
|
||||
|
||||
Returns:
|
||||
An `Operation` that applies the gradients. If `global_step` was not None,
|
||||
that operation also increments `global_step`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the grads_and_vars is malformed.
|
||||
"""
|
||||
summed_grads_and_vars = []
|
||||
for (grad, var) in grads_and_vars:
|
||||
if grad is None:
|
||||
summed_grads_and_vars.append((grad, var))
|
||||
else:
|
||||
with ops.colocate_with(grad):
|
||||
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
|
||||
grad, self._group_assignment), var))
|
||||
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
|
||||
|
||||
def get_slot(self, *args, **kwargs):
|
||||
"""Return a slot named "name" created for "var" by the Optimizer.
|
||||
|
||||
This simply wraps the get_slot() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
"""
|
||||
return self._opt.get_slot(*args, **kwargs)
|
||||
|
||||
def get_slot_names(self, *args, **kwargs):
|
||||
"""Return a list of the names of slots created by the `Optimizer`.
|
||||
|
||||
This simply wraps the get_slot_names() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
A list of strings.
|
||||
"""
|
||||
return self._opt.get_slot_names(*args, **kwargs)
|
||||
|
||||
def variables(self):
|
||||
"""Forwarding the variables from the underlying optimizer."""
|
||||
return self._opt.variables()
|
253
tensorflow/python/tpu/tpu_sharding.py
Normal file
253
tensorflow/python/tpu/tpu_sharding.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for sharding during TPU compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
_DEFAULT_NUMBER_OF_SHARDS = 1
|
||||
_DEFAULT_SHARD_DIMENSION = 0
|
||||
|
||||
|
||||
# TODO(b/36777903) change other parts of tpu.py to use this class.
|
||||
class ShardingPolicy(object):
|
||||
"""An object use to hold the sharding policy for a Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._number_of_shards = None
|
||||
self._shard_dimension = None
|
||||
self._frozen = False
|
||||
|
||||
def __str__(self):
|
||||
if self.number_of_shards is None or self.shard_dimension is None:
|
||||
return "ShardingPolicy(unset)"
|
||||
else:
|
||||
return ("ShardingPolicy(%d shards dimension %d)" %
|
||||
(self.number_of_shards, self.shard_dimension))
|
||||
|
||||
def _fill_default_values(self):
|
||||
if self._number_of_shards is None:
|
||||
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
|
||||
if self._shard_dimension is None:
|
||||
self._shard_dimension = tensor_shape.as_dimension(
|
||||
_DEFAULT_SHARD_DIMENSION)
|
||||
|
||||
def freeze(self):
|
||||
"""Prevents further modification to the sharding policy.
|
||||
|
||||
Any values that have not been set when freeze is called are set to
|
||||
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
|
||||
"""
|
||||
if not self._frozen:
|
||||
self._fill_default_values()
|
||||
self._frozen = True
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Returns the number of shards in the policy or None if unspecified."""
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards for the current policy.
|
||||
|
||||
If the policy has been frozen then number_of_shards must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
number_of_shards: The number of shards to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and number_of_shards
|
||||
differs from the frozen value; or number_of_shards <= 0.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._number_of_shards != number_of_shards:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %d shards since it has been "
|
||||
"frozen to use %d." % (number_of_shards, self._number_of_shards))
|
||||
else:
|
||||
if number_of_shards > 0:
|
||||
self._number_of_shards = number_of_shards
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %s shards; value must be >0",
|
||||
str(number_of_shards))
|
||||
|
||||
@property
|
||||
def shard_dimension(self):
|
||||
"""Returns the shard dimension of the policy or None if unspecified."""
|
||||
return self._shard_dimension
|
||||
|
||||
def set_shard_dimension(self, shard_dimension):
|
||||
"""Sets the shard dimension for the current policy.
|
||||
|
||||
If the policy has been frozen then shard_dimension must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
shard_dimension: The shard dimension to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and shard_dimension
|
||||
differs from the frozen value, or shard_dimension can't be
|
||||
interpreted as a Dimension.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._shard_dimension != shard_dimension:
|
||||
raise ValueError(
|
||||
"Can't set shard dimension to %d since it has been frozen to "
|
||||
"use %d." % (shard_dimension, self._shard_dimension))
|
||||
else:
|
||||
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
|
||||
|
||||
def merge(self, other):
|
||||
"""Merges the policy of another policy into the current policy.
|
||||
|
||||
Args:
|
||||
other: The policy to merge into this one.
|
||||
|
||||
Raises:
|
||||
ValueError: If this policy has been frozen and the merge conflicts with
|
||||
the frozen policy.
|
||||
"""
|
||||
if other.number_of_shards is not None:
|
||||
self.set_number_of_shards(other.number_of_shards)
|
||||
if other.shard_dimension is not None:
|
||||
self.set_shard_dimension(other.shard_dimension)
|
||||
|
||||
def get_sharded_shape(self, shape, shard_index=None):
|
||||
"""Returns the shape of a shard of a full Tensor.
|
||||
|
||||
When given the shape of a 'full-size' Tensor, returns the shape of
|
||||
the sub-Tensor after it has been sharded. Freezes the policy if it
|
||||
has not yet been frozen.
|
||||
|
||||
Args:
|
||||
shape: The shape of the full-size Tensor to be sharded.
|
||||
shard_index: The index of the shard whose shape should be returned.
|
||||
shard_index can be None for sharding policies that use the same
|
||||
shape for every shard.
|
||||
freeze_config:
|
||||
|
||||
Returns:
|
||||
The shape of the sharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If shard_index is None when shards are of different
|
||||
shapes; or shard_index is not None and
|
||||
!(0<=shard_index<number_of_shards); or shape does not have at
|
||||
least self.shard_dimension+1 dimensions; or the value of
|
||||
shape's shard dimension is not a multiple of
|
||||
self.number_of_shards
|
||||
"""
|
||||
if self._shard_dimension is None or self._number_of_shards is None:
|
||||
# Don't raise an error if the config is unset.
|
||||
return None
|
||||
if shard_index is not None:
|
||||
if shard_index < 0 or shard_index >= self.number_of_shards:
|
||||
raise ValueError("shard_index %d, but must be in [0,%d)." %
|
||||
(shard_index, self._number_of_shards))
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
if dims[self._shard_dimension] is None:
|
||||
raise ValueError("shape %s must have a fixed size for dimension %d "
|
||||
"that is known at graph construction time." %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
|
||||
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
||||
(shape.as_list(), self._number_of_shards,
|
||||
self._shard_dimension))
|
||||
dims[self._shard_dimension] /= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def _unshard_shape(self, shape):
|
||||
"""Return the unsharded shape that would generate a given sharded shape.
|
||||
|
||||
Args:
|
||||
shape: the sharded shape to unshard
|
||||
|
||||
Returns:
|
||||
The unsharded shape.
|
||||
|
||||
Raises:
|
||||
ValueError: if shape is unknown or does not contain
|
||||
self.shard_dimension
|
||||
TypeError: if shape is not convertible to a TensorShape
|
||||
"""
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
dims[self._shard_dimension] *= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def get_unsharded_shape(self, shapes):
|
||||
"""Returns the shape of an unsharded Tensor given a list of shards.
|
||||
|
||||
When given a list of shapes of shards, returns the shape of the
|
||||
unsharded Tensor that would generate the shards. Sets defaults for the
|
||||
policy if number_of_shards or shard_dimension is None.
|
||||
|
||||
Args:
|
||||
shapes: The shapes of the Tensor shards to be combined.
|
||||
|
||||
Returns:
|
||||
The shape of the unsharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if shapes is not a list of length
|
||||
self.number_of_shards; or any element of shapes is not a valid
|
||||
shape consistent with the sharding policy; or the list of
|
||||
shapes is not a valid sharding of a full shape.
|
||||
TypeError: if an element of shapes is not convertible to a
|
||||
TensorShape
|
||||
"""
|
||||
self._fill_default_values()
|
||||
if len(shapes) != self.number_of_shards:
|
||||
raise ValueError(
|
||||
"shapes is %s but must be a list of length number_of_shards=%d" % (
|
||||
str(shapes), self.number_of_shards))
|
||||
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
|
||||
for i in xrange(self.number_of_shards - 1):
|
||||
if not unsharded_shapes[i].is_compatible_with(
|
||||
unsharded_shapes[self.number_of_shards - 1]):
|
||||
raise ValueError(
|
||||
"sharded shapes %s are not consistent shards of a full shape "
|
||||
"sharded %d ways along dimension %d" % (
|
||||
str(shapes), self.number_of_shards, self.shard_dimension))
|
||||
return unsharded_shapes[0]
|
@ -19,10 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_sharding
|
||||
|
||||
|
||||
class ShardingTest(test.TestCase):
|
156
tensorflow/python/tpu/tpu_system_metadata.py
Normal file
156
tensorflow/python/tpu/tpu_system_metadata.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU system metadata and associated tooling."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu
|
||||
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min
|
||||
_RETRY_TIMES = 120
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||
|
||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
||||
|
||||
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
|
||||
# including num_cores and num_hosts.
|
||||
_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
|
||||
'num_cores',
|
||||
'num_hosts',
|
||||
'num_of_cores_per_host',
|
||||
'topology',
|
||||
'devices',
|
||||
])
|
||||
|
||||
|
||||
def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
query_topology=False):
|
||||
"""Automatically detects the TPU system metadata in the system."""
|
||||
tpu_core_count = 0
|
||||
devices = []
|
||||
device_dict = collections.defaultdict(list)
|
||||
|
||||
# TODO(b/120564445): Replace with standard library for retries.
|
||||
retry_count = 1
|
||||
while True:
|
||||
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
|
||||
master_address)
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
with session_lib.Session(
|
||||
master_address,
|
||||
config=get_session_config_with_timeout(
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS,
|
||||
cluster_def)) as sess:
|
||||
devices = sess.list_devices()
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
if match:
|
||||
host_id = match.group(1)
|
||||
core_id = match.group(2)
|
||||
device_dict[host_id].append(core_id)
|
||||
tpu_core_count += 1
|
||||
break
|
||||
except errors.DeadlineExceededError:
|
||||
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
||||
'not be ready (still scheduling) or the Tensorflow master address '
|
||||
'is incorrect: got (%s).' %
|
||||
(master_address))
|
||||
|
||||
# TODO(xiejw): For local or grpc master we might not need retry logic
|
||||
# here.
|
||||
if retry_count <= _RETRY_TIMES:
|
||||
logging.warning('%s', msg)
|
||||
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
|
||||
retry_count += 1
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
num_of_cores_per_host = 0
|
||||
if tpu_core_count:
|
||||
num_cores_per_host_set = set(
|
||||
[len(core_ids) for core_ids in device_dict.values()])
|
||||
if len(num_cores_per_host_set) != 1:
|
||||
raise RuntimeError(
|
||||
'TPU cores on each host is not same. This should not happen!. '
|
||||
'devices: {}'.format(devices))
|
||||
num_of_cores_per_host = num_cores_per_host_set.pop()
|
||||
|
||||
topology = None
|
||||
if query_topology:
|
||||
if not tpu_core_count:
|
||||
raise RuntimeError(
|
||||
'Cannot find any TPU cores in the system (master address {}). '
|
||||
'This usually means the master address is incorrect or the '
|
||||
'TPU worker has some problems. Available devices: {}'.format(
|
||||
master_address, devices))
|
||||
|
||||
topology = _obtain_topology(master_address, cluster_def)
|
||||
|
||||
metadata = _TPUSystemMetadata(
|
||||
num_cores=tpu_core_count,
|
||||
num_hosts=len(device_dict),
|
||||
num_of_cores_per_host=num_of_cores_per_host,
|
||||
topology=topology,
|
||||
devices=devices)
|
||||
|
||||
if tpu_core_count:
|
||||
logging.info('Found TPU system:')
|
||||
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
|
||||
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
|
||||
logging.info('*** Num TPU Cores Per Worker: %d',
|
||||
metadata.num_of_cores_per_host)
|
||||
for device in metadata.devices:
|
||||
logging.info('*** Available Device: %s', device)
|
||||
else:
|
||||
logging.info('Failed to find TPU: %s', metadata)
|
||||
return metadata
|
||||
|
||||
|
||||
def _obtain_topology(master_address, cluster_def):
|
||||
"""Obtains TPU fabric topology."""
|
||||
try:
|
||||
logging.info('Initializing TPU system (master: %s) to fetch topology '
|
||||
'for model parallelism. This might take a while.',
|
||||
master_address)
|
||||
with ops.Graph().as_default():
|
||||
session_config = get_session_config_with_timeout(
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
|
||||
with session_lib.Session(
|
||||
master_address, config=session_config) as sess:
|
||||
topology = sess.run(tpu.initialize_system())
|
||||
return topology
|
||||
except errors.DeadlineExceededError:
|
||||
raise ValueError(
|
||||
'Fail to initialize TPU system with master (%s). '
|
||||
'Please double check the TPU system is functional.' % (
|
||||
master_address))
|
||||
|
||||
|
||||
def get_session_config_with_timeout(timeout_in_secs, cluster_def):
|
||||
"""Returns a session given a timeout and a cluster configuration."""
|
||||
config = config_pb2.ConfigProto(
|
||||
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
|
||||
return config
|
@ -19,18 +19,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.contrib.tpu.python.tpu import training_loop
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
from tensorflow.python.tpu import training_loop
|
||||
|
||||
|
||||
class TPUContextTest(test.TestCase):
|
222
tensorflow/python/tpu/training_loop.py
Normal file
222
tensorflow/python/tpu/training_loop.py
Normal file
@ -0,0 +1,222 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Library for constructing a training loop, suitable for TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.tpu import tensor_tracer
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.tpu import xla
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop for TPUs.
|
||||
|
||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||
`condition` and `body` take the current value of the loop-carried
|
||||
tensors. 'body' additionally takes a tuple of infeed from
|
||||
infeed_queue if infeed_queue is not None. `condition` must return a
|
||||
single boolean value that determines whether iteration
|
||||
continues. `body` must return an updated list of values for the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
condition: a Python function that builds the loop condition.
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop, or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: if body or condition has the wrong signature.
|
||||
"""
|
||||
del name
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
|
||||
x in inputs]
|
||||
input_types = [x.dtype for x in inputs]
|
||||
input_arity = len(inputs)
|
||||
|
||||
body_arg_error = xla.check_function_argument_count(
|
||||
body, input_arity, infeed_queue)
|
||||
if body_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
|
||||
input_arity, str([i.name for i in inputs]), body_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s and %d additional inputs from "
|
||||
"infeed, but the computation needs %s" % (input_arity, str(
|
||||
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
|
||||
body_arg_error))
|
||||
condition_arg_error = xla.check_function_argument_count(
|
||||
condition, input_arity, None)
|
||||
if condition_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s. Note that infeed is not passed to the loop "
|
||||
"condition." % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
|
||||
def condition_wrapper(*inputs):
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
return condition(*inputs)
|
||||
|
||||
def body_wrapper(*inputs):
|
||||
"""Wrapper around `body` that handles infeed queues and control deps."""
|
||||
inputs = list(inputs)
|
||||
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
|
||||
# Runs `body` with the dequeue_ops appended.
|
||||
if infeed_queue:
|
||||
number_of_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if number_of_shards is None:
|
||||
raise ValueError("Can't build training loop with infeed when there is "
|
||||
"no tpu_shard_context. Are you building a loop or "
|
||||
"graph directly rather than from inside tpu.rewrite, "
|
||||
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
|
||||
infeed_queue.set_number_of_shards(number_of_shards)
|
||||
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
|
||||
else:
|
||||
dequeue_ops = []
|
||||
outputs = body(*(inputs + dequeue_ops))
|
||||
|
||||
# If the computation only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs
|
||||
if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
"TPU training loop body must return zero or more Tensor values "
|
||||
"followed by zero or more Operations.")
|
||||
|
||||
output_types = [op.dtype for op in output_tensors]
|
||||
if input_types != output_types:
|
||||
raise TypeError(
|
||||
"Mismatch between input types and output types for training loop "
|
||||
"body: {} vs {}".format(input_types, output_types))
|
||||
|
||||
# Add the dequeue operations to output_operations to ensure they are run
|
||||
# by the loop, even if the programmer's loop body does not use them.
|
||||
output_operations += dequeue_ops
|
||||
|
||||
# Add a dummy output, if needed.
|
||||
if not output_tensors:
|
||||
output_tensors = array_ops.constant(0)
|
||||
|
||||
if output_operations:
|
||||
# TODO(phawkins): in principle this is too restrictive since it serializes
|
||||
# the training loop steps. In practice it does not matter since this loop
|
||||
# will be compiled by XLA.
|
||||
output_tensors = control_flow_ops.tuple(output_tensors,
|
||||
control_inputs=output_operations)
|
||||
|
||||
if tensor_tracer.TensorTracer.is_enabled():
|
||||
num_replicas = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_replicas is None:
|
||||
num_replicas = 1
|
||||
tt = tensor_tracer.TensorTracer()
|
||||
output_tensors = tt.trace_tpu(ops.get_default_graph(),
|
||||
output_tensors, None,
|
||||
num_replicas)
|
||||
return output_tensors
|
||||
|
||||
# If the body has arity 0, add a dummy loop-carried value to which we can add
|
||||
# control dependencies from any side-effecting operations.
|
||||
if input_arity == 0:
|
||||
inputs = [array_ops.constant(0)]
|
||||
return control_flow_ops.while_loop(
|
||||
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
||||
|
||||
|
||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop that executes a fixed number of iterations.
|
||||
|
||||
The set of loop-carried tensors correspond to `inputs`.
|
||||
`body` must be a function that takes and returns the values of the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
n: the number of loop iterations
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
Raises:
|
||||
ValueError: if there is a type error.
|
||||
"""
|
||||
def _convert_to_list(xs):
|
||||
if not isinstance(xs, (list, tuple)):
|
||||
return [xs]
|
||||
else:
|
||||
return list(xs)
|
||||
|
||||
def cond(i, *args):
|
||||
del args
|
||||
return i < n
|
||||
|
||||
def body_wrapper(i, *args):
|
||||
return [i + 1] + _convert_to_list(body(*args))
|
||||
|
||||
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
|
||||
outputs = while_loop(
|
||||
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
|
||||
outputs = _convert_to_list(outputs)
|
||||
if len(outputs) == 1:
|
||||
# Returns the Op rather than an empty list.
|
||||
return outputs[0].op
|
||||
else:
|
||||
return outputs[1:]
|
51
tensorflow/python/tpu/util.py
Normal file
51
tensorflow/python/tpu/util.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Utilities for the functionalities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training
|
||||
|
||||
def check_positive_integer(value, name):
|
||||
"""Checks whether `value` is a positive integer."""
|
||||
if not isinstance(value, six.integer_types):
|
||||
raise TypeError('{} must be int, got {}'.format(name, type(value)))
|
||||
|
||||
if value <= 0:
|
||||
raise ValueError('{} must be positive, got {}'.format(name, value))
|
||||
|
||||
|
||||
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
|
||||
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
|
||||
# python/estimator/util.py.
|
||||
class MultiHostDatasetInitializerHook(training.SessionRunHook):
|
||||
"""Creates a SessionRunHook that initializes all passed iterators."""
|
||||
|
||||
def __init__(self, dataset_initializers):
|
||||
self._initializers = dataset_initializers
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
del coord
|
||||
start = time.time()
|
||||
session.run(self._initializers)
|
||||
logging.info('Initialized dataset iterators in %d seconds',
|
||||
time.time() - start)
|
106
tensorflow/python/tpu/xla.py
Normal file
106
tensorflow/python/tpu/xla.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""XLA utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def is_flat(outputs):
|
||||
"""Checks if outputs is a flat structure.
|
||||
|
||||
Following structures and values are considered flat:
|
||||
1) None
|
||||
2) A single object
|
||||
3) A list or tuple of Tensors/Operations
|
||||
|
||||
The only structures that this function understands are sequences and
|
||||
dictionaries. E.g. this means that if outputs contains a single
|
||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
||||
if that Object cannot be converted to a Tensor.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
|
||||
Returns:
|
||||
A boolean indicates whether outputs is flat.
|
||||
"""
|
||||
# If outputs is a list or tuple, check if it has any nested structure. If
|
||||
# there is, then outputs is non-flat.
|
||||
if isinstance(outputs, collections.Sequence):
|
||||
for o in outputs:
|
||||
if isinstance(o, collections.Sequence) or isinstance(o, dict):
|
||||
return False
|
||||
|
||||
# If outputs is a dict, it is non-flat.
|
||||
if isinstance(outputs, dict):
|
||||
return False
|
||||
|
||||
# Getting here means either outputs itself is a single non-structured value
|
||||
# or it is a flat list of single non-structured values.
|
||||
return True
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to an XLA function.
|
||||
|
||||
Args:
|
||||
func: the Python function that will be called to generate the body of an XLA
|
||||
computation graph.
|
||||
input_arity: the number of explicit arguments supplied by the caller.
|
||||
infeed_queue: if not None, the infeed queue that will supply
|
||||
additional arguments to the function.
|
||||
|
||||
Returns:
|
||||
None if function can be called with the supplied number of
|
||||
arguments, or an error string if it cannot.
|
||||
"""
|
||||
def format_error(complaint, quantity):
|
||||
return '%s %d argument%s' % (complaint, quantity, ''
|
||||
if quantity == 1 else 's')
|
||||
|
||||
num_args_supplied = input_arity
|
||||
if infeed_queue is not None:
|
||||
num_args_supplied += infeed_queue.number_of_tuple_elements
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
num_func_args = len(arg_spec.args)
|
||||
if arg_spec.defaults is None:
|
||||
num_func_defaults = 0
|
||||
else:
|
||||
num_func_defaults = len(arg_spec.defaults)
|
||||
min_func_args = num_func_args - num_func_defaults
|
||||
if num_args_supplied < min_func_args:
|
||||
# The required number of arguments is not enough to call the function.
|
||||
if num_func_defaults == 0 and arg_spec.varargs is None:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at least', min_func_args)
|
||||
if arg_spec.varargs is None and num_args_supplied > num_func_args:
|
||||
# The required number of arguments is too many to call the function.
|
||||
if num_func_defaults == 0:
|
||||
return format_error('exactly', num_func_args)
|
||||
else:
|
||||
return format_error('at most', num_func_args)
|
||||
# Reaching here means either
|
||||
# 1) There are varargs, func can accept any number of arguments greater than
|
||||
# the minimum.
|
||||
# 2) Number of supplied arguments falls in range of acceptable argument count
|
||||
# of func.
|
||||
return None
|
Loading…
Reference in New Issue
Block a user