Introduce Tensor Processing Unit-specific operations and estimators. This commit enables TensorFlow users to utilize functionality available in Cloud TPUs. Please note that these APIs are in alpha with Cloud TPUs, and will likely undergo significant changes during the Cloud TPU alpha and beta process.

Interested individuals and organizations should sign up for a Cloud TPU beta at https://ai.google/tools/cloud-tpus/.

PiperOrigin-RevId: 159636962
This commit is contained in:
Frank Chen 2017-06-20 17:11:19 -07:00 committed by TensorFlower Gardener
parent a936b239cf
commit a4660cce81
28 changed files with 3601 additions and 0 deletions

View File

@ -304,6 +304,7 @@ filegroup(
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/tpu:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/verbs:all_files",

View File

@ -71,6 +71,9 @@ py_library(
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/text:text_py",
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_helper_library",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
],
@ -106,6 +109,7 @@ cc_library(
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
],
)

View File

@ -65,6 +65,7 @@ from tensorflow.contrib import tensor_forest
from tensorflow.contrib import tensorboard
from tensorflow.contrib import testing
from tensorflow.contrib import tfprof
from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.ndlstm import python as ndlstm

View File

@ -84,6 +84,12 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc"
"${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc"
"${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/cross_replica_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/infeed_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/outfeed_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/replication_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc"
)
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
endif(tensorflow_BUILD_CONTRIB_KERNELS)

View File

@ -67,6 +67,10 @@ file(GLOB_RECURSE tensor_forest_hybrid_srcs
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/*.cc"
)
file(GLOB_RECURSE tpu_ops_srcs
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/*.cc"
)
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
@ -83,6 +87,7 @@ GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensor
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
########################################################

View File

@ -515,6 +515,11 @@ add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not impl
#add_python_module("tensorflow/contrib/tfprof/python")
#add_python_module("tensorflow/contrib/tfprof/python/tools")
#add_python_module("tensorflow/contrib/tfprof/python/tools/tfprof")
add_python_module("tensorflow/contrib/tpu")
add_python_module("tensorflow/contrib/tpu/ops")
add_python_module("tensorflow/contrib/tpu/python")
add_python_module("tensorflow/contrib/tpu/python/ops")
add_python_module("tensorflow/contrib/tpu/python/tpu")
add_python_module("tensorflow/contrib/training")
add_python_module("tensorflow/contrib/training/python")
add_python_module("tensorflow/contrib/training/python/training")

View File

@ -0,0 +1,224 @@
# Description: Operations defined for Cloud TPUs
package(
default_visibility = [
"//learning/brain:__subpackages__",
"//tensorflow:__subpackages__",
],
)
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_test")
cc_library(
name = "all_ops",
deps = [
":cross_replica_ops_op_lib",
":infeed_ops_op_lib",
":outfeed_ops_op_lib",
":replication_ops_op_lib",
":tpu_configuration_ops_op_lib",
":tpu_sendrecv_ops_op_lib",
],
)
py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/tpu_config.py",
"python/tpu/tpu_estimator.py",
],
srcs_version = "PY2AND3",
deps = [
":tpu",
":tpu_py",
":training_loop",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
],
)
tf_gen_op_libs(
op_lib_names = [
"cross_replica_ops",
"infeed_ops",
"outfeed_ops",
"replication_ops",
"tpu_configuration_ops",
"tpu_sendrecv_ops",
],
deps = [
"//tensorflow/core:lib",
],
)
tf_custom_op_library(
name = "python/ops/_tpu_ops.so",
srcs = [
"ops/cross_replica_ops.cc",
"ops/infeed_ops.cc",
"ops/outfeed_ops.cc",
"ops/replication_ops.cc",
"ops/tpu_configuration_ops.cc",
"ops/tpu_sendrecv_ops.cc",
],
)
tf_gen_op_wrapper_py(
name = "tpu_ops",
deps = [
":cross_replica_ops_op_lib",
":infeed_ops_op_lib",
":outfeed_ops_op_lib",
":replication_ops_op_lib",
":tpu_configuration_ops_op_lib",
":tpu_sendrecv_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "tpu_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
dso = [":python/ops/_tpu_ops.so"],
kernels = [
":all_ops",
],
srcs_version = "PY2AND3",
deps = [
":tpu_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
py_library(
name = "tpu_helper_library",
srcs_version = "PY2AND3",
deps = [
":tpu",
":tpu_feed",
":tpu_function",
":tpu_py",
":tpu_sharding",
":training_loop",
],
)
py_library(
name = "tpu_function",
srcs = ["python/tpu/tpu_function.py"],
srcs_version = "PY2AND3",
deps = [
":tpu_feed",
":tpu_py",
"//tensorflow/python:framework",
],
)
py_library(
name = "tpu",
srcs = [
"python/tpu/__init__.py",
"python/tpu/tpu.py",
],
srcs_version = "PY2AND3",
deps = [
":tpu_py",
":training_loop",
"//tensorflow/python:framework",
],
)
py_library(
name = "tpu_sharding",
srcs = ["python/tpu/tpu_sharding.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework",
],
)
py_library(
name = "tpu_feed",
srcs = ["python/tpu/tpu_feed.py"],
srcs_version = "PY2AND3",
deps = [
":tpu_py",
":tpu_sharding",
"//tensorflow/python:framework",
],
)
py_library(
name = "training_loop",
srcs = [
"python/tpu/tpu_optimizer.py",
"python/tpu/training_loop.py",
],
srcs_version = "PY2AND3",
deps = [
":tpu_function",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "tpu_sharding_test",
size = "small",
srcs = ["python/tpu/tpu_sharding_test.py"],
additional_deps = [
":tpu_sharding",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
)
tf_py_test(
name = "tpu_infeed_test",
size = "small",
srcs = ["python/tpu/tpu_infeed_test.py"],
additional_deps = [
":tpu_feed",
":tpu_sharding",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_function_test",
size = "small",
srcs = ["python/tpu/tpu_function_test.py"],
additional_deps = [
":tpu_function",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Ops related to Tensor Processing Units."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
.Output("output: T")
.Attr("T: {float}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
instance supplies its own input, and the output of each is the sum of
all the inputs.
input: The local input to the sum.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,107 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("InfeedDequeue")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
TensorShapeProto shape_proto;
shape.AsProto(&shape_proto);
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
A placeholder op for a value that will be fed into the computation.
output: A tensor that will be provided using the infeed mechanism.
dtype: The type of elements in the tensor.
shape: The shape of the tensor.
)doc");
REGISTER_OP("InfeedEnqueue")
.Input("input: dtype")
.Attr("dtype: type")
.Attr("shape: shape = {}")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.Doc(R"doc(
An op which feeds a single Tensor value into the computation.
input: A tensor that will be provided using the infeed mechanism.
dtype: The type of elements in the tensor.
shape: The shape of the tensor.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");
REGISTER_OP("InfeedEnqueueTuple")
.Input("inputs: dtypes")
.Attr("dtypes: list(type)")
.Attr("shapes: list(shape)")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.Doc(R"doc(
An op which feeds multiple Tensor values into the computation as an XLA tuple.
inputs: A list of tensors that will be provided using the infeed mechanism.
dtypes: The element types of each element in `inputs`.
shapes: The shapes of each tensor in `inputs`.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");
REGISTER_OP("InfeedDequeueTuple")
.Output("outputs: dtypes")
.Attr("dtypes: list(type)")
.Attr("shapes: list(shape)")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
std::vector<PartialTensorShape> shapes;
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
for (int i = 0; i < shapes.size(); ++i) {
TensorShapeProto shape_proto;
shapes[i].AsProto(&shape_proto);
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
c->set_output(i, out);
}
return Status::OK();
})
.Doc(R"doc(
A placeholder op for multiple values that will be fed into the computation
simultaneously as an XLA tuple.
outputs: A list of tensors that will be provided using the infeed mechanism.
dtypes: The element types of each element in `outputs`.
shapes: The shapes of each tensor in `outputs`.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,106 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("OutfeedEnqueue")
.Input("input: dtype")
.Attr("dtype: type")
.SetIsStateful()
.Doc(R"doc(
An op which emits a single Tensor value from an XLA computation.
input: A tensor that will be inserted into the outfeed queue.
)doc");
REGISTER_OP("OutfeedEnqueueTuple")
.Input("inputs: dtypes")
.Attr("dtypes: list(type)")
.SetIsStateful()
.Doc(R"doc(
An op which emits multiple Tensor values from an XLA computation.
inputs: A list of tensors that will be inserted into the outfeed queue as an
XLA tuple.
)doc");
REGISTER_OP("OutfeedDequeue")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("shape: shape")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
Retrieves a single tensor from the computation outfeed. This operation will
block indefinitely until data is available.
output: A tensor that will be read from the device outfeed.
dtype: The type of elements in the tensor.
shape: The shape of the tensor.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");
REGISTER_OP("OutfeedDequeueTuple")
.Output("outputs: dtypes")
.Attr("dtypes: list(type)")
.Attr("shapes: list(shape)")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
std::vector<PartialTensorShape> shapes;
std::vector<DataType> dtypes;
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
if (shapes.size() != dtypes.size()) {
return errors::InvalidArgument(
"Incorrect number of output shapes specified");
}
for (int i = 0; i < shapes.size(); ++i) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
c->set_output(i, out);
}
return Status::OK();
})
.Doc(R"doc(
Retrieve multiple values that will be emitted by the computation as an XLA
tuple. This operations will block indefinitely until data is available.
Output `i` corresponds to XLA tuple element `i`.
outputs: A list of tensors that will be read from the outfeed.
dtypes: The element types of each element in `outputs`.
shapes: The shapes of each tensor in `outputs`.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,87 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("TPUReplicatedInput")
.Input("inputs: N * T")
.Output("output: T")
.Attr("N: int >= 1")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
"From merging shape ", i,
" with other shapes.");
}
c->set_output(0, cur);
return Status::OK();
})
.Doc(
"Operator that connects N unreplicated inputs to an N-way "
"replicated TPU computation.");
REGISTER_OP("TPUReplicatedOutput")
.Input("input: T")
.Output("outputs: num_replicas * T")
.Attr("num_replicas: int >= 1")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->input(0));
}
return Status::OK();
})
.Doc(
"Operator that connects the output of an N-way replicated TPU "
"computation to N separate outputs.");
REGISTER_OP("TPUReplicate")
.Attr("computation: func")
.Attr("num_replicas: int >= 1")
.Attr("global_tpu_id: list(int) = []")
.Attr("Tinputs: list(type) >= 0")
.Attr("Tbroadcast_inputs: list(type) >= 0")
.Attr("NumVariables: int >= 0")
.Attr("output_types: list(type) >= 0")
.Input("inputs: Tinputs")
.Input("broadcast_inputs: Tbroadcast_inputs")
.Input("variables: NumVariables * resource")
.Output("outputs: output_types")
.Doc(R"doc(
Runs replicated computations on a distributed TPU system.
computation: a function containing the computation to run.
num_replicas: the number of replicas of the computation to run.
global_tpu_id: map from device to global tpu id.
Tinputs: the types of the arguments to 'computation'.
inputs: the inputs to 'computation', flattened, in replica-major order.
Tbroadcast_inputs: the types of the additional arguments to broadcast to all
replicas.
broadcast_inputs: additional arguments to broadcast to all replicas. The
broadcast inputs are appended to the per-replica inputs when calling
computation.
output_types: the types of the outputs of 'computation'.
outputs: the outputs of 'computation'.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,213 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
// Configuring a distributed TPU system is achieved by running
// the following Ops:
//
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
// host. This is needed in case the system had previously been
// configured. It returns, for each host, the number of TPU chips on
// the host.
//
// 2 Run _ConfigureDistributedTPU on TPU_SYSTEM. Takes as input the
// number of chips on each host. Validates that all hosts have the
// same number of chips, and that the chips are consistent with the
// topology set by flags. Has a single output which is a proto
// describing the requested system configuration, which is sent to all
// hosts.
//
// 3 Run _InitializeHostForDistributedTPU on the CPU of each host,
// taking as input the output from ConfigureDistributedTPU. Has a
// single Tensor output which is a vector of int32 indicating, for
// each TPU on the host, what its global TPU system id is.
//
// 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the
// outputs from all the _InitializeHostForDistributedTPU
// Ops. _WaitForDistributedTPU has an attr host_specs which is a
// vector<string> giving the partial device spec for each host. These
// partial specs are combined in the Op with the outputs from the host
// initialization Ops to construct a mapping from full TPU device
// specs to global TPU ids. Has a single Tensor output which is a
// matrix of int32 indicating, for each host (outer dimension) and for
// each TPU on the host (inner dimension) what that TPU's global id
// is. _WaitForDistributedTPU also waits for the TPU distributed
// system to initialize fully, which may take several minutes for a
// large system.
//
// 5 Run _SetGlobalTPUArray on the CPU of each host, taking as input
// the output from _WaitForDistributedTPU. This Op tells each host the
// global Id of every TPU on every host.
//
// Most user code works by placing the ConfigureDistributedTPU Op on
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
// the subgraph described above.
//
//
// A distributed TPU system can be cleanly shut down by running
// the following Ops:
//
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
// host.
//
// 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where
// _ConfigureDistributedTPU was run. The Op will return an error if no
// system is configured.
//
//
// Most user code works by placing the ShutdownDistributedTPU Op on
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
// the subgraph described above.
REGISTER_OP("_ConfigureDistributedTPU")
.Input("inputs: N * int32")
.Output("output: string")
.Attr("N: int >= 1")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
// Validate that all the inputs are scalars.
for (int i = 0; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input));
}
c->set_output(0, c->Scalar());
return Status::OK();
})
.Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.
inputs: A scalar tensor for each host indicating how many TPU chips
there are on the host.
output: A tensor containing a TPUHostConfiguration proto serialized to
a string, containing the information necessary to initialize the chips
in a host.
)doc");
REGISTER_OP("_WaitForDistributedTPU")
.Input("inputs: N * int32")
.Output("global_tpu_array: int32")
.Attr("host_specs: list(string)")
.Attr("startup_timeout_sec: int = 20")
.Attr("N: int")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
// Validate that all the inputs have the same vector shape.
for (int i = 0; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input));
}
c->set_output(0, c->UnknownShapeOfRank(2));
return ::tensorflow::Status::OK();
})
.Doc(R"doc(
An op that blocks execution until a distributed TPU system has
started up. This Op must be run on the same TPU_SYSTEM device as
_ConfigureDistributedTPU, and takes an inputs the outputs from the
_InitializeHostForDistributedTPU Ops.
inputs: For each initialized host, a vector giving the global TPU id
of each TPU on the host.
global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
host_specs: For each initialized host, the partial device specification
indicating job, replica, and task. Combining this spec with
'/device:TPU:k' gives the full device name of the k'th TPU on the
host.
startup_timeout_sec: The number of seconds to wait for the TPU system
to stabilize.
)doc");
REGISTER_OP("_SetGlobalTPUArray")
.Input("global_tpu_array: int32")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
return ::tensorflow::Status::OK();
})
.Doc(R"doc(
An op that informs a host of the global ids of all the of TPUs in the
system.
global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
)doc");
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. This Op must be run on the same
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
to start the system, and must be run only after
_DisconnectHostFromDistributedTPUSystem has completed on every host in
the system.
)doc");
REGISTER_OP("_InitializeHostForDistributedTPU")
.Input("input: string")
.Output("tpu_ids: int32")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
c->set_output(0, c->Vector(c->UnknownDim()));
return ::tensorflow::Status::OK();
})
.Doc(R"doc(
An op that connects each chip on the host to a centralized UberDriver to allow
them to operate as a distributed system with chips in other hosts.
input: A string containing the address of the UberDriver to connect to.
tpu_ids: A vector containing the global TPU id of each TPU on the host.
)doc");
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
.Output("number_of_tpu_chips: int32")
.SetIsStateful()
.Doc(R"doc(
An op that disconnects the TPUs on a host from a running distributed
TPU system.
number_of_tpu_chips: A scalar tensor containing the number of TPU
chips on the host.
)doc");
REGISTER_OP("ConfigureDistributedTPU")
.Output("global_tpu_array: int32")
.Attr("embedding_config: string = ''")
.SetIsStateful()
.Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.
global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
embedding_config: Internal use.
)doc");
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("_TPUSend")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.Doc(R"doc(
Sends the named tensor over the TPU fabric.
tensor: The tensor to send.
tensor_name: The name of the tensor to send.
)doc");
REGISTER_OP("_TPURecv")
.Output("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.Doc(R"doc(
Receives the named tensor over the TPU fabric.
tensor: The tensor to receive.
tensor_name: The name of the tensor to receive.
shape: The shape of the input tensor.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Operations for TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.contrib.tpu.ops.gen_tpu_ops import *
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
pass

View File

@ -0,0 +1,20 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Ops related to Tensor Processing Units."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,583 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of TPU helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
def initialize_system(embedding_config=None, job=None):
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
embedding_config: If not None, an EmbeddingLayerConfiguration proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX)
that contains the TPU devices that will be initialized. If job=None
it is assumed there is only one job in the TensorFlow flock, and an
error will be returned if this assumption does not hold.
Returns:
Op which, when executed, will initialize the system.
"""
if job is None:
device_name = "/replica:0/task:0/device:TPU_SYSTEM:0"
else:
device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job
config_string = ("" if embedding_config is None else
embedding_config.SerializeToString())
with ops.device(device_name):
init_distributed_tpu = tpu_ops.configure_distributed_tpu(
embedding_config=config_string)
return init_distributed_tpu
def shutdown_system(job=None):
"""Shuts down a running a distributed TPU system."""
if job is None:
device_name = "/replica:0/task:0/device:TPU_SYSTEM:0"
else:
device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job
with ops.device(device_name):
shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
return shutdown_distributed_tpu
def core(num):
"""Returns the device name for a core in a replicated TPU computation.
Args:
num: the virtual core number within each replica to which operators should
be assigned.
Returns:
A device name, suitable for passing to tf.device().
"""
return "device:TPU_REPLICATED_CORE:{}".format(num)
# Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.) context.
# In
#
# XXX
# with tpu.rewrite(...):
# YYY
# with tpu.outside_all_rewrites():
# ZZZ
#
# the Ops in ZZZ are added outside the scope of the rewrite().
# TODO(phawkins): currently outside_all_rewrites() pops out of all nested
# control flow scopes, for example loops. It would make more sense if it only
# popped out of a single scope.
@contextlib.contextmanager
def outside_all_rewrites():
"""Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.)."""
with ops.control_dependencies(None):
yield
class TPUReplicateContext(control_flow_ops.ControlFlowContext):
"""A ControlFlowContext for nodes inside a TPU computation.
The primary role of TPUReplicateContext is to mark operators inside a
tpu.replicate() computation with attributes:
* _tpu_replicate=XYZ, where XYZ is a unique name, and
* _tpu_num_replicas=k, where k is the number of replicas.
We use a ControlFlowContext to perform the annotation since it
integrates with Tensorflow constructs like ResourceVariables. For example,
if a ResourceVariable is constructed inside a tpu.replicate() block, the
ResourceVariable implementation can use "with ops.control_dependencies(None)"
to build the variable's definition outside the replicated computation.
"""
def __init__(self, name, num_replicas, global_tpu_id=None):
control_flow_ops.ControlFlowContext.__init__(self)
self._name = name
self._num_replicas = num_replicas
self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id
def AddOp(self, op):
self._AddOpInternal(op)
def _AddOpInternal(self, op):
# pylint: disable=protected-access
if any(x.dtype._is_ref_dtype for x in op.inputs):
raise NotImplementedError(
"Non-resource Variables are not supported inside TPU computations "
"(operator name: %s)" % op.name)
# pylint: enable=protected-access
if "_tpu_replicate" in op.node_def.attr:
raise ValueError("TPU computations cannot be nested")
op.node_def.attr["_tpu_replicate"].s = self._name
op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas
op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id)
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
def AddValue(self, val):
result = val
if self._outer_context:
result = self._outer_context.AddValue(val)
return result
def AddInnerOp(self, op):
self._AddOpInternal(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
def replicate(computation,
inputs=None,
infeed_queue=None,
global_tpu_id=None,
name=None):
"""Builds a graph operator that runs a replicated TPU computation.
Args:
computation: a Python function that builds the computation to replicate.
inputs: a list of lists of input tensors or None (equivalent to
[[]]), indexed by [replica_num][input_num]. All replicas must
have the same number of inputs.
infeed_queue: if not None, the InfeedQueue from which to append a tuple
of arguments as inputs to computation.
global_tpu_id: if not None, a Numpy 2D array indicating the global
id of each TPU device in the system. The outer dimension of the
array is host task id, and the inner dimension is device ordinal,
so e.g., global_tpu_id[x][y] indicates the global id of device
/task:x/device:TPU_NODE:y.
name: name of the operator.
Returns:
A list of lists of output tensors, indexed by [replica_num][output_num].
Raises:
ValueError: if all replicas do not have equal numbers of input tensors.
ValueError: if the number of inputs per replica does not match
the number of formal parameters to `computation`.
"""
if name is None:
name = "TPUReplicate"
inputs = [[]] if inputs is None else inputs
if global_tpu_id is not None:
# Turn the Numpy array into a flattened list.
global_tpu_id = global_tpu_id.flatten().tolist()
if ((not isinstance(inputs, list)) or
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
num_replicas = len(inputs)
# No replicas? Nothing to do.
if num_replicas == 0:
return []
# Converts inputs to Tensors.
inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]
# Verifies that all replicas have matching numbers and types of inputs
input_types = [x.dtype for x in inputs[0]]
input_arity = len(input_types)
for i in range(num_replicas):
if len(inputs[i]) != input_arity:
raise ValueError("Replicas must have the same number of inputs. "
"Replica 0 had {} inputs, replica {} had {} "
"inputs.".format(input_arity, i, len(inputs[i])))
types = [x.dtype for x in inputs[i]]
if types != input_types:
raise ValueError(
"Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format(
input_types, i, types))
arg_error = tpu_function.check_function_argument_count(
computation, input_arity, infeed_queue)
if arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
"You specified %d inputs: %s, but the computation needs %s" % (
input_arity, str([i.name for i in inputs[0]]), arg_error))
else:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
"You specified %d inputs: %s and %d additional inputs from infeed,"
" but the computation needs %s" % (input_arity, str(
[i.name
for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
arg_error))
graph = ops.get_default_graph()
with ops.name_scope(name, "replicate"):
# Fan-in: Builds a TPUReplicatedInput node for each input.
computation_inputs = []
for i in range(0, input_arity):
replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
context = TPUReplicateContext(
name=graph.unique_name("cluster"),
num_replicas=num_replicas,
global_tpu_id=global_tpu_id)
try:
context.Enter()
with tpu_function.tpu_shard_context(num_replicas):
# The EncapsulateTPUComputations rewrite needs to identify the
# replicated arguments inside each computation. Adds identity operators
# tagged with an attribute _tpu_replicated_input to identify the
# replicated inputs.
# pylint: disable=protected-access
with graph._attr_scope({"_tpu_replicated_input":
attr_value_pb2.AttrValue(b=True)}):
computation_inputs = [
array_ops.identity(x, name="replicated_input_{}".format(i))
for i, x in enumerate(computation_inputs)]
# pylint: enable=protected-access
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
if infeed_queue is not None:
infeed_queue.set_number_of_shards(num_replicas)
for t in infeed_queue.generate_dequeue_op():
computation_inputs.append(t)
# Only resource variables work inside a TPU computation, so turn on
# resource variables for the computation.
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
vscope.set_use_resource(True)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
# If the computation only returned one value, makes it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
try:
with ops.device(core(0)):
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
except Exception as e:
raise ValueError(
"TPU function return values must all either be Operations or "
"convertible to Tensors. Got '%s'" % str(e))
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs
if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
"TPU functions must return zero-or more Tensor values followed by "
"zero or more Operations.")
output_arity = len(output_tensors)
# Wraps outputs in Identity ops. Otherwise a replicated input copied
# straight to an output would bypass the replicate(). This would be bad
# because the TPUReplicatedInput/TPUReplicatedOutput operator would not
# be rewritten away, leading to a runtime error.
# TODO(phawkins): extend the rewrite to elide these nodes instead.
with ops.device(core(0)):
output_tensors = [array_ops.identity(x) for x in output_tensors]
finally:
context.Exit()
# Fan-out: Builds a TPUReplicatedOutput node for each output.
outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
name="output{}".format(i))
for i in xrange(output_arity)]
with ops.control_dependencies(output_operations):
if output_arity == 0:
# Returns a list of NoOps dependent on the replication Op, indexed by
# [replica_num].
return [
control_flow_ops.no_op(name="%s_shard_%d" % (name, i))
for i in range(num_replicas)
]
else:
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
return [
[array_ops.identity(outputs[out][replica],
name="output_%d_shard_%d" % (out, replica))
for out in xrange(output_arity)]
for replica in xrange(num_replicas)
]
def shard(computation,
inputs=None,
num_shards=1,
input_shard_axes=None,
outputs_from_all_shards=True,
output_shard_axes=None,
infeed_queue=None,
global_tpu_id=None,
name=None):
"""Shards `computation` for parallel execution.
`inputs` must be a list of Tensors or None (equivalent to an empty
list), each of which has a corresponding split axis (from
`input_shard_axes`). Each input is split into `num_shards` pieces
along the corresponding axis, and computation is applied to each
shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
x = tf.constant(7)
def computation():
return x + 3
... = shard(computation, ...)
TODO(phawkins): consider adding support for broadcasting Tensors passed
as inputs.
If `outputs_from_all_shards` is true, the outputs from all shards of
`computation` are concatenated back together along their `output_shards_axes`.
Otherwise, each output is taken from an arbitrary shard.
Inputs and outputs of the computation must be at least rank-1 Tensors.
Args:
computation: a Python function that builds a computation to apply to each
shard of the input.
inputs: a list of input tensors or None (equivalent to an empty
list). Each input tensor has a corresponding shard axes, given
by `input_shard_axes`, which must have size divisible by
`num_shards`.
num_shards: the number of shards.
input_shard_axes: a list of dimensions along which to shard `inputs`, or
`None`. `None` means "shard all inputs along dimension 0". If not `None`,
there must be one dimension per input.
outputs_from_all_shards: boolean or list of boolean. For each output, if
`True`, outputs from all shards are concatenated along the corresponding
`output_shard_axes` entry. Otherwise, each output is taken
from an arbitrary shard. If the argument is a boolean, the argument's
value is used for each output.
output_shard_axes: a list of dimensions along which to concatenate the
outputs of `computation`, or `None`. `None` means "concatenate all outputs
along dimension 0". If not `None`, there must be one dimension per output.
Ignored if `outputs_from_all_shards` is False.
infeed_queue: if not None, the InfeedQueue to use to augment the inputs of
`computation`.
global_tpu_id: if not None, a Numpy 2D array indicating the global
id of each TPU device in the system. The outer dimension of the
array is host task id, and the inner dimension is device ordinal,
so e.g., global_tpu_id[x][y] indicates the global id of device
/task:x/device:TPU_NODE:y.
name: name of the operator.
Returns:
A list of output tensors.
Raises:
ValueError: if num_shards <= 0
ValueError: if len(input_shard_axes) != len(inputs)
ValueError: if len(output_shard_axes) != len(outputs from `computation`)
"""
if num_shards <= 0:
raise ValueError("num_shards must be a positive integer.")
# Converts inputs to Tensors.
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
if input_shard_axes is None:
input_shard_axes = [0] * len(inputs)
if len(inputs) != len(input_shard_axes):
raise ValueError("Length of input_shard_axes must be equal to the number "
"of inputs.")
if inputs:
# Splits the `inputs` along the corresponding `input_shard_axes`, giving
# lists with layout [input][shard]
split_inputs = [
array_ops.split(x, num_shards, axis=axis)
for (axis, x) in zip(input_shard_axes, inputs)]
# Transposes the input lists to have layout [shard][input]
transposed_inputs = [list(i) for i in zip(*split_inputs)]
else:
transposed_inputs = [[]] * num_shards
outputs = replicate(
computation,
transposed_inputs,
infeed_queue=infeed_queue,
global_tpu_id=global_tpu_id,
name=name)
# There must be at least one shard since num_shards > 0.
# TODO(b/36647078) remove disable when pylint bug is fixed.
# pylint: disable=indexing-exception
if isinstance(outputs[0], ops.Operation):
# pylint: enable=indexing-exception
# There were no outputs from the computation and replicate returned a list
# of NoOps with control dependencies on the computation. Return the first
# one so it can be used as a control dependency or fetch node.
# TODO(b/36647078) remove disable when pylint bug is fixed.
# pylint: disable=indexing-exception
return [outputs[0]]
# pylint: enable=indexing-exception
# TODO(b/36647078) remove disable when pylint bug is fixed.
# pylint: disable=indexing-exception
num_outputs = len(outputs[0])
# pylint: enable=indexing-exception
if output_shard_axes is None:
output_shard_axes = [0] * num_outputs
if num_outputs != len(output_shard_axes):
raise ValueError("Length of output_shard_axes must be equal to the number "
"of outputs.")
if isinstance(outputs_from_all_shards, bool):
outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
if num_outputs != len(outputs_from_all_shards):
raise ValueError("Length of outputs_from_all_shards must be equal to the "
"number of outputs.")
results = []
for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
zip(*outputs)):
if all_shards:
# Concatenate all of the outputs together.
results.append(array_ops.concat(list(x), axis=axis))
else:
# TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
results.append(x[0])
return results
def batch_parallel(computation,
inputs=None,
num_shards=1,
infeed_queue=None,
global_tpu_id=None,
name=None):
"""Shards `computation` along the batch dimension for parallel execution.
Convenience wrapper around shard().
`inputs` must be a list of Tensors or None (equivalent to an empty
list). Each input is split into `num_shards` pieces along the 0-th
dimension, and computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
x = tf.constant(7)
def computation():
return x + 3
... = shard(computation, ...)
The outputs from all shards are concatenated back together along their 0-th
dimension.
Inputs and outputs of the computation must be at least rank-1 Tensors.
Args:
computation: a Python function that builds a computation to apply to each
shard of the input.
inputs: a list of input tensors or None (equivalent to an empty
list). The 0-th dimension of each Tensor must have size
divisible by `num_shards`.
num_shards: the number of shards.
infeed_queue: if not None, the InfeedQueue from which to append a tuple
of arguments as inputs to `computation`.
global_tpu_id: if not None, a Numpy 2D array indicating the global
id of each TPU device in the system. The outer dimension of the
array is host task id, and the inner dimension is device ordinal,
so e.g., global_tpu_id[x][y] indicates the global id of device
/task:x/device:TPU_NODE:y.
name: name of the operator.
Returns:
A list of output tensors.
Raises:
ValueError: if num_shards <= 0
"""
return shard(
computation,
inputs,
num_shards=num_shards,
infeed_queue=infeed_queue,
global_tpu_id=global_tpu_id,
name=name)
def rewrite(computation,
inputs=None,
infeed_queue=None,
global_tpu_id=None,
name=None):
"""Rewrites `computation` for execution on a TPU system.
Args:
computation: a Python function that builds a computation to apply
to the input. If the function takes n inputs, 'inputs' should be
a list of n tensors. If the function returns m outputs, rewrite
will return a list of m tensors.
inputs: a list of input tensors or None (equivalent to an empty list).
infeed_queue: if not None, the InfeedQueue from which to append a tuple
of arguments as inputs to `computation`.
global_tpu_id: if not None, a Numpy 2D array indicating the global
id of each TPU device in the system. The outer dimension of the
array is host task id, and the inner dimension is device ordinal,
so e.g., global_tpu_id[x][y] indicates the global id of device
/task:x/device:TPU_NODE:y.
name: name of the operator.
Returns:
A list of output tensors.
"""
if inputs is not None and not isinstance(inputs, (list, tuple)):
raise TypeError("tpu.rewrite() inputs must be a list or tuple")
# TODO(b/36647078) remove disable when pylint bug is fixed.
# pylint: disable=indexing-exception
return replicate(
computation,
None if inputs is None else [inputs],
infeed_queue=infeed_queue,
global_tpu_id=global_tpu_id,
name=name)[0]
# pylint: enable=indexing-exception

View File

@ -0,0 +1,47 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""A RunConfig subclass with TPU support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
class TpuConfig(collections.namedtuple(
'TpuConfig', ['iterations_per_loop', 'num_shards'])):
"""TPU related configuration required by `TPUEstimator`."""
def __new__(cls, iterations_per_loop=2, num_shards=2):
return super(TpuConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards)
class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
def __init__(self, tpu_config=None, **kwargs):
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TpuConfig()
@property
def tpu_config(self):
return self._tpu_config

View File

@ -0,0 +1,361 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tpu Estimator class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
def _tpu_job(run_config):
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
return None if run_config.master in ['', 'local'] else 'tpu_worker'
class _SIGNAL(object):
"""Signal used to control the input thread of infeed."""
NEXT_BATCH = 1
STOP = 2
class InfeedThreadController(object):
"""This wraps the infeed thread and stops when Estimator train finishes.
For model_fn wrapper, it is not possible to know when the `train` API will
stop. It could be the cases that the `max_steps` is reached or some hook
requests the stop in the monitored_session.
This controller (with coordination with `TpuInfeedSessionHook`) does the
following:
1) It pre-infeeds one `batch` data for current TPU iterations.
2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch`
data will be infed.
3) When `end` of `TpuInfeedSessionHook` is called, the thread will end
gracefully.
So, we might need to adjust the algorithrm here if the IO is slower than the
computation.
"""
def __init__(self, session, enqueue_ops, iterations):
self._signal_queue = Queue.Queue()
self._input_thd = threading.Thread(target=self._input_thread_fn_for_loading,
args=(session, enqueue_ops, iterations))
self._input_thd.daemon = True
self._input_thd.start()
def _input_thread_fn_for_loading(self, session, enqueue_ops, iterations):
count = 0
while True:
signal = self._signal_queue.get()
if signal == _SIGNAL.STOP:
logging.info('Stop Infeed input thread.')
return
for i in range(iterations):
logging.debug('InfeedEnqueue data for iteration (%d, %d)', count, i)
session.run(enqueue_ops)
count += 1
def load_next_batch(self):
self._signal_queue.put(_SIGNAL.NEXT_BATCH)
def join(self):
logging.info('Waiting for InputThread to exit.')
self._signal_queue.put(_SIGNAL.STOP)
self._input_thd.join()
class TpuInfeedSessionHook(session_run_hook.SessionRunHook):
"""A Session hook setting up the TPU initialization and infeed.
This hook does two major things:
1. initialize and shutdown TPU system (maybe a separated hook)
2. launch and join the input thread for infeed.
"""
def __init__(self, run_config, enqueue_fn):
self._iterations = run_config.tpu_config.iterations_per_loop
self._enqueue_fn = enqueue_fn
self._tpu_job = _tpu_job(run_config)
def begin(self):
self._enqueue_ops = self._enqueue_fn()
logging.info('TPU job name %s', self._tpu_job)
self._init_op = [tpu.initialize_system(job=self._tpu_job)]
self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)]
def after_create_session(self, session, coord):
logging.info('Init TPU system')
session.run(self._init_op)
logging.info('Start infeed input thread controller')
self._infeed_thd_controller = InfeedThreadController(
session, self._enqueue_ops, self._iterations)
def before_run(self, run_context):
logging.info('Load next batch of data to infeed.')
self._infeed_thd_controller.load_next_batch()
def end(self, session):
logging.info('Stop infeed input thread controller')
self._infeed_thd_controller.join()
logging.info('Shutdown TPU system.')
session.run(self._finalize_op)
class TpuEstimator(estimator_lib.Estimator):
"""Estimator with TPU support.
The only difference is a wrapped model_fn is set in the constructor.
"""
def __init__(self,
model_fn=None,
model_dir=None,
config=None,
params=None,
use_tpu=True):
if use_tpu:
model_function = wrapped_model_fn(model_fn, config)
else:
model_function = model_fn
super(TpuEstimator, self).__init__(
model_fn=model_function,
model_dir=model_dir,
config=config,
params=params)
if not isinstance(config, tpu_config.RunConfig):
raise ValueError('`config` must be `tpu_config.RunConfig`')
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
Args:
graph: The graph in which to create the global step.
Returns:
A global step `Tensor`.
Raises:
ValueError: if the global step tensor is already defined.
"""
graph = graph or ops.get_default_graph()
if training.get_global_step(graph) is not None:
raise ValueError('"global_step" already exists.')
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
dtype=dtypes.int32,
initializer=init_ops.zeros_initializer(),
trainable=False,
use_resource=True,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
# TODO(xiejw): Improve the structure of this input_fn to infeed converion.
# The code now looks not like Estimator style. We need to abstract many
# details.
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
"""Utility to convert input_fn to enqueue and dequeue fns for TPU.
Mainly, three things need to be done here.
1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
2. Create a dequeue_fn used by the train_step inside TPU execution to
dequeue the tensors.
3. Sets up the input thread to infeed.
Args:
run_config: run_config
features: features
labels: labels
Returns:
A tuple of (dequeue_fn, and thread main function)
"""
infeed_names = None
infeed_tuple = []
if isinstance(features, dict):
# We need a fixed ordering for enqueueing and dequeueing.
infeed_names = [name for name in features]
infeed_tuple.extend([features[name] for name in infeed_names])
else:
infeed_tuple.append(features)
# TODO(jhseu): Handle multi-head and None labels
infeed_tuple.append(labels)
# TODO(jhseu): Update when b/36470756 is settled.
infeed_queue = tpu_feed.InfeedQueue(
tuple_types=[t.dtype for t in infeed_tuple],
tuple_shapes=[t.shape for t in infeed_tuple])
infeed_queue.set_number_of_shards(run_config.tpu_config.num_shards)
def dequeue_fn():
"""dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
values = infeed_queue.generate_dequeue_op()
if infeed_names is None:
return values
# Restore the feature dictionary and label.
dequeued_features = {}
for i in range(len(values) - 1):
dequeued_features[infeed_names[i]] = values[i]
label = values[-1]
return dequeued_features, label
def enqueue_fn():
"""enqueue_fn is used to add ops to the graph to send tensors."""
job = _tpu_job(run_config)
def placement_function(index):
if job is None:
return '/replica:0/task:0/device:CPU:0'
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
return infeed_queue.split_inputs_and_generate_enqueue_ops(
infeed_tuple, placement_function=placement_function)
return (dequeue_fn, enqueue_fn)
def wrapped_model_fn(model_fn, run_config):
"""Returns a new model_fn, which wraps the TPU support."""
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params=None) # pylint: disable=protected-access
def _model_fn(features, labels, mode):
"""model_fn."""
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
if mode != model_fn_lib.ModeKeys.TRAIN:
return model_fn(features, labels, mode)
dequeue_fn, enqueue_fn = (
_create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels))
loss = _train_on_tpu_shards(
run_config,
train_step=_convert_model_fn_to_train_step(
model_fn, dequeue_fn, mode, run_config))
# Gets the variables back from TPU nodes. This means the variables updated
# by TPU will now be *synced* to host memory.
update_ops = [
array_ops.check_numerics(v.read_value(),
'Gradient for %s is NaN' % v.name).op
for v in variables.trainable_variables()
]
hooks = [
TpuInfeedSessionHook(run_config, enqueue_fn),
training.LoggingTensorHook(
{'loss': array_ops.identity(loss),
'step': training.get_global_step()},
every_n_secs=30)
]
return model_fn_lib.EstimatorSpec(
mode,
loss=array_ops.identity(loss),
training_hooks=hooks,
train_op=control_flow_ops.group(*update_ops))
return _model_fn
def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config):
"""generates a train step based on the model_fn."""
def _call_model_fn(features, labels):
"""Calls the model_fn with required parameters."""
model_fn_args = estimator_lib._model_fn_args(model_fn) # pylint: disable=protected-access
kwargs = {}
if 'mode' in model_fn_args:
kwargs['mode'] = mode
# Uncomment the following lines once `params` is supported.
# if 'params' in model_fn_args:
# kwargs['params'] = params
if 'config' in model_fn_args:
kwargs['config'] = run_config
return model_fn(features=features, labels=labels, **kwargs)
def _verify_estimator_spec(estimator_spec):
"""Validates the estimator_spec."""
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
if estimator_spec.training_chief_hooks:
raise ValueError(err_msg.format('training_chief_hooks'))
if estimator_spec.training_hooks:
raise ValueError(err_msg.format('training_hooks'))
return estimator_spec
def train_step(loss):
"""Training step function for use inside a while loop."""
del loss # unused; required in function signature.
features, labels = dequeue_fn()
# TODO(xiejw): how to do we support hook and savers in the original
# model_fn. Realistically, the original
# model_fn will be excuted on TPU chips in a replica way. The hooks
# returned by the model_fn cannot be supported at all. If we have to,
# the graph construction part in the model_fn should be separated from the
# control part (such as hooks and savers). By that the graph construction
# could de defered on TPU chip, while the control logic can stay in host.
estimator_spec = _verify_estimator_spec(_call_model_fn(features, labels))
loss, train_op = estimator_spec.loss, estimator_spec.train_op
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
return train_step
def _train_on_tpu_shards(run_config, train_step):
"""Executes the `train_step` on all shards."""
def train_shard():
return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
train_step,
[1e7], # initial_loss
name='loop')
(loss,) = tpu.shard(train_shard,
inputs=[],
num_shards=run_config.tpu_config.num_shards,
outputs_from_all_shards=False)
return loss

View File

@ -0,0 +1,613 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Helper library for handling infeed between hosts and TPUs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
class InfeedQueue(object):
"""A helper object to build a device infeed queue.
The InfeedQueue builds the host-side and device-side Ops to enqueue and
dequeue elements, respectively, and ensures that their types and
shapes match.
"""
def __init__(self,
number_of_tuple_elements=None,
tuple_types=None,
tuple_shapes=None,
shard_dimensions=None,
name=None):
"""Creates a new InfeedQueue with the given configuration.
The configuration need not be fully specified at creation since it
can be modified subsequently by methods that set the values
explicitly or infer them from the shapes of inputs.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
tuple_types: if not None, a list of types of the elements of the queue.
tuple_shapes: if not None, a list of shapes of the elements of the queue.
shard_dimensions: if not None, a list of dimensions on which the
elements of the queue should be sharded during automatic
parallelization.
name: the name of the queue.
Raises:
ValueError: if number_of_tuple_elements <= 0; or
number_of_tuple_arguments, tuple_types, tuple_shapes, and
shard_dimensions are all None; or the length of tuple_types,
tuple_shapes, or shard_dimensions is not equal to
number_of_tuple_elements; or any element of shard_dimensions
can't be converted to a Dimension.
TypeError: if any element of tuple_types or tuple_shapes can't
be converted to a dtype or TensorShape, respectively.
"""
self._frozen = False
self._generated_enqueue_ops = False
self._generated_dequeue_op = False
self._name = "InfeedQueue" if name is None else name
if number_of_tuple_elements is None:
if tuple_types is not None:
number_of_tuple_elements = len(tuple_types)
elif tuple_shapes is not None:
number_of_tuple_elements = len(tuple_shapes)
elif shard_dimensions is not None:
number_of_tuple_elements = len(shard_dimensions)
else:
raise ValueError(
"number of tuple elements cannot be inferred from InfeedQueue "
"constructor"
)
if number_of_tuple_elements <= 0:
raise ValueError("number_of_tuple_elements %d must be > 0" %
number_of_tuple_elements)
# Make an empty sharding policy for each tuple element.
self._sharding_policies = [
tpu_sharding.ShardingPolicy()
for _ in xrange(number_of_tuple_elements)
]
if tuple_types is not None:
self.set_tuple_types(tuple_types)
else:
self._tuple_types = None
if tuple_shapes is not None:
self.set_tuple_shapes(tuple_shapes)
else:
self._tuple_shapes = None
if shard_dimensions is not None:
self.set_shard_dimensions(shard_dimensions)
self._validate()
def _validate(self):
"""Checks that the configuration is self-consistent.
Raises:
ValueError: if the shapes and sharding policies don't match.
"""
if self.tuple_shapes is not None:
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
# Raise an error if the policy is incompatible with the shape.
_ = policy.get_sharded_shape(shape)
@property
def number_of_tuple_elements(self):
"""Returns the number of InfeedQueue tuple elements."""
return len(self._sharding_policies)
@property
def tuple_types(self):
"""Returns the types of the InfeedQueue tuple elements."""
return self._tuple_types
def set_tuple_types(self, tuple_types):
"""Sets the type of each element of the queue.
tuple_types must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a dtype.
Args:
tuple_types: the types of each queue element.
Raises:
ValueError: if tuple_types is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_types cannot be converted to a
dtype.
"""
if len(tuple_types) != self.number_of_tuple_elements:
raise ValueError("tuple_types is %s, but must be a list of length %d" %
(str(tuple_types), self.number_of_tuple_elements))
if self._frozen:
for (frozen, updated) in zip(self._tuple_types, tuple_types):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible type. Frozen types are %s, updated types are %s" % (
str(self._tuple_types), str(tuple_types)))
else:
try:
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
except (TypeError) as e:
raise TypeError(
"tuple_types is %s, but must be a list of elements each "
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
@property
def tuple_shapes(self):
"""Returns the shapes of the InfeedQueue tuple elements."""
return self._tuple_shapes
def set_tuple_shapes(self, tuple_shapes):
"""Sets the shape of each element of the queue.
tuple_shapes must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a TensorShape.
Args:
tuple_shapes: the shapes of each queue element.
Raises:
ValueError: if tuple_shapes is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_shapes cannot be converted to
a TensorShape.
"""
if len(tuple_shapes) != self.number_of_tuple_elements:
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
(str(tuple_shapes), self.number_of_tuple_elements))
try:
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
except (ValueError, TypeError) as e:
raise TypeError(
"tuple_shapes is %s, but must be a list of elements each "
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
str(e)))
if self._frozen:
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
% (str(self._tuple_shapes), str(tuple_shapes)))
else:
self._tuple_shapes = tuple_shapes
self._validate()
@property
def sharding_policies(self):
"""Returns the sharding policies of the InfeedQueue tuple elements."""
return self._sharding_policies
@property
def shard_dimensions(self):
"""Gets the shard dimension of each tuple element.
Returns:
A list of length number_of_tuple_elements, where each list entry
is the shard dimension of that tuple element or None if the
shard dimension has not been set.
"""
# The number of shards is always the same for all the policies.
return [policy.shard_dimension for policy in self._sharding_policies]
def set_shard_dimensions(self, shard_dimensions):
"""Sets the shard_dimension of each element of the queue.
shard_dimensions must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a Dimension compatible with self.tuple_shapes.
Args:
shard_dimensions: the dimensions of each queue element.
Raises:
ValueError: if shard_dimensions is not of length
self.number_of_tuple_elements; or an element of
shard_dimensions cannot be converted to a Dimension; or an
element of shard_dimensions is a Dimension that is out of
range for the corresponding tuple element shape.
"""
if len(shard_dimensions) != self.number_of_tuple_elements:
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
% (str(shard_dimensions),
self.number_of_tuple_elements))
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
policy.set_shard_dimension(dimension)
self._validate()
@property
def number_of_shards(self):
"""Gets the number of shards to use for the InfeedQueue.
Returns:
Number of shards or None if the number of shards has not been set.
"""
# The number of shards is always the same for all the policies.
return self._sharding_policies[0].number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards to use for the InfeedQueue.
Args:
number_of_shards: number of ways to shard the InfeedQueue.
Raises:
ValueError: if number_of_shards is not > 0; or the policies have
been frozen and number_of_shards was already set to something
else.
"""
for policy in self._sharding_policies:
policy.set_number_of_shards(number_of_shards)
self._validate()
def set_configuration_from_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of Tensors whose types and shapes are used
to set the queue configuration.
Args:
input_tensors: list of Tensors of the same types and shapes as
the desired queue Tuple.
Raises:
ValueError: if input_tensors is not a list of length
self.number_of_tuple_elements
"""
if len(input_tensors) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s, but should be a list of %d Tensors", (
str(input_tensors), self.number_of_tuple_elements))
self.set_tuple_shapes([t.shape for t in input_tensors])
self.set_tuple_types([t.dtype for t in input_tensors])
def set_configuration_from_sharded_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of lists of Tensors whose types and shapes are used
to set the queue configuration. The length of the outer list is the number
of shards required, and each inner list is the tuple of Tensors to use to
determine the types and shapes of the correponding shard. This method
depends on the shard dimension, and calling it freezes the shard policy.
Args:
input_tensors: list of lists of Tensors. The outer list length corresponds
to the desired number of shards, and each inner list is the size
and shape of the desired configuration of the corresponding shard.
Raises:
ValueError: if any inner list is not a list of length
self.number_of_tuple_elements; or the inner lists do not combine to
form a consistent unsharded shape.
TypeError: if the types of the Tensors in the inner lists do not match.
"""
if not self._frozen:
# Unset the tuple shapes in case the configuration becomes
# transiently inconsistent.
self._tuple_shapes = None
number_of_shards = len(input_tensors)
self.set_number_of_shards(number_of_shards)
for t in input_tensors:
if len(t) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s but must be a list of lists, where each inner"
" list has length number_of_tuple_elements=%d" % (
str(input_tensors), self.number_of_tuple_elements))
# Transpose the inputs to make a list of shard shapes for each tuple
# element.
sharded_shapes = [[t[i].shape for t in input_tensors]
for i in xrange(self.number_of_tuple_elements)]
# For each tuple, get the unsharded shape using that tuple's policy.
unsharded_shapes = [
policy.get_unsharded_shape(s)
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
]
self.set_tuple_shapes(unsharded_shapes)
for i in xrange(1, self.number_of_shards):
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
if t1.dtype != t2.dtype:
raise TypeError(
"types of the tuple elements of input_tensors %s are not "
"consistent" % str(input_tensors))
self.set_tuple_types([t.dtype for t in input_tensors[0]])
def freeze(self):
"""Freezes the InfeedQueue so it can no longer be modified.
The configuration is implicitly frozen before any host-side or
device-side Ops are generated. The configuration cannot be frozen
until the types and shapes of the tuple elements have been set.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set.
"""
self._frozen = True
if self._tuple_types is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple types.")
if self._tuple_shapes is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for shape in self._tuple_shapes:
if shape.dims is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for policy in self._sharding_policies:
policy.freeze()
self._validate()
def generate_dequeue_op(self):
"""Generates the device-side Op to dequeue a tuple from the queue.
Implicitly freezes the queue configuration if it is not already
frozen, which will raise errors if the shapes and types have not
been fully specified.
Returns:
A list of Outputs corresponding to a shard of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
def _generate_enqueue_op(self,
inputs,
name_prefix,
index,
device=None,
tpu_ordinal=-1):
"""Generate a host-side Op to enqueue a tuple to the queue.
If device is None the inputs are all required to have the same
device specification, and the enqueue Op is colocated with
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
Args:
inputs: a list of Tensors with the types and shapes of the tuple elements.
name_prefix: the base name for the Op.
index: the shard index, used to uniquify the Op name.
device: device to place the Op on, or None if it should be
colocated with the inputs.
tpu_ordinal: ordinal of the TPU device on the host to use for
infeed if device is a CPU device. Should be set to -1 if device
is a TPU device.
Returns:
An Op corresponding to a shard of infeed enqueued at the host,
suitable for use within a replicated block.
Raises:
ValueError: if device is None and inputs do not all have the
same device specification.
"""
full_name = "%s/%d" % (name_prefix, index)
shapes = [t.shape for t in inputs]
if device is None:
devices = [t.device for t in inputs]
for i in xrange(1, self.number_of_tuple_elements):
if devices[0] != devices[i]:
raise ValueError(
"input devices for shard %d are %s, but should all be the same",
index, str(devices))
with ops.colocate_with(inputs[0]):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
else:
with ops.device(device):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
def generate_enqueue_ops(self, sharded_inputs):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
shard 0 if the queue. Returns the host-side Ops that must be run to
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
for shard i.
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of sharded_inputs, an error
will be raised.
Args:
sharded_inputs: a list of lists of Tensors. The length of the outer list
determines the number of shards. Each inner list indicates the types
and shapes of the tuples in the corresponding shard.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(shard, name_prefix, index)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
# TODO(misard) Generalize this to the case of systems that don't
# have 8 devices per host, and figure out what to do with
# model-parallelism.
def _default_placement_function(self, index):
return "/task:%d/device:CPU:0" % (index / 8)
def _default_ordinal_function(self, index):
return index % 8
# TODO(b/36470756) remove this from tutorials once we have a better story
# for automatic placement of input pipelines.
def split_inputs_and_generate_enqueue_ops(self,
inputs,
global_tpu_id=None,
placement_function=None,
tpu_ordinal_function=None):
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
Generates the host-side Ops to enqueue a tuple.
This method performs poorly because it takes an entire input on a single
host, splits it, and distributes it to all of the cores. It is present only
to simplify tutorial examples.
inputs is a list of Tensors to use to feed the queue. Each input is split
into self.number_of_shards shards. Returns an Op for each shard to enqueue
the shard. The Op for shard i is placed on device placement_function(i).
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of inputs, an error
will be raised.
Args:
inputs: a list of Tensors which indicates the types and shapes of the
queue tuple.
global_tpu_id: if not None, a Numpy 2D array indicating the global
id of each TPU device in the system. The outer dimension of the
array is host task id, and the inner dimension is device ordinal,
so e.g., global_tpu_id[x][y] indicates the global id of device
/task:x/device:TPU_NODE:y. If global_tpu_id is not None, but
placement_function and ordinal_function are None, then global_tpu_id
will be used to place infeed on the TPUs with the first k global ids,
where k is the number of shards in the queue.
placement_function: if not None, a function that takes the shard
index as input and returns a device string indicating which
device the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of inputs are not compatible with the frozen
configuration.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of inputs are not compatible with the frozen
configuration.
"""
if global_tpu_id is None:
if placement_function is None:
placement_function = self._default_placement_function
if tpu_ordinal_function is None:
tpu_ordinal_function = self._default_ordinal_function
else:
global_id_map = {}
for host, devices in enumerate(global_tpu_id):
for ordinal, global_id in enumerate(devices):
global_id_map[global_id] = (host, ordinal)
def _placement_function_from_map(index):
return "/task:%d/device:CPU:0" % global_id_map[index][0]
def _ordinal_function_from_map(index):
return global_id_map[index][1]
if placement_function is None:
placement_function = _placement_function_from_map
if tpu_ordinal_function is None:
tpu_ordinal_function = _ordinal_function_from_map
self.set_configuration_from_input_tensors(inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
split_name_prefix = "%s/split" % self._name
if self.number_of_shards == 1:
transposed_sharded_inputs = [[inp] for inp in inputs]
else:
transposed_sharded_inputs = [
array_ops.split(
inp,
self.number_of_shards,
axis=policy.shard_dimension,
name="%s/%d" % (split_name_prefix, index))
for (inp, policy, index) in zip(inputs, self._sharding_policies,
xrange(self.number_of_tuple_elements))
]
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
for i in xrange(self.number_of_shards)]
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
device=placement_function(index),
tpu_ordinal=tpu_ordinal_function(index))
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]

View File

@ -0,0 +1,106 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for functions used during TPU compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.util import tf_inspect
class TpuContext(object):
"""A context object holding state about the TPU computation being built."""
def __init__(self):
"""Creates a new TpuContext."""
self._number_of_shards = None
@property
def number_of_shards(self):
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
self._number_of_shards = number_of_shards
# The Tpu context holds the number of shards when a sharded computation is
# being built, or None if no computation is being built.
_current_tpu_context = TpuContext()
@contextlib.contextmanager
def tpu_shard_context(number_of_shards):
if _current_tpu_context.number_of_shards is not None:
raise NotImplementedError("tpu_shard_context cannot be nested.")
try:
_current_tpu_context.set_number_of_shards(number_of_shards)
yield
finally:
_current_tpu_context.set_number_of_shards(None)
def get_tpu_context():
return _current_tpu_context
def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to a tpu function.
Args:
func: the Python function that will be called to generate the body
of a TPUFunction.
input_arity: the number of explicit arguments supplied by the
caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
Returns:
None if function can be called with the supplied number of
arguments, or an error string if it cannot.
"""
def format_error(complaint, quantity):
return "%s %d argument%s" % (complaint, quantity, ""
if quantity == 1 else "s")
number_of_arguments_needed = input_arity
if infeed_queue is not None:
number_of_arguments_needed += infeed_queue.number_of_tuple_elements
arg_spec = tf_inspect.getargspec(func)
number_of_args = len(arg_spec.args)
if arg_spec.defaults is None:
number_of_defaults = 0
else:
number_of_defaults = len(arg_spec.defaults)
min_required_arguments = number_of_args - number_of_defaults
if number_of_arguments_needed < min_required_arguments:
# The required number of arguments is not enough to call the function.
if number_of_defaults == 0 and arg_spec.varargs is None:
return format_error("exactly", number_of_args)
else:
return format_error("at least", min_required_arguments)
if arg_spec.varargs is None and number_of_arguments_needed > number_of_args:
# The required number of arguments is too many to call the function.
if number_of_defaults == 0:
return format_error("exactly", number_of_args)
else:
return format_error("at most", number_of_args)
# Since there are varargs, func can accept any number of arguments
# greater than the minimum.
return None

View File

@ -0,0 +1,125 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tpu_function helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.platform import test
class FunctionArgCheckTest(test.TestCase):
def testSimple(self):
"""Tests that arg checker works for functions with no varargs or defaults.
"""
def func(x, y, z):
return x + y + z
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, None))
self.assertEqual("exactly 3 arguments",
tpu_function.check_function_argument_count(func, 2, None))
queue = tpu_feed.InfeedQueue(2)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 1, queue))
self.assertEqual("exactly 3 arguments",
tpu_function.check_function_argument_count(func, 2, queue))
def testDefaultArgs(self):
"""Tests that arg checker works for a function with no varargs."""
def func(x, y, z=17):
return x + y + z
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, None))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 1, None))
self.assertEqual("at most 3 arguments",
tpu_function.check_function_argument_count(func, 4, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 1, queue))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 0, queue))
self.assertEqual("at most 3 arguments",
tpu_function.check_function_argument_count(func, 4, queue))
def testVarArgs(self):
"""Tests that arg checker works for a function with varargs."""
def func(x, y, *z):
return x + y + len(z)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 4, None))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 1, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, queue))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 0, queue))
def testVarArgsAndDefaults(self):
"""Tests that arg checker works for a function with varargs and defaults."""
def func(x, y, z=17, *q):
return x + y + z + len(q)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 4, None))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 5, None))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 1, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 2, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 3, queue))
self.assertEqual(None,
tpu_function.check_function_argument_count(func, 4, queue))
self.assertEqual("at least 2 arguments",
tpu_function.check_function_argument_count(func, 0, queue))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,130 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for TPU InfeedQueue methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
class InfeedTest(test.TestCase):
def testConstructor(self):
"""Tests that the constructor can be called with different arguments."""
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
self.assertEqual(i.number_of_tuple_elements, 2)
self.assertEqual(i.tuple_types, None)
self.assertEqual(i.tuple_shapes, None)
self.assertEqual(i.number_of_shards, None)
i = tpu_feed.InfeedQueue(
tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32])
self.assertEqual(i.number_of_tuple_elements, 3)
self.assertEqual(i.tuple_types,
[dtypes.float32, dtypes.int32, dtypes.int32])
self.assertEqual(i.tuple_shapes, None)
self.assertEqual(i.number_of_shards, None)
i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]])
self.assertEqual(i.number_of_tuple_elements, 2)
self.assertEqual(i.tuple_types, None)
self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
self.assertEqual(i.number_of_shards, None)
i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7])
self.assertEqual(i.number_of_tuple_elements, 3)
self.assertEqual(i.tuple_types, None)
self.assertEqual(i.tuple_shapes, None)
self.assertEqual([p.shard_dimension
for p in i.sharding_policies], [1, 0, 7])
with self.assertRaises(ValueError):
i = tpu_feed.InfeedQueue()
with self.assertRaises(ValueError):
i = tpu_feed.InfeedQueue(
number_of_tuple_elements=2, tuple_types=[dtypes.float32])
with self.assertRaises(ValueError):
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]])
with self.assertRaises(ValueError):
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1])
with self.assertRaises(ValueError):
i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1])
def testModification(self):
"""Tests modification of the queue post-construction."""
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
i.set_tuple_types([dtypes.float32, dtypes.int32])
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
i.set_tuple_types([dtypes.float32, dtypes.float32])
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32])
with self.assertRaises(ValueError):
i.set_tuple_types([dtypes.float32])
i.set_tuple_shapes([[1], [2, 3]])
self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
i.set_tuple_shapes([[1, 2], [3, 4]])
self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]])
with self.assertRaises(ValueError):
i.set_tuple_shapes([[1, 2]])
i.set_number_of_shards(2)
self.assertEqual(i.number_of_shards, 2)
i.set_number_of_shards(3)
self.assertEqual(i.number_of_shards, 3)
t1 = constant_op.constant(1, dtypes.int32, shape=[6])
t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18])
i.set_configuration_from_input_tensors([t1, t2])
self.assertEqual(i.tuple_shapes, [[6], [3, 18]])
self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32])
i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
self.assertEqual(i.number_of_shards, 2)
self.assertEqual(i.tuple_shapes, [[6, 18], [12]])
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
i.set_shard_dimensions([1, 0])
i.set_number_of_shards(3)
with self.assertRaises(ValueError):
i.set_number_of_shards(4)
def testFreezing(self):
"""Tests freezing the queue."""
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
t1 = constant_op.constant(1, dtypes.int32, shape=[2])
t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4])
i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
self.assertEqual(i.number_of_shards, 2)
self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
self.assertEqual(i.shard_dimensions, [0, 0])
i.freeze()
i.set_number_of_shards(2)
i.set_tuple_shapes([[4, 4], [4]])
i.set_tuple_types([dtypes.float32, dtypes.int32])
i.set_shard_dimensions([0, 0])
with self.assertRaises(ValueError):
i.set_number_of_shards(1)
with self.assertRaises(ValueError):
i.set_tuple_shapes([[8, 8], [8]])
with self.assertRaises(ValueError):
i.set_tuple_types([dtypes.int32, dtypes.float32])
with self.assertRaises(ValueError):
i.set_shard_dimensions([1, 0])
self.assertEqual(i.number_of_shards, 2)
self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
self.assertEqual(i.shard_dimensions, [0, 0])
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,106 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Optimizer that implements cross-shard gradient reduction for TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.python.training import optimizer
class CrossShardOptimizer(optimizer.Optimizer):
"""A optimizer sums gradients across TPU shards."""
def __init__(self, opt, name="CrossShardOptimizer"):
super(CrossShardOptimizer, self).__init__(False, name)
self._opt = opt
def compute_gradients(self, *args, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
This simply wraps the compute_gradients() from the real optimizer. The
gradients will be aggregated in the apply_gradients() so that user can
modify the gradients like clipping with per replica global norm if needed.
The global norm with aggregated gradients can be bad as one replica's huge
gradients can hurt the gradients from other replicas.
Args:
*args: Arguments for compute_gradients().
**kwargs: Keyword arguments for compute_gradients().
Returns:
A list of (gradient, variable) pairs.
"""
return self._opt.compute_gradients(*args, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
replicas, and then applies the real optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
compute_gradients().
global_step: Optional Variable to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the Optimizer constructor.
Returns:
An `Operation` that applies the gradients. If `global_step` was not None,
that operation also increments `global_step`.
Raises:
ValueError: If the grads_and_vars is malformed.
"""
summed_grads_and_vars = []
for (grad, var) in grads_and_vars:
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
"""Return a slot named "name" created for "var" by the Optimizer.
This simply wraps the get_slot() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
return self._opt.get_slot(*args, **kwargs)
def get_slot_names(self, *args, **kwargs):
"""Return a list of the names of slots created by the `Optimizer`.
This simply wraps the get_slot_names() from the actual optimizer.
Args:
*args: Arguments for get_slot().
**kwargs: Keyword arguments for get_slot().
Returns:
A list of strings.
"""
return self._opt.get_slot_names(*args, **kwargs)

View File

@ -0,0 +1,248 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Helper library for sharding during TPU compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import tensor_shape
_DEFAULT_NUMBER_OF_SHARDS = 1
_DEFAULT_SHARD_DIMENSION = 0
# TODO(b/36777903) change other parts of tpu.py to use this class.
class ShardingPolicy(object):
"""An object use to hold the sharding policy for a Tensor.
"""
def __init__(self):
self._number_of_shards = None
self._shard_dimension = None
self._frozen = False
def __str__(self):
if self.number_of_shards is None or self.shard_dimension is None:
return "ShardingPolicy(unset)"
else:
return ("ShardingPolicy(%d shards dimension %d)" %
(self.number_of_shards, self.shard_dimension))
def _fill_default_values(self):
if self._number_of_shards is None:
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
if self._shard_dimension is None:
self._shard_dimension = tensor_shape.as_dimension(
_DEFAULT_SHARD_DIMENSION)
def freeze(self):
"""Prevents further modification to the sharding policy.
Any values that have not been set when freeze is called are set to
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
"""
if not self._frozen:
self._fill_default_values()
self._frozen = True
@property
def number_of_shards(self):
"""Returns the number of shards in the policy or None if unspecified."""
return self._number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards for the current policy.
If the policy has been frozen then number_of_shards must match the
existing setting.
Args:
number_of_shards: The number of shards to use in the policy.
Raises:
ValueError: If the policy has been frozen and number_of_shards
differs from the frozen value; or number_of_shards <= 0.
"""
if self._frozen:
if self._number_of_shards != number_of_shards:
raise ValueError(
"Can't set sharding policy to use %d shards since it has been "
"frozen to use %d." % (number_of_shards, self._number_of_shards))
else:
if number_of_shards > 0:
self._number_of_shards = number_of_shards
else:
raise ValueError(
"Can't set sharding policy to use %s shards; value must be >0",
str(number_of_shards))
@property
def shard_dimension(self):
"""Returns the shard dimension of the policy or None if unspecified."""
return self._shard_dimension
def set_shard_dimension(self, shard_dimension):
"""Sets the shard dimension for the current policy.
If the policy has been frozen then shard_dimension must match the
existing setting.
Args:
shard_dimension: The shard dimension to use in the policy.
Raises:
ValueError: If the policy has been frozen and shard_dimension
differs from the frozen value, or shard_dimension can't be
interpreted as a Dimension.
"""
if self._frozen:
if self._shard_dimension != shard_dimension:
raise ValueError(
"Can't set shard dimension to %d since it has been frozen to "
"use %d." % (shard_dimension, self._shard_dimension))
else:
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
def merge(self, other):
"""Merges the policy of another policy into the current policy.
Args:
other: The policy to merge into this one.
Raises:
ValueError: If this policy has been frozen and the merge conflicts with
the frozen policy.
"""
if other.number_of_shards is not None:
self.set_number_of_shards(other.number_of_shards)
if other.shard_dimension is not None:
self.set_shard_dimension(other.shard_dimension)
def get_sharded_shape(self, shape, shard_index=None):
"""Returns the shape of a shard of a full Tensor.
When given the shape of a 'full-size' Tensor, returns the shape of
the sub-Tensor after it has been sharded. Freezes the policy if it
has not yet been frozen.
Args:
shape: The shape of the full-size Tensor to be sharded.
shard_index: The index of the shard whose shape should be returned.
shard_index can be None for sharding policies that use the same
shape for every shard.
freeze_config:
Returns:
The shape of the sharded version of the Tensor.
Raises:
ValueError: If shard_index is None when shards are of different
shapes; or shard_index is not None and
!(0<=shard_index<number_of_shards); or shape does not have at
least self.shard_dimension+1 dimensions; or the value of
shape's shard dimension is not a multiple of
self.number_of_shards
"""
if self._shard_dimension is None or self._number_of_shards is None:
# Don't raise an error if the config is unset.
return None
if shard_index is not None:
if shard_index < 0 or shard_index >= self.number_of_shards:
raise ValueError("shard_index %d, but must be in [0,%d)." %
(shard_index, self._number_of_shards))
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
(shape.as_list(), self._number_of_shards,
self._shard_dimension))
dims[self._shard_dimension] /= self._number_of_shards
return tensor_shape.as_shape(dims)
def _unshard_shape(self, shape):
"""Return the unsharded shape that would generate a given sharded shape.
Args:
shape: the sharded shape to unshard
Returns:
The unsharded shape.
Raises:
ValueError: if shape is unknown or does not contain
self.shard_dimension
TypeError: if shape is not convertible to a TensorShape
"""
shape = tensor_shape.as_shape(shape)
if self._number_of_shards == 1:
# Don't do anything when there's only one shard.
return shape
ndims = shape.ndims
if ndims is None:
raise ValueError("shape must be a specified shape not Unknown")
if ndims <= self._shard_dimension:
raise ValueError("shape %s does not contain shard_dimension %d" %
(shape.as_list(), self._shard_dimension))
dims = shape.as_list()
dims[self._shard_dimension] *= self._number_of_shards
return tensor_shape.as_shape(dims)
def get_unsharded_shape(self, shapes):
"""Returns the shape of an unsharded Tensor given a list of shards.
When given a list of shapes of shards, returns the shape of the
unsharded Tensor that would generate the shards. Sets defaults for the
policy if number_of_shards or shard_dimension is None.
Args:
shapes: The shapes of the Tensor shards to be combined.
Returns:
The shape of the unsharded version of the Tensor.
Raises:
ValueError: if shapes is not a list of length
self.number_of_shards; or any element of shapes is not a valid
shape consistent with the sharding policy; or the list of
shapes is not a valid sharding of a full shape.
TypeError: if an element of shapes is not convertible to a
TensorShape
"""
self._fill_default_values()
if len(shapes) != self.number_of_shards:
raise ValueError(
"shapes is %s but must be a list of length number_of_shards=%d" % (
str(shapes), self.number_of_shards))
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
for i in xrange(self.number_of_shards - 1):
if unsharded_shapes[i] != unsharded_shapes[self.number_of_shards - 1]:
raise ValueError(
"sharded shapes %s are not consistent shards of a full shape "
"sharded %d ways along dimension %d" % (
str(shapes), self.number_of_shards, self.shard_dimension))
return unsharded_shapes[0]

View File

@ -0,0 +1,138 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tpu_function helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
class ShardingTest(test.TestCase):
def testFreeze(self):
"""Tests that freezing a policy applies default values."""
p1 = tpu_sharding.ShardingPolicy()
p1.freeze()
self.assertEqual(p1.number_of_shards,
tpu_sharding._DEFAULT_NUMBER_OF_SHARDS)
self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION)
p2 = tpu_sharding.ShardingPolicy()
p2.set_number_of_shards(17)
p2.set_shard_dimension(23)
p2.freeze()
self.assertEqual(p2.number_of_shards, 17)
self.assertEqual(p2.shard_dimension, 23)
def testFrozen(self):
"""Tests that frozen policies can't be changed."""
p1 = tpu_sharding.ShardingPolicy()
p1.freeze()
with self.assertRaises(ValueError):
p1.set_number_of_shards(17)
with self.assertRaises(ValueError):
p1.set_shard_dimension(22)
def testStr(self):
"""Tests the string representation."""
p1 = tpu_sharding.ShardingPolicy()
self.assertEqual(str(p1), "ShardingPolicy(unset)")
p1.set_number_of_shards(17)
self.assertEqual(str(p1), "ShardingPolicy(unset)")
p1.set_shard_dimension(8)
self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
def testMerge(self):
"""Tests that merging works."""
p1 = tpu_sharding.ShardingPolicy()
p1.set_number_of_shards(17)
p1.set_shard_dimension(23)
p2 = tpu_sharding.ShardingPolicy()
p2.merge(p1)
self.assertEqual(p2.number_of_shards, 17)
self.assertEqual(p2.shard_dimension, 23)
p1 = tpu_sharding.ShardingPolicy()
p1.set_shard_dimension(12)
p2.merge(p1)
self.assertEqual(p2.number_of_shards, 17)
self.assertEqual(p2.shard_dimension, 12)
p2.freeze()
p2.merge(p1)
self.assertEqual(p2.number_of_shards, 17)
self.assertEqual(p2.shard_dimension, 12)
p1.set_number_of_shards(1)
with self.assertRaises(ValueError):
p2.merge(p1)
p1 = tpu_sharding.ShardingPolicy()
p1.set_number_of_shards(17)
p2.merge(p1)
p1.set_shard_dimension(2)
with self.assertRaises(ValueError):
p2.merge(p1)
def testGetShardedShape(self):
"""Tests getting a sharded shape."""
p = tpu_sharding.ShardingPolicy()
p.set_number_of_shards(3)
p.set_shard_dimension(1)
self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3])
p.freeze()
with self.assertRaises(ValueError):
p.set_shard_dimension(0)
with self.assertRaises(ValueError):
_ = p.get_sharded_shape([4, 9], shard_index=4)
with self.assertRaises(ValueError):
_ = p.get_sharded_shape([4, 9], shard_index=-1)
with self.assertRaises(TypeError):
_ = p.get_sharded_shape("not_a_shape")
with self.assertRaises(ValueError):
_ = p.get_sharded_shape(tensor_shape.TensorShape(None))
with self.assertRaises(ValueError):
_ = p.get_sharded_shape([4, 10], shard_index=-1)
def testGetUnshardedShape(self):
"""Tests getting an unsharded shape."""
p = tpu_sharding.ShardingPolicy()
p.set_number_of_shards(2)
p.set_shard_dimension(1)
self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6])
with self.assertRaises(ValueError):
_ = p.get_unsharded_shape([[4, 3]])
with self.assertRaises(ValueError):
_ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]])
with self.assertRaises(ValueError):
_ = p.get_unsharded_shape([[4, 3], [4, 2]])
with self.assertRaises(TypeError):
_ = p.get_unsharded_shape([[4, 3], "not_a_shape"])
with self.assertRaises(ValueError):
_ = p.get_unsharded_shape([None, [4, 3]])
with self.assertRaises(ValueError):
_ = p.get_unsharded_shape([[2], [4, 3]])
def testScalar(self):
"""Tests sharding and unsharding scalars."""
p = tpu_sharding.ShardingPolicy()
p.freeze()
self.assertEqual(p.get_sharded_shape([]), [])
self.assertEqual(p.get_unsharded_shape([[]]), [])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,213 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Library for constructing a training loop, suitable for TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
`condition` and `body` take the current value of the loop-carried
tensors. 'body' additionally takes a tuple of infeed from
infeed_queue if infeed_queue is not None. `condition` must return a
single boolean value that determines whether iteration
continues. `body` must return an updated list of values for the
loop-carried tensors.
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: an optional name for the loop.
Returns:
The final values of the loop-carried tensors.
Raises:
TypeError: if body or condition has the wrong signature.
"""
# Converts inputs to Tensors.
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
x in inputs]
input_types = [x.dtype for x in inputs]
input_arity = len(inputs)
body_arg_error = tpu_function.check_function_argument_count(
body, input_arity, infeed_queue)
if body_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
input_arity, str([i.name for i in inputs]), body_arg_error))
else:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s and %d additional inputs from "
"infeed, but the computation needs %s" % (input_arity, str(
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
body_arg_error))
condition_arg_error = tpu_function.check_function_argument_count(
condition, input_arity, None)
if condition_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
else:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s. Note that infeed is not passed to the loop "
"condition." % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
def condition_wrapper(*inputs):
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
return condition(*inputs)
def body_wrapper(*inputs):
"""Wrapper around `body` that handles infeed queues and control deps."""
inputs = list(inputs)
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
# Runs `body` with the dequeue_ops appended.
if infeed_queue:
number_of_shards = tpu_function.get_tpu_context().number_of_shards
if number_of_shards is None:
raise ValueError("Can't build training loop with infeed when there is "
"no tpu_shard_context. Are you building a loop or "
"graph directly rather than from inside tpu.rewrite, "
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
infeed_queue.set_number_of_shards(number_of_shards)
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
else:
dequeue_ops = []
outputs = body(*(inputs + dequeue_ops))
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs
if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
"TPU training loop body must return zero or more Tensor values "
"followed by zero or more Operations.")
output_types = [op.dtype for op in output_tensors]
if input_types != output_types:
raise TypeError(
"Mismatch between input types and output types for training loop "
"body: {} vs {}".format(input_types, output_types))
# Add the dequeue operations to output_operations to ensure they are run
# by the loop, even if the programmer's loop body does not use them.
output_operations += dequeue_ops
# Add a dummy output, if needed.
if not output_tensors:
output_tensors = array_ops.constant(0)
if output_operations:
# TODO(phawkins): in principle this is too restrictive since it serializes
# the training loop steps. In practice it does not matter since this loop
# will be compiled by XLA.
return control_flow_ops.tuple(output_tensors,
control_inputs=output_operations)
else:
return output_tensors
# If the body has arity 0, add a dummy loop-carried value to which we can add
# control dependencies from any side-effecting operations.
if input_arity == 0:
inputs = [array_ops.constant(0)]
return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs,
name=name)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop that executes a fixed number of interations.
The set of loop-carried tensors correspond to `inputs`.
`body` must be a function that takes and returns the values of the
loop-carried tensors.
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: an optional name for the loop.
Returns:
The final values of the loop-carried tensors.
Raises:
ValueError: if there is a type error.
"""
def _convert_to_list(xs):
if not isinstance(xs, (list, tuple)):
return [xs]
else:
return list(xs)
def cond(i, *args):
del args
return i < n
def body_wrapper(i, *args):
return [i + 1] + _convert_to_list(body(*args))
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
outputs = while_loop(
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
outputs = _convert_to_list(outputs)
if len(outputs) == 1:
# Returns the Op rather than an empty list.
return outputs[0].op
else:
return outputs[1:]

View File

@ -144,6 +144,9 @@ sh_binary(
"//tensorflow/contrib/slim:slim",
"//tensorflow/contrib/slim/python/slim/data:data_pip",
"//tensorflow/contrib/slim/python/slim/nets:nets_pip",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_helper_library",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/specs:specs",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",