Introduce Tensor Processing Unit-specific operations and estimators. This commit enables TensorFlow users to utilize functionality available in Cloud TPUs. Please note that these APIs are in alpha with Cloud TPUs, and will likely undergo significant changes during the Cloud TPU alpha and beta process.
Interested individuals and organizations should sign up for a Cloud TPU beta at https://ai.google/tools/cloud-tpus/. PiperOrigin-RevId: 159636962
This commit is contained in:
parent
a936b239cf
commit
a4660cce81
@ -304,6 +304,7 @@ filegroup(
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/text:all_files",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tpu:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/verbs:all_files",
|
||||
|
@ -71,6 +71,9 @@ py_library(
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/contrib/text:text_py",
|
||||
"//tensorflow/contrib/tfprof",
|
||||
"//tensorflow/contrib/tpu:tpu_estimator",
|
||||
"//tensorflow/contrib/tpu:tpu_helper_library",
|
||||
"//tensorflow/contrib/tpu:tpu_py",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
],
|
||||
@ -106,6 +109,7 @@ cc_library(
|
||||
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/contrib/text:all_ops",
|
||||
"//tensorflow/contrib/tpu:all_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -65,6 +65,7 @@ from tensorflow.contrib import tensor_forest
|
||||
from tensorflow.contrib import tensorboard
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.contrib import tfprof
|
||||
from tensorflow.contrib import tpu
|
||||
from tensorflow.contrib import training
|
||||
from tensorflow.contrib import util
|
||||
from tensorflow.contrib.ndlstm import python as ndlstm
|
||||
|
@ -84,6 +84,12 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/cross_replica_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/infeed_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/outfeed_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/replication_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc"
|
||||
)
|
||||
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
|
||||
endif(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
|
@ -67,6 +67,10 @@ file(GLOB_RECURSE tensor_forest_hybrid_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE tpu_ops_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/*.cc"
|
||||
)
|
||||
|
||||
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
|
||||
@ -83,6 +87,7 @@ GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensor
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
|
||||
|
||||
########################################################
|
||||
|
@ -515,6 +515,11 @@ add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not impl
|
||||
#add_python_module("tensorflow/contrib/tfprof/python")
|
||||
#add_python_module("tensorflow/contrib/tfprof/python/tools")
|
||||
#add_python_module("tensorflow/contrib/tfprof/python/tools/tfprof")
|
||||
add_python_module("tensorflow/contrib/tpu")
|
||||
add_python_module("tensorflow/contrib/tpu/ops")
|
||||
add_python_module("tensorflow/contrib/tpu/python")
|
||||
add_python_module("tensorflow/contrib/tpu/python/ops")
|
||||
add_python_module("tensorflow/contrib/tpu/python/tpu")
|
||||
add_python_module("tensorflow/contrib/training")
|
||||
add_python_module("tensorflow/contrib/training/python")
|
||||
add_python_module("tensorflow/contrib/training/python/training")
|
||||
|
224
tensorflow/contrib/tpu/BUILD
Normal file
224
tensorflow/contrib/tpu/BUILD
Normal file
@ -0,0 +1,224 @@
|
||||
# Description: Operations defined for Cloud TPUs
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/brain:__subpackages__",
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
cc_library(
|
||||
name = "all_ops",
|
||||
deps = [
|
||||
":cross_replica_ops_op_lib",
|
||||
":infeed_ops_op_lib",
|
||||
":outfeed_ops_op_lib",
|
||||
":replication_ops_op_lib",
|
||||
":tpu_configuration_ops_op_lib",
|
||||
":tpu_sendrecv_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_estimator",
|
||||
srcs = [
|
||||
"python/tpu/tpu_config.py",
|
||||
"python/tpu/tpu_estimator.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu",
|
||||
":tpu_py",
|
||||
":training_loop",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"cross_replica_ops",
|
||||
"infeed_ops",
|
||||
"outfeed_ops",
|
||||
"replication_ops",
|
||||
"tpu_configuration_ops",
|
||||
"tpu_sendrecv_ops",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_tpu_ops.so",
|
||||
srcs = [
|
||||
"ops/cross_replica_ops.cc",
|
||||
"ops/infeed_ops.cc",
|
||||
"ops/outfeed_ops.cc",
|
||||
"ops/replication_ops.cc",
|
||||
"ops/tpu_configuration_ops.cc",
|
||||
"ops/tpu_sendrecv_ops.cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "tpu_ops",
|
||||
deps = [
|
||||
":cross_replica_ops_op_lib",
|
||||
":infeed_ops_op_lib",
|
||||
":outfeed_ops_op_lib",
|
||||
":replication_ops_op_lib",
|
||||
":tpu_configuration_ops_op_lib",
|
||||
":tpu_sendrecv_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "tpu_py",
|
||||
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
|
||||
dso = [":python/ops/_tpu_ops.so"],
|
||||
kernels = [
|
||||
":all_ops",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_helper_library",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu",
|
||||
":tpu_feed",
|
||||
":tpu_function",
|
||||
":tpu_py",
|
||||
":tpu_sharding",
|
||||
":training_loop",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_function",
|
||||
srcs = ["python/tpu/tpu_function.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_feed",
|
||||
":tpu_py",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu",
|
||||
srcs = [
|
||||
"python/tpu/__init__.py",
|
||||
"python/tpu/tpu.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_py",
|
||||
":training_loop",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_sharding",
|
||||
srcs = ["python/tpu/tpu_sharding.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_feed",
|
||||
srcs = ["python/tpu/tpu_feed.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_py",
|
||||
":tpu_sharding",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "training_loop",
|
||||
srcs = [
|
||||
"python/tpu/tpu_optimizer.py",
|
||||
"python/tpu/training_loop.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tpu_function",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_sharding_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_sharding_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_sharding",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_infeed_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_infeed_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_feed",
|
||||
":tpu_sharding",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_function_test",
|
||||
size = "small",
|
||||
srcs = ["python/tpu/tpu_function_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_function",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
28
tensorflow/contrib/tpu/__init__.py
Normal file
28
tensorflow/contrib/tpu/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Ops related to Tensor Processing Units."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
|
||||
from tensorflow.contrib.tpu.python.tpu import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
37
tensorflow/contrib/tpu/ops/cross_replica_ops.cc
Normal file
37
tensorflow/contrib/tpu/ops/cross_replica_ops.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("CrossReplicaSum")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
An Op to sum inputs across replicated TPU instances. Each
|
||||
instance supplies its own input, and the output of each is the sum of
|
||||
all the inputs.
|
||||
|
||||
input: The local input to the sum.
|
||||
output: The sum of all the distributed inputs.
|
||||
T: The type of elements to be summed.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
107
tensorflow/contrib/tpu/ops/infeed_ops.cc
Normal file
107
tensorflow/contrib/tpu/ops/infeed_ops.cc
Normal file
@ -0,0 +1,107 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("InfeedDequeue")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
||||
TensorShapeProto shape_proto;
|
||||
shape.AsProto(&shape_proto);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A placeholder op for a value that will be fed into the computation.
|
||||
|
||||
output: A tensor that will be provided using the infeed mechanism.
|
||||
dtype: The type of elements in the tensor.
|
||||
shape: The shape of the tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("InfeedEnqueue")
|
||||
.Input("input: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape = {}")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op which feeds a single Tensor value into the computation.
|
||||
|
||||
input: A tensor that will be provided using the infeed mechanism.
|
||||
dtype: The type of elements in the tensor.
|
||||
shape: The shape of the tensor.
|
||||
device_ordinal: The TPU device to use. This should be -1 when the Op
|
||||
is running on a TPU device, and >= 0 when the Op is running on the CPU
|
||||
device.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("InfeedEnqueueTuple")
|
||||
.Input("inputs: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.Attr("shapes: list(shape)")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op which feeds multiple Tensor values into the computation as an XLA tuple.
|
||||
|
||||
inputs: A list of tensors that will be provided using the infeed mechanism.
|
||||
dtypes: The element types of each element in `inputs`.
|
||||
shapes: The shapes of each tensor in `inputs`.
|
||||
device_ordinal: The TPU device to use. This should be -1 when the Op
|
||||
is running on a TPU device, and >= 0 when the Op is running on the CPU
|
||||
device.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("InfeedDequeueTuple")
|
||||
.Output("outputs: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.Attr("shapes: list(shape)")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
TensorShapeProto shape_proto;
|
||||
shapes[i].AsProto(&shape_proto);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
||||
c->set_output(i, out);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A placeholder op for multiple values that will be fed into the computation
|
||||
simultaneously as an XLA tuple.
|
||||
|
||||
outputs: A list of tensors that will be provided using the infeed mechanism.
|
||||
dtypes: The element types of each element in `outputs`.
|
||||
shapes: The shapes of each tensor in `outputs`.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
106
tensorflow/contrib/tpu/ops/outfeed_ops.cc
Normal file
106
tensorflow/contrib/tpu/ops/outfeed_ops.cc
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("OutfeedEnqueue")
|
||||
.Input("input: dtype")
|
||||
.Attr("dtype: type")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op which emits a single Tensor value from an XLA computation.
|
||||
|
||||
input: A tensor that will be inserted into the outfeed queue.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OutfeedEnqueueTuple")
|
||||
.Input("inputs: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op which emits multiple Tensor values from an XLA computation.
|
||||
|
||||
inputs: A list of tensors that will be inserted into the outfeed queue as an
|
||||
XLA tuple.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OutfeedDequeue")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Retrieves a single tensor from the computation outfeed. This operation will
|
||||
block indefinitely until data is available.
|
||||
|
||||
output: A tensor that will be read from the device outfeed.
|
||||
dtype: The type of elements in the tensor.
|
||||
shape: The shape of the tensor.
|
||||
device_ordinal: The TPU device to use. This should be -1 when the Op
|
||||
is running on a TPU device, and >= 0 when the Op is running on the CPU
|
||||
device.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OutfeedDequeueTuple")
|
||||
.Output("outputs: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.Attr("shapes: list(shape)")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
std::vector<DataType> dtypes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
|
||||
if (shapes.size() != dtypes.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Incorrect number of output shapes specified");
|
||||
}
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
|
||||
c->set_output(i, out);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Retrieve multiple values that will be emitted by the computation as an XLA
|
||||
tuple. This operations will block indefinitely until data is available.
|
||||
Output `i` corresponds to XLA tuple element `i`.
|
||||
|
||||
outputs: A list of tensors that will be read from the outfeed.
|
||||
dtypes: The element types of each element in `outputs`.
|
||||
shapes: The shapes of each tensor in `outputs`.
|
||||
device_ordinal: The TPU device to use. This should be -1 when the Op
|
||||
is running on a TPU device, and >= 0 when the Op is running on the CPU
|
||||
device.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
87
tensorflow/contrib/tpu/ops/replication_ops.cc
Normal file
87
tensorflow/contrib/tpu/ops/replication_ops.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("TPUReplicatedInput")
|
||||
.Input("inputs: N * T")
|
||||
.Output("output: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle cur = c->input(c->num_inputs() - 1);
|
||||
for (int i = c->num_inputs() - 2; i >= 0; --i) {
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
|
||||
"From merging shape ", i,
|
||||
" with other shapes.");
|
||||
}
|
||||
c->set_output(0, cur);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(
|
||||
"Operator that connects N unreplicated inputs to an N-way "
|
||||
"replicated TPU computation.");
|
||||
|
||||
REGISTER_OP("TPUReplicatedOutput")
|
||||
.Input("input: T")
|
||||
.Output("outputs: num_replicas * T")
|
||||
.Attr("num_replicas: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(0));
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(
|
||||
"Operator that connects the output of an N-way replicated TPU "
|
||||
"computation to N separate outputs.");
|
||||
|
||||
REGISTER_OP("TPUReplicate")
|
||||
.Attr("computation: func")
|
||||
.Attr("num_replicas: int >= 1")
|
||||
.Attr("global_tpu_id: list(int) = []")
|
||||
.Attr("Tinputs: list(type) >= 0")
|
||||
.Attr("Tbroadcast_inputs: list(type) >= 0")
|
||||
.Attr("NumVariables: int >= 0")
|
||||
.Attr("output_types: list(type) >= 0")
|
||||
.Input("inputs: Tinputs")
|
||||
.Input("broadcast_inputs: Tbroadcast_inputs")
|
||||
.Input("variables: NumVariables * resource")
|
||||
.Output("outputs: output_types")
|
||||
.Doc(R"doc(
|
||||
Runs replicated computations on a distributed TPU system.
|
||||
|
||||
computation: a function containing the computation to run.
|
||||
num_replicas: the number of replicas of the computation to run.
|
||||
global_tpu_id: map from device to global tpu id.
|
||||
Tinputs: the types of the arguments to 'computation'.
|
||||
inputs: the inputs to 'computation', flattened, in replica-major order.
|
||||
Tbroadcast_inputs: the types of the additional arguments to broadcast to all
|
||||
replicas.
|
||||
broadcast_inputs: additional arguments to broadcast to all replicas. The
|
||||
broadcast inputs are appended to the per-replica inputs when calling
|
||||
computation.
|
||||
output_types: the types of the outputs of 'computation'.
|
||||
outputs: the outputs of 'computation'.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
213
tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
Normal file
213
tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
Normal file
@ -0,0 +1,213 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
// Configuring a distributed TPU system is achieved by running
|
||||
// the following Ops:
|
||||
//
|
||||
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
|
||||
// host. This is needed in case the system had previously been
|
||||
// configured. It returns, for each host, the number of TPU chips on
|
||||
// the host.
|
||||
//
|
||||
// 2 Run _ConfigureDistributedTPU on TPU_SYSTEM. Takes as input the
|
||||
// number of chips on each host. Validates that all hosts have the
|
||||
// same number of chips, and that the chips are consistent with the
|
||||
// topology set by flags. Has a single output which is a proto
|
||||
// describing the requested system configuration, which is sent to all
|
||||
// hosts.
|
||||
//
|
||||
// 3 Run _InitializeHostForDistributedTPU on the CPU of each host,
|
||||
// taking as input the output from ConfigureDistributedTPU. Has a
|
||||
// single Tensor output which is a vector of int32 indicating, for
|
||||
// each TPU on the host, what its global TPU system id is.
|
||||
//
|
||||
// 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the
|
||||
// outputs from all the _InitializeHostForDistributedTPU
|
||||
// Ops. _WaitForDistributedTPU has an attr host_specs which is a
|
||||
// vector<string> giving the partial device spec for each host. These
|
||||
// partial specs are combined in the Op with the outputs from the host
|
||||
// initialization Ops to construct a mapping from full TPU device
|
||||
// specs to global TPU ids. Has a single Tensor output which is a
|
||||
// matrix of int32 indicating, for each host (outer dimension) and for
|
||||
// each TPU on the host (inner dimension) what that TPU's global id
|
||||
// is. _WaitForDistributedTPU also waits for the TPU distributed
|
||||
// system to initialize fully, which may take several minutes for a
|
||||
// large system.
|
||||
//
|
||||
// 5 Run _SetGlobalTPUArray on the CPU of each host, taking as input
|
||||
// the output from _WaitForDistributedTPU. This Op tells each host the
|
||||
// global Id of every TPU on every host.
|
||||
//
|
||||
// Most user code works by placing the ConfigureDistributedTPU Op on
|
||||
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
|
||||
// the subgraph described above.
|
||||
//
|
||||
//
|
||||
// A distributed TPU system can be cleanly shut down by running
|
||||
// the following Ops:
|
||||
//
|
||||
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
|
||||
// host.
|
||||
//
|
||||
// 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where
|
||||
// _ConfigureDistributedTPU was run. The Op will return an error if no
|
||||
// system is configured.
|
||||
//
|
||||
//
|
||||
// Most user code works by placing the ShutdownDistributedTPU Op on
|
||||
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
|
||||
// the subgraph described above.
|
||||
|
||||
REGISTER_OP("_ConfigureDistributedTPU")
|
||||
.Input("inputs: N * int32")
|
||||
.Output("output: string")
|
||||
.Attr("N: int >= 1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
// Validate that all the inputs are scalars.
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input));
|
||||
}
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that sets up the centralized structures for a distributed TPU
|
||||
system.
|
||||
|
||||
inputs: A scalar tensor for each host indicating how many TPU chips
|
||||
there are on the host.
|
||||
output: A tensor containing a TPUHostConfiguration proto serialized to
|
||||
a string, containing the information necessary to initialize the chips
|
||||
in a host.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_WaitForDistributedTPU")
|
||||
.Input("inputs: N * int32")
|
||||
.Output("global_tpu_array: int32")
|
||||
.Attr("host_specs: list(string)")
|
||||
.Attr("startup_timeout_sec: int = 20")
|
||||
.Attr("N: int")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
// Validate that all the inputs have the same vector shape.
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input));
|
||||
}
|
||||
c->set_output(0, c->UnknownShapeOfRank(2));
|
||||
return ::tensorflow::Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that blocks execution until a distributed TPU system has
|
||||
started up. This Op must be run on the same TPU_SYSTEM device as
|
||||
_ConfigureDistributedTPU, and takes an inputs the outputs from the
|
||||
_InitializeHostForDistributedTPU Ops.
|
||||
|
||||
inputs: For each initialized host, a vector giving the global TPU id
|
||||
of each TPU on the host.
|
||||
global_tpu_array: A two-dimensional array. For each host (the outer
|
||||
dimension) the array lists the global ids of the TPUs on that host.
|
||||
host_specs: For each initialized host, the partial device specification
|
||||
indicating job, replica, and task. Combining this spec with
|
||||
'/device:TPU:k' gives the full device name of the k'th TPU on the
|
||||
host.
|
||||
startup_timeout_sec: The number of seconds to wait for the TPU system
|
||||
to stabilize.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_SetGlobalTPUArray")
|
||||
.Input("global_tpu_array: int32")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
|
||||
return ::tensorflow::Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that informs a host of the global ids of all the of TPUs in the
|
||||
system.
|
||||
|
||||
global_tpu_array: A two-dimensional array. For each host (the outer
|
||||
dimension) the array lists the global ids of the TPUs on that host.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
||||
An op that shuts down a running distributed TPU system. The Op returns
|
||||
an error if no system is running. This Op must be run on the same
|
||||
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
|
||||
to start the system, and must be run only after
|
||||
_DisconnectHostFromDistributedTPUSystem has completed on every host in
|
||||
the system.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_InitializeHostForDistributedTPU")
|
||||
.Input("input: string")
|
||||
.Output("tpu_ids: int32")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
return ::tensorflow::Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
An op that connects each chip on the host to a centralized UberDriver to allow
|
||||
them to operate as a distributed system with chips in other hosts.
|
||||
|
||||
input: A string containing the address of the UberDriver to connect to.
|
||||
tpu_ids: A vector containing the global TPU id of each TPU on the host.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
|
||||
.Output("number_of_tpu_chips: int32")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op that disconnects the TPUs on a host from a running distributed
|
||||
TPU system.
|
||||
|
||||
number_of_tpu_chips: A scalar tensor containing the number of TPU
|
||||
chips on the host.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ConfigureDistributedTPU")
|
||||
.Output("global_tpu_array: int32")
|
||||
.Attr("embedding_config: string = ''")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
An op that sets up the centralized structures for a distributed TPU
|
||||
system.
|
||||
|
||||
global_tpu_array: A two-dimensional array. For each host (the outer
|
||||
dimension) the array lists the global ids of the TPUs on that host.
|
||||
embedding_config: Internal use.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
||||
An op that shuts down a running distributed TPU system. The Op returns
|
||||
an error if no system is running.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
46
tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc
Normal file
46
tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("_TPUSend")
|
||||
.Input("tensor: T")
|
||||
.Attr("T: type")
|
||||
.Attr("tensor_name: string")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor over the TPU fabric.
|
||||
|
||||
tensor: The tensor to send.
|
||||
tensor_name: The name of the tensor to send.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_TPURecv")
|
||||
.Output("tensor: T")
|
||||
.Attr("T: type")
|
||||
.Attr("tensor_name: string")
|
||||
.Attr("shape: shape")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor over the TPU fabric.
|
||||
|
||||
tensor: The tensor to receive.
|
||||
tensor_name: The name of the tensor to receive.
|
||||
shape: The shape of the input tensor.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
38
tensorflow/contrib/tpu/python/ops/tpu_ops.py
Normal file
38
tensorflow/contrib/tpu/python/ops/tpu_ops.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Operations for TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.contrib.tpu.ops.gen_tpu_ops import *
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
_tpu_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_tpu_ops.so"))
|
||||
else:
|
||||
# We have already built the appropriate libraries into the binary via CMake
|
||||
# if we have built contrib, so we don't need this
|
||||
pass
|
20
tensorflow/contrib/tpu/python/tpu/__init__.py
Normal file
20
tensorflow/contrib/tpu/python/tpu/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Ops related to Tensor Processing Units."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
583
tensorflow/contrib/tpu/python/tpu/tpu.py
Normal file
583
tensorflow/contrib/tpu/python/tpu/tpu.py
Normal file
@ -0,0 +1,583 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ======================================
|
||||
|
||||
"""Library of TPU helper functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
def initialize_system(embedding_config=None, job=None):
|
||||
"""Initializes a distributed TPU system for use with TensorFlow.
|
||||
|
||||
Args:
|
||||
embedding_config: If not None, an EmbeddingLayerConfiguration proto
|
||||
describing the desired configuration of the hardware embedding lookup
|
||||
tables. If embedding_config is None, no hardware embeddings can be used.
|
||||
job: The job (the XXX in TensorFlow device specification /job:XXX)
|
||||
that contains the TPU devices that will be initialized. If job=None
|
||||
it is assumed there is only one job in the TensorFlow flock, and an
|
||||
error will be returned if this assumption does not hold.
|
||||
Returns:
|
||||
Op which, when executed, will initialize the system.
|
||||
"""
|
||||
if job is None:
|
||||
device_name = "/replica:0/task:0/device:TPU_SYSTEM:0"
|
||||
else:
|
||||
device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job
|
||||
config_string = ("" if embedding_config is None else
|
||||
embedding_config.SerializeToString())
|
||||
with ops.device(device_name):
|
||||
init_distributed_tpu = tpu_ops.configure_distributed_tpu(
|
||||
embedding_config=config_string)
|
||||
return init_distributed_tpu
|
||||
|
||||
|
||||
def shutdown_system(job=None):
|
||||
"""Shuts down a running a distributed TPU system."""
|
||||
if job is None:
|
||||
device_name = "/replica:0/task:0/device:TPU_SYSTEM:0"
|
||||
else:
|
||||
device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job
|
||||
with ops.device(device_name):
|
||||
shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
|
||||
return shutdown_distributed_tpu
|
||||
|
||||
|
||||
def core(num):
|
||||
"""Returns the device name for a core in a replicated TPU computation.
|
||||
|
||||
Args:
|
||||
num: the virtual core number within each replica to which operators should
|
||||
be assigned.
|
||||
Returns:
|
||||
A device name, suitable for passing to tf.device().
|
||||
"""
|
||||
return "device:TPU_REPLICATED_CORE:{}".format(num)
|
||||
|
||||
|
||||
# Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.) context.
|
||||
# In
|
||||
#
|
||||
# XXX
|
||||
# with tpu.rewrite(...):
|
||||
# YYY
|
||||
# with tpu.outside_all_rewrites():
|
||||
# ZZZ
|
||||
#
|
||||
# the Ops in ZZZ are added outside the scope of the rewrite().
|
||||
# TODO(phawkins): currently outside_all_rewrites() pops out of all nested
|
||||
# control flow scopes, for example loops. It would make more sense if it only
|
||||
# popped out of a single scope.
|
||||
@contextlib.contextmanager
|
||||
def outside_all_rewrites():
|
||||
"""Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.)."""
|
||||
with ops.control_dependencies(None):
|
||||
yield
|
||||
|
||||
|
||||
class TPUReplicateContext(control_flow_ops.ControlFlowContext):
|
||||
"""A ControlFlowContext for nodes inside a TPU computation.
|
||||
|
||||
The primary role of TPUReplicateContext is to mark operators inside a
|
||||
tpu.replicate() computation with attributes:
|
||||
* _tpu_replicate=XYZ, where XYZ is a unique name, and
|
||||
* _tpu_num_replicas=k, where k is the number of replicas.
|
||||
|
||||
We use a ControlFlowContext to perform the annotation since it
|
||||
integrates with Tensorflow constructs like ResourceVariables. For example,
|
||||
if a ResourceVariable is constructed inside a tpu.replicate() block, the
|
||||
ResourceVariable implementation can use "with ops.control_dependencies(None)"
|
||||
to build the variable's definition outside the replicated computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, num_replicas, global_tpu_id=None):
|
||||
control_flow_ops.ControlFlowContext.__init__(self)
|
||||
self._name = name
|
||||
self._num_replicas = num_replicas
|
||||
self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id
|
||||
|
||||
def AddOp(self, op):
|
||||
self._AddOpInternal(op)
|
||||
|
||||
def _AddOpInternal(self, op):
|
||||
# pylint: disable=protected-access
|
||||
if any(x.dtype._is_ref_dtype for x in op.inputs):
|
||||
raise NotImplementedError(
|
||||
"Non-resource Variables are not supported inside TPU computations "
|
||||
"(operator name: %s)" % op.name)
|
||||
# pylint: enable=protected-access
|
||||
if "_tpu_replicate" in op.node_def.attr:
|
||||
raise ValueError("TPU computations cannot be nested")
|
||||
op.node_def.attr["_tpu_replicate"].s = self._name
|
||||
op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas
|
||||
op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id)
|
||||
op.graph.prevent_feeding(op)
|
||||
op.graph.prevent_fetching(op)
|
||||
|
||||
def AddValue(self, val):
|
||||
result = val
|
||||
if self._outer_context:
|
||||
result = self._outer_context.AddValue(val)
|
||||
return result
|
||||
|
||||
def AddInnerOp(self, op):
|
||||
self._AddOpInternal(op)
|
||||
if self._outer_context:
|
||||
self._outer_context.AddInnerOp(op)
|
||||
|
||||
|
||||
def replicate(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
global_tpu_id=None,
|
||||
name=None):
|
||||
"""Builds a graph operator that runs a replicated TPU computation.
|
||||
|
||||
Args:
|
||||
computation: a Python function that builds the computation to replicate.
|
||||
inputs: a list of lists of input tensors or None (equivalent to
|
||||
[[]]), indexed by [replica_num][input_num]. All replicas must
|
||||
have the same number of inputs.
|
||||
infeed_queue: if not None, the InfeedQueue from which to append a tuple
|
||||
of arguments as inputs to computation.
|
||||
global_tpu_id: if not None, a Numpy 2D array indicating the global
|
||||
id of each TPU device in the system. The outer dimension of the
|
||||
array is host task id, and the inner dimension is device ordinal,
|
||||
so e.g., global_tpu_id[x][y] indicates the global id of device
|
||||
/task:x/device:TPU_NODE:y.
|
||||
name: name of the operator.
|
||||
Returns:
|
||||
A list of lists of output tensors, indexed by [replica_num][output_num].
|
||||
Raises:
|
||||
ValueError: if all replicas do not have equal numbers of input tensors.
|
||||
ValueError: if the number of inputs per replica does not match
|
||||
the number of formal parameters to `computation`.
|
||||
"""
|
||||
if name is None:
|
||||
name = "TPUReplicate"
|
||||
inputs = [[]] if inputs is None else inputs
|
||||
|
||||
if global_tpu_id is not None:
|
||||
# Turn the Numpy array into a flattened list.
|
||||
global_tpu_id = global_tpu_id.flatten().tolist()
|
||||
|
||||
if ((not isinstance(inputs, list)) or
|
||||
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
|
||||
raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
|
||||
|
||||
num_replicas = len(inputs)
|
||||
|
||||
# No replicas? Nothing to do.
|
||||
if num_replicas == 0:
|
||||
return []
|
||||
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]
|
||||
|
||||
# Verifies that all replicas have matching numbers and types of inputs
|
||||
input_types = [x.dtype for x in inputs[0]]
|
||||
input_arity = len(input_types)
|
||||
for i in range(num_replicas):
|
||||
if len(inputs[i]) != input_arity:
|
||||
raise ValueError("Replicas must have the same number of inputs. "
|
||||
"Replica 0 had {} inputs, replica {} had {} "
|
||||
"inputs.".format(input_arity, i, len(inputs[i])))
|
||||
|
||||
types = [x.dtype for x in inputs[i]]
|
||||
if types != input_types:
|
||||
raise ValueError(
|
||||
"Replicas must have matching input types. Replica 0 had "
|
||||
"input types {}, replica {} had input types {}".format(
|
||||
input_types, i, types))
|
||||
|
||||
arg_error = tpu_function.check_function_argument_count(
|
||||
computation, input_arity, infeed_queue)
|
||||
if arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied computation cannot be called with the specified inputs. "
|
||||
"You specified %d inputs: %s, but the computation needs %s" % (
|
||||
input_arity, str([i.name for i in inputs[0]]), arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied computation cannot be called with the specified inputs. "
|
||||
"You specified %d inputs: %s and %d additional inputs from infeed,"
|
||||
" but the computation needs %s" % (input_arity, str(
|
||||
[i.name
|
||||
for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
|
||||
arg_error))
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
with ops.name_scope(name, "replicate"):
|
||||
# Fan-in: Builds a TPUReplicatedInput node for each input.
|
||||
computation_inputs = []
|
||||
for i in range(0, input_arity):
|
||||
replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
|
||||
computation_inputs.append(
|
||||
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
|
||||
|
||||
context = TPUReplicateContext(
|
||||
name=graph.unique_name("cluster"),
|
||||
num_replicas=num_replicas,
|
||||
global_tpu_id=global_tpu_id)
|
||||
try:
|
||||
context.Enter()
|
||||
|
||||
with tpu_function.tpu_shard_context(num_replicas):
|
||||
|
||||
# The EncapsulateTPUComputations rewrite needs to identify the
|
||||
# replicated arguments inside each computation. Adds identity operators
|
||||
# tagged with an attribute _tpu_replicated_input to identify the
|
||||
# replicated inputs.
|
||||
# pylint: disable=protected-access
|
||||
with graph._attr_scope({"_tpu_replicated_input":
|
||||
attr_value_pb2.AttrValue(b=True)}):
|
||||
computation_inputs = [
|
||||
array_ops.identity(x, name="replicated_input_{}".format(i))
|
||||
for i, x in enumerate(computation_inputs)]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# If there is an infeed queue, adds the dequeued values to the
|
||||
# computation's inputs.
|
||||
if infeed_queue is not None:
|
||||
infeed_queue.set_number_of_shards(num_replicas)
|
||||
for t in infeed_queue.generate_dequeue_op():
|
||||
computation_inputs.append(t)
|
||||
|
||||
# Only resource variables work inside a TPU computation, so turn on
|
||||
# resource variables for the computation.
|
||||
# TODO(phawkins): consider removing this code. It will
|
||||
# be less confusing to clients if they knowingly choose to use resource
|
||||
# variables.
|
||||
vscope = variable_scope.get_variable_scope()
|
||||
saved_use_resource = vscope.use_resource
|
||||
vscope.set_use_resource(True)
|
||||
|
||||
outputs = computation(*computation_inputs)
|
||||
|
||||
vscope.set_use_resource(saved_use_resource)
|
||||
|
||||
# If the computation only returned one value, makes it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
|
||||
try:
|
||||
with ops.device(core(0)):
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"TPU function return values must all either be Operations or "
|
||||
"convertible to Tensors. Got '%s'" % str(e))
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs
|
||||
if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
"TPU functions must return zero-or more Tensor values followed by "
|
||||
"zero or more Operations.")
|
||||
output_arity = len(output_tensors)
|
||||
|
||||
# Wraps outputs in Identity ops. Otherwise a replicated input copied
|
||||
# straight to an output would bypass the replicate(). This would be bad
|
||||
# because the TPUReplicatedInput/TPUReplicatedOutput operator would not
|
||||
# be rewritten away, leading to a runtime error.
|
||||
# TODO(phawkins): extend the rewrite to elide these nodes instead.
|
||||
with ops.device(core(0)):
|
||||
output_tensors = [array_ops.identity(x) for x in output_tensors]
|
||||
finally:
|
||||
context.Exit()
|
||||
|
||||
# Fan-out: Builds a TPUReplicatedOutput node for each output.
|
||||
outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
|
||||
name="output{}".format(i))
|
||||
for i in xrange(output_arity)]
|
||||
|
||||
with ops.control_dependencies(output_operations):
|
||||
if output_arity == 0:
|
||||
# Returns a list of NoOps dependent on the replication Op, indexed by
|
||||
# [replica_num].
|
||||
return [
|
||||
control_flow_ops.no_op(name="%s_shard_%d" % (name, i))
|
||||
for i in range(num_replicas)
|
||||
]
|
||||
else:
|
||||
# Wraps the outputs in identity operators so the names of any possible
|
||||
# `fetch` nodes are preserved by the replication rewrite.
|
||||
return [
|
||||
[array_ops.identity(outputs[out][replica],
|
||||
name="output_%d_shard_%d" % (out, replica))
|
||||
for out in xrange(output_arity)]
|
||||
for replica in xrange(num_replicas)
|
||||
]
|
||||
|
||||
|
||||
def shard(computation,
|
||||
inputs=None,
|
||||
num_shards=1,
|
||||
input_shard_axes=None,
|
||||
outputs_from_all_shards=True,
|
||||
output_shard_axes=None,
|
||||
infeed_queue=None,
|
||||
global_tpu_id=None,
|
||||
name=None):
|
||||
"""Shards `computation` for parallel execution.
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty
|
||||
list), each of which has a corresponding split axis (from
|
||||
`input_shard_axes`). Each input is split into `num_shards` pieces
|
||||
along the corresponding axis, and computation is applied to each
|
||||
shard in parallel.
|
||||
|
||||
Tensors are broadcast to all shards if they are lexically captured by
|
||||
`computation`. e.g.,
|
||||
|
||||
x = tf.constant(7)
|
||||
def computation():
|
||||
return x + 3
|
||||
... = shard(computation, ...)
|
||||
|
||||
TODO(phawkins): consider adding support for broadcasting Tensors passed
|
||||
as inputs.
|
||||
|
||||
If `outputs_from_all_shards` is true, the outputs from all shards of
|
||||
`computation` are concatenated back together along their `output_shards_axes`.
|
||||
Otherwise, each output is taken from an arbitrary shard.
|
||||
|
||||
Inputs and outputs of the computation must be at least rank-1 Tensors.
|
||||
|
||||
Args:
|
||||
computation: a Python function that builds a computation to apply to each
|
||||
shard of the input.
|
||||
inputs: a list of input tensors or None (equivalent to an empty
|
||||
list). Each input tensor has a corresponding shard axes, given
|
||||
by `input_shard_axes`, which must have size divisible by
|
||||
`num_shards`.
|
||||
num_shards: the number of shards.
|
||||
input_shard_axes: a list of dimensions along which to shard `inputs`, or
|
||||
`None`. `None` means "shard all inputs along dimension 0". If not `None`,
|
||||
there must be one dimension per input.
|
||||
outputs_from_all_shards: boolean or list of boolean. For each output, if
|
||||
`True`, outputs from all shards are concatenated along the corresponding
|
||||
`output_shard_axes` entry. Otherwise, each output is taken
|
||||
from an arbitrary shard. If the argument is a boolean, the argument's
|
||||
value is used for each output.
|
||||
output_shard_axes: a list of dimensions along which to concatenate the
|
||||
outputs of `computation`, or `None`. `None` means "concatenate all outputs
|
||||
along dimension 0". If not `None`, there must be one dimension per output.
|
||||
Ignored if `outputs_from_all_shards` is False.
|
||||
infeed_queue: if not None, the InfeedQueue to use to augment the inputs of
|
||||
`computation`.
|
||||
global_tpu_id: if not None, a Numpy 2D array indicating the global
|
||||
id of each TPU device in the system. The outer dimension of the
|
||||
array is host task id, and the inner dimension is device ordinal,
|
||||
so e.g., global_tpu_id[x][y] indicates the global id of device
|
||||
/task:x/device:TPU_NODE:y.
|
||||
name: name of the operator.
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
Raises:
|
||||
ValueError: if num_shards <= 0
|
||||
ValueError: if len(input_shard_axes) != len(inputs)
|
||||
ValueError: if len(output_shard_axes) != len(outputs from `computation`)
|
||||
"""
|
||||
|
||||
if num_shards <= 0:
|
||||
raise ValueError("num_shards must be a positive integer.")
|
||||
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
|
||||
|
||||
if input_shard_axes is None:
|
||||
input_shard_axes = [0] * len(inputs)
|
||||
if len(inputs) != len(input_shard_axes):
|
||||
raise ValueError("Length of input_shard_axes must be equal to the number "
|
||||
"of inputs.")
|
||||
|
||||
if inputs:
|
||||
# Splits the `inputs` along the corresponding `input_shard_axes`, giving
|
||||
# lists with layout [input][shard]
|
||||
split_inputs = [
|
||||
array_ops.split(x, num_shards, axis=axis)
|
||||
for (axis, x) in zip(input_shard_axes, inputs)]
|
||||
|
||||
# Transposes the input lists to have layout [shard][input]
|
||||
transposed_inputs = [list(i) for i in zip(*split_inputs)]
|
||||
else:
|
||||
transposed_inputs = [[]] * num_shards
|
||||
|
||||
outputs = replicate(
|
||||
computation,
|
||||
transposed_inputs,
|
||||
infeed_queue=infeed_queue,
|
||||
global_tpu_id=global_tpu_id,
|
||||
name=name)
|
||||
|
||||
# There must be at least one shard since num_shards > 0.
|
||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||
# pylint: disable=indexing-exception
|
||||
if isinstance(outputs[0], ops.Operation):
|
||||
# pylint: enable=indexing-exception
|
||||
# There were no outputs from the computation and replicate returned a list
|
||||
# of NoOps with control dependencies on the computation. Return the first
|
||||
# one so it can be used as a control dependency or fetch node.
|
||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||
# pylint: disable=indexing-exception
|
||||
return [outputs[0]]
|
||||
# pylint: enable=indexing-exception
|
||||
|
||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||
# pylint: disable=indexing-exception
|
||||
num_outputs = len(outputs[0])
|
||||
# pylint: enable=indexing-exception
|
||||
|
||||
if output_shard_axes is None:
|
||||
output_shard_axes = [0] * num_outputs
|
||||
if num_outputs != len(output_shard_axes):
|
||||
raise ValueError("Length of output_shard_axes must be equal to the number "
|
||||
"of outputs.")
|
||||
|
||||
if isinstance(outputs_from_all_shards, bool):
|
||||
outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
|
||||
|
||||
if num_outputs != len(outputs_from_all_shards):
|
||||
raise ValueError("Length of outputs_from_all_shards must be equal to the "
|
||||
"number of outputs.")
|
||||
|
||||
results = []
|
||||
for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
|
||||
zip(*outputs)):
|
||||
if all_shards:
|
||||
# Concatenate all of the outputs together.
|
||||
results.append(array_ops.concat(list(x), axis=axis))
|
||||
else:
|
||||
# TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
|
||||
results.append(x[0])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def batch_parallel(computation,
|
||||
inputs=None,
|
||||
num_shards=1,
|
||||
infeed_queue=None,
|
||||
global_tpu_id=None,
|
||||
name=None):
|
||||
"""Shards `computation` along the batch dimension for parallel execution.
|
||||
|
||||
Convenience wrapper around shard().
|
||||
|
||||
`inputs` must be a list of Tensors or None (equivalent to an empty
|
||||
list). Each input is split into `num_shards` pieces along the 0-th
|
||||
dimension, and computation is applied to each shard in parallel.
|
||||
|
||||
Tensors are broadcast to all shards if they are lexically captured by
|
||||
`computation`. e.g.,
|
||||
|
||||
x = tf.constant(7)
|
||||
def computation():
|
||||
return x + 3
|
||||
... = shard(computation, ...)
|
||||
|
||||
The outputs from all shards are concatenated back together along their 0-th
|
||||
dimension.
|
||||
|
||||
Inputs and outputs of the computation must be at least rank-1 Tensors.
|
||||
|
||||
Args:
|
||||
computation: a Python function that builds a computation to apply to each
|
||||
shard of the input.
|
||||
inputs: a list of input tensors or None (equivalent to an empty
|
||||
list). The 0-th dimension of each Tensor must have size
|
||||
divisible by `num_shards`.
|
||||
num_shards: the number of shards.
|
||||
infeed_queue: if not None, the InfeedQueue from which to append a tuple
|
||||
of arguments as inputs to `computation`.
|
||||
global_tpu_id: if not None, a Numpy 2D array indicating the global
|
||||
id of each TPU device in the system. The outer dimension of the
|
||||
array is host task id, and the inner dimension is device ordinal,
|
||||
so e.g., global_tpu_id[x][y] indicates the global id of device
|
||||
/task:x/device:TPU_NODE:y.
|
||||
name: name of the operator.
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
Raises:
|
||||
ValueError: if num_shards <= 0
|
||||
"""
|
||||
return shard(
|
||||
computation,
|
||||
inputs,
|
||||
num_shards=num_shards,
|
||||
infeed_queue=infeed_queue,
|
||||
global_tpu_id=global_tpu_id,
|
||||
name=name)
|
||||
|
||||
|
||||
def rewrite(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
global_tpu_id=None,
|
||||
name=None):
|
||||
"""Rewrites `computation` for execution on a TPU system.
|
||||
|
||||
Args:
|
||||
computation: a Python function that builds a computation to apply
|
||||
to the input. If the function takes n inputs, 'inputs' should be
|
||||
a list of n tensors. If the function returns m outputs, rewrite
|
||||
will return a list of m tensors.
|
||||
inputs: a list of input tensors or None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the InfeedQueue from which to append a tuple
|
||||
of arguments as inputs to `computation`.
|
||||
global_tpu_id: if not None, a Numpy 2D array indicating the global
|
||||
id of each TPU device in the system. The outer dimension of the
|
||||
array is host task id, and the inner dimension is device ordinal,
|
||||
so e.g., global_tpu_id[x][y] indicates the global id of device
|
||||
/task:x/device:TPU_NODE:y.
|
||||
name: name of the operator.
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
"""
|
||||
if inputs is not None and not isinstance(inputs, (list, tuple)):
|
||||
raise TypeError("tpu.rewrite() inputs must be a list or tuple")
|
||||
|
||||
# TODO(b/36647078) remove disable when pylint bug is fixed.
|
||||
# pylint: disable=indexing-exception
|
||||
return replicate(
|
||||
computation,
|
||||
None if inputs is None else [inputs],
|
||||
infeed_queue=infeed_queue,
|
||||
global_tpu_id=global_tpu_id,
|
||||
name=name)[0]
|
||||
# pylint: enable=indexing-exception
|
47
tensorflow/contrib/tpu/python/tpu/tpu_config.py
Normal file
47
tensorflow/contrib/tpu/python/tpu/tpu_config.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""A RunConfig subclass with TPU support."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||
|
||||
|
||||
class TpuConfig(collections.namedtuple(
|
||||
'TpuConfig', ['iterations_per_loop', 'num_shards'])):
|
||||
"""TPU related configuration required by `TPUEstimator`."""
|
||||
|
||||
def __new__(cls, iterations_per_loop=2, num_shards=2):
|
||||
return super(TpuConfig, cls).__new__(
|
||||
cls,
|
||||
iterations_per_loop=iterations_per_loop,
|
||||
num_shards=num_shards)
|
||||
|
||||
|
||||
class RunConfig(run_config_lib.RunConfig):
|
||||
"""RunConfig with TPU support."""
|
||||
|
||||
def __init__(self, tpu_config=None, **kwargs):
|
||||
super(RunConfig, self).__init__(**kwargs)
|
||||
self._tpu_config = tpu_config or TpuConfig()
|
||||
|
||||
@property
|
||||
def tpu_config(self):
|
||||
return self._tpu_config
|
361
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
Normal file
361
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
Normal file
@ -0,0 +1,361 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Tpu Estimator class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_config
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.contrib.tpu.python.tpu import training_loop
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
def _tpu_job(run_config):
|
||||
# The tpu job is determined by the run_config. Right now, this method is
|
||||
# required as tpu_config is not part of the RunConfig.
|
||||
return None if run_config.master in ['', 'local'] else 'tpu_worker'
|
||||
|
||||
|
||||
class _SIGNAL(object):
|
||||
"""Signal used to control the input thread of infeed."""
|
||||
NEXT_BATCH = 1
|
||||
STOP = 2
|
||||
|
||||
|
||||
class InfeedThreadController(object):
|
||||
"""This wraps the infeed thread and stops when Estimator train finishes.
|
||||
|
||||
For model_fn wrapper, it is not possible to know when the `train` API will
|
||||
stop. It could be the cases that the `max_steps` is reached or some hook
|
||||
requests the stop in the monitored_session.
|
||||
|
||||
This controller (with coordination with `TpuInfeedSessionHook`) does the
|
||||
following:
|
||||
|
||||
1) It pre-infeeds one `batch` data for current TPU iterations.
|
||||
|
||||
2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch`
|
||||
data will be infed.
|
||||
|
||||
3) When `end` of `TpuInfeedSessionHook` is called, the thread will end
|
||||
gracefully.
|
||||
|
||||
So, we might need to adjust the algorithrm here if the IO is slower than the
|
||||
computation.
|
||||
"""
|
||||
|
||||
def __init__(self, session, enqueue_ops, iterations):
|
||||
self._signal_queue = Queue.Queue()
|
||||
self._input_thd = threading.Thread(target=self._input_thread_fn_for_loading,
|
||||
args=(session, enqueue_ops, iterations))
|
||||
self._input_thd.daemon = True
|
||||
self._input_thd.start()
|
||||
|
||||
def _input_thread_fn_for_loading(self, session, enqueue_ops, iterations):
|
||||
count = 0
|
||||
while True:
|
||||
signal = self._signal_queue.get()
|
||||
if signal == _SIGNAL.STOP:
|
||||
logging.info('Stop Infeed input thread.')
|
||||
return
|
||||
|
||||
for i in range(iterations):
|
||||
logging.debug('InfeedEnqueue data for iteration (%d, %d)', count, i)
|
||||
session.run(enqueue_ops)
|
||||
count += 1
|
||||
|
||||
def load_next_batch(self):
|
||||
self._signal_queue.put(_SIGNAL.NEXT_BATCH)
|
||||
|
||||
def join(self):
|
||||
logging.info('Waiting for InputThread to exit.')
|
||||
self._signal_queue.put(_SIGNAL.STOP)
|
||||
self._input_thd.join()
|
||||
|
||||
|
||||
class TpuInfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
"""A Session hook setting up the TPU initialization and infeed.
|
||||
|
||||
This hook does two major things:
|
||||
1. initialize and shutdown TPU system (maybe a separated hook)
|
||||
2. launch and join the input thread for infeed.
|
||||
"""
|
||||
|
||||
def __init__(self, run_config, enqueue_fn):
|
||||
self._iterations = run_config.tpu_config.iterations_per_loop
|
||||
self._enqueue_fn = enqueue_fn
|
||||
self._tpu_job = _tpu_job(run_config)
|
||||
|
||||
def begin(self):
|
||||
self._enqueue_ops = self._enqueue_fn()
|
||||
logging.info('TPU job name %s', self._tpu_job)
|
||||
self._init_op = [tpu.initialize_system(job=self._tpu_job)]
|
||||
self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)]
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
logging.info('Init TPU system')
|
||||
session.run(self._init_op)
|
||||
|
||||
logging.info('Start infeed input thread controller')
|
||||
self._infeed_thd_controller = InfeedThreadController(
|
||||
session, self._enqueue_ops, self._iterations)
|
||||
|
||||
def before_run(self, run_context):
|
||||
logging.info('Load next batch of data to infeed.')
|
||||
self._infeed_thd_controller.load_next_batch()
|
||||
|
||||
def end(self, session):
|
||||
logging.info('Stop infeed input thread controller')
|
||||
self._infeed_thd_controller.join()
|
||||
|
||||
logging.info('Shutdown TPU system.')
|
||||
session.run(self._finalize_op)
|
||||
|
||||
|
||||
class TpuEstimator(estimator_lib.Estimator):
|
||||
"""Estimator with TPU support.
|
||||
|
||||
The only difference is a wrapped model_fn is set in the constructor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_fn=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
params=None,
|
||||
use_tpu=True):
|
||||
if use_tpu:
|
||||
model_function = wrapped_model_fn(model_fn, config)
|
||||
else:
|
||||
model_function = model_fn
|
||||
|
||||
super(TpuEstimator, self).__init__(
|
||||
model_fn=model_function,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
params=params)
|
||||
if not isinstance(config, tpu_config.RunConfig):
|
||||
raise ValueError('`config` must be `tpu_config.RunConfig`')
|
||||
|
||||
def _create_global_step(self, graph):
|
||||
"""Creates a global step suitable for TPUs.
|
||||
|
||||
Args:
|
||||
graph: The graph in which to create the global step.
|
||||
|
||||
Returns:
|
||||
A global step `Tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the global step tensor is already defined.
|
||||
"""
|
||||
graph = graph or ops.get_default_graph()
|
||||
if training.get_global_step(graph) is not None:
|
||||
raise ValueError('"global_step" already exists.')
|
||||
# Create in proper graph and base name_scope.
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
return variable_scope.get_variable(
|
||||
ops.GraphKeys.GLOBAL_STEP,
|
||||
shape=[],
|
||||
dtype=dtypes.int32,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
use_resource=True,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
|
||||
|
||||
# TODO(xiejw): Improve the structure of this input_fn to infeed converion.
|
||||
# The code now looks not like Estimator style. We need to abstract many
|
||||
# details.
|
||||
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
|
||||
"""Utility to convert input_fn to enqueue and dequeue fns for TPU.
|
||||
|
||||
Mainly, three things need to be done here.
|
||||
1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
|
||||
2. Create a dequeue_fn used by the train_step inside TPU execution to
|
||||
dequeue the tensors.
|
||||
3. Sets up the input thread to infeed.
|
||||
|
||||
Args:
|
||||
run_config: run_config
|
||||
features: features
|
||||
labels: labels
|
||||
|
||||
Returns:
|
||||
A tuple of (dequeue_fn, and thread main function)
|
||||
"""
|
||||
infeed_names = None
|
||||
infeed_tuple = []
|
||||
if isinstance(features, dict):
|
||||
# We need a fixed ordering for enqueueing and dequeueing.
|
||||
infeed_names = [name for name in features]
|
||||
infeed_tuple.extend([features[name] for name in infeed_names])
|
||||
else:
|
||||
infeed_tuple.append(features)
|
||||
# TODO(jhseu): Handle multi-head and None labels
|
||||
infeed_tuple.append(labels)
|
||||
# TODO(jhseu): Update when b/36470756 is settled.
|
||||
infeed_queue = tpu_feed.InfeedQueue(
|
||||
tuple_types=[t.dtype for t in infeed_tuple],
|
||||
tuple_shapes=[t.shape for t in infeed_tuple])
|
||||
infeed_queue.set_number_of_shards(run_config.tpu_config.num_shards)
|
||||
|
||||
def dequeue_fn():
|
||||
"""dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
|
||||
values = infeed_queue.generate_dequeue_op()
|
||||
if infeed_names is None:
|
||||
return values
|
||||
# Restore the feature dictionary and label.
|
||||
dequeued_features = {}
|
||||
for i in range(len(values) - 1):
|
||||
dequeued_features[infeed_names[i]] = values[i]
|
||||
label = values[-1]
|
||||
return dequeued_features, label
|
||||
|
||||
def enqueue_fn():
|
||||
"""enqueue_fn is used to add ops to the graph to send tensors."""
|
||||
job = _tpu_job(run_config)
|
||||
def placement_function(index):
|
||||
if job is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
|
||||
return infeed_queue.split_inputs_and_generate_enqueue_ops(
|
||||
infeed_tuple, placement_function=placement_function)
|
||||
|
||||
return (dequeue_fn, enqueue_fn)
|
||||
|
||||
|
||||
def wrapped_model_fn(model_fn, run_config):
|
||||
"""Returns a new model_fn, which wraps the TPU support."""
|
||||
|
||||
# Verifies the model_fn signature according to Estimator framework.
|
||||
estimator_lib._verify_model_fn_args(model_fn, params=None) # pylint: disable=protected-access
|
||||
|
||||
def _model_fn(features, labels, mode):
|
||||
"""model_fn."""
|
||||
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
|
||||
if mode != model_fn_lib.ModeKeys.TRAIN:
|
||||
return model_fn(features, labels, mode)
|
||||
|
||||
dequeue_fn, enqueue_fn = (
|
||||
_create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels))
|
||||
|
||||
loss = _train_on_tpu_shards(
|
||||
run_config,
|
||||
train_step=_convert_model_fn_to_train_step(
|
||||
model_fn, dequeue_fn, mode, run_config))
|
||||
|
||||
# Gets the variables back from TPU nodes. This means the variables updated
|
||||
# by TPU will now be *synced* to host memory.
|
||||
update_ops = [
|
||||
array_ops.check_numerics(v.read_value(),
|
||||
'Gradient for %s is NaN' % v.name).op
|
||||
for v in variables.trainable_variables()
|
||||
]
|
||||
|
||||
hooks = [
|
||||
TpuInfeedSessionHook(run_config, enqueue_fn),
|
||||
training.LoggingTensorHook(
|
||||
{'loss': array_ops.identity(loss),
|
||||
'step': training.get_global_step()},
|
||||
every_n_secs=30)
|
||||
]
|
||||
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=array_ops.identity(loss),
|
||||
training_hooks=hooks,
|
||||
train_op=control_flow_ops.group(*update_ops))
|
||||
return _model_fn
|
||||
|
||||
|
||||
def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config):
|
||||
"""generates a train step based on the model_fn."""
|
||||
|
||||
def _call_model_fn(features, labels):
|
||||
"""Calls the model_fn with required parameters."""
|
||||
model_fn_args = estimator_lib._model_fn_args(model_fn) # pylint: disable=protected-access
|
||||
kwargs = {}
|
||||
if 'mode' in model_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
# Uncomment the following lines once `params` is supported.
|
||||
# if 'params' in model_fn_args:
|
||||
# kwargs['params'] = params
|
||||
if 'config' in model_fn_args:
|
||||
kwargs['config'] = run_config
|
||||
return model_fn(features=features, labels=labels, **kwargs)
|
||||
|
||||
def _verify_estimator_spec(estimator_spec):
|
||||
"""Validates the estimator_spec."""
|
||||
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
|
||||
if estimator_spec.training_chief_hooks:
|
||||
raise ValueError(err_msg.format('training_chief_hooks'))
|
||||
if estimator_spec.training_hooks:
|
||||
raise ValueError(err_msg.format('training_hooks'))
|
||||
return estimator_spec
|
||||
|
||||
def train_step(loss):
|
||||
"""Training step function for use inside a while loop."""
|
||||
del loss # unused; required in function signature.
|
||||
features, labels = dequeue_fn()
|
||||
|
||||
# TODO(xiejw): how to do we support hook and savers in the original
|
||||
# model_fn. Realistically, the original
|
||||
# model_fn will be excuted on TPU chips in a replica way. The hooks
|
||||
# returned by the model_fn cannot be supported at all. If we have to,
|
||||
# the graph construction part in the model_fn should be separated from the
|
||||
# control part (such as hooks and savers). By that the graph construction
|
||||
# could de defered on TPU chip, while the control logic can stay in host.
|
||||
estimator_spec = _verify_estimator_spec(_call_model_fn(features, labels))
|
||||
loss, train_op = estimator_spec.loss, estimator_spec.train_op
|
||||
with ops.control_dependencies([train_op]):
|
||||
return array_ops.identity(loss)
|
||||
return train_step
|
||||
|
||||
|
||||
def _train_on_tpu_shards(run_config, train_step):
|
||||
"""Executes the `train_step` on all shards."""
|
||||
def train_shard():
|
||||
return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
|
||||
train_step,
|
||||
[1e7], # initial_loss
|
||||
name='loop')
|
||||
|
||||
(loss,) = tpu.shard(train_shard,
|
||||
inputs=[],
|
||||
num_shards=run_config.tpu_config.num_shards,
|
||||
outputs_from_all_shards=False)
|
||||
return loss
|
613
tensorflow/contrib/tpu/python/tpu/tpu_feed.py
Normal file
613
tensorflow/contrib/tpu/python/tpu/tpu_feed.py
Normal file
@ -0,0 +1,613 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Helper library for handling infeed between hosts and TPUs.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
class InfeedQueue(object):
|
||||
"""A helper object to build a device infeed queue.
|
||||
|
||||
The InfeedQueue builds the host-side and device-side Ops to enqueue and
|
||||
dequeue elements, respectively, and ensures that their types and
|
||||
shapes match.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
number_of_tuple_elements=None,
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
name=None):
|
||||
"""Creates a new InfeedQueue with the given configuration.
|
||||
|
||||
The configuration need not be fully specified at creation since it
|
||||
can be modified subsequently by methods that set the values
|
||||
explicitly or infer them from the shapes of inputs.
|
||||
|
||||
Args:
|
||||
number_of_tuple_elements: the number of Tensors fed atomically through the
|
||||
queue, must be present unless it can be inferred from other arguments.
|
||||
tuple_types: if not None, a list of types of the elements of the queue.
|
||||
tuple_shapes: if not None, a list of shapes of the elements of the queue.
|
||||
shard_dimensions: if not None, a list of dimensions on which the
|
||||
elements of the queue should be sharded during automatic
|
||||
parallelization.
|
||||
name: the name of the queue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_tuple_elements <= 0; or
|
||||
number_of_tuple_arguments, tuple_types, tuple_shapes, and
|
||||
shard_dimensions are all None; or the length of tuple_types,
|
||||
tuple_shapes, or shard_dimensions is not equal to
|
||||
number_of_tuple_elements; or any element of shard_dimensions
|
||||
can't be converted to a Dimension.
|
||||
TypeError: if any element of tuple_types or tuple_shapes can't
|
||||
be converted to a dtype or TensorShape, respectively.
|
||||
"""
|
||||
self._frozen = False
|
||||
self._generated_enqueue_ops = False
|
||||
self._generated_dequeue_op = False
|
||||
self._name = "InfeedQueue" if name is None else name
|
||||
if number_of_tuple_elements is None:
|
||||
if tuple_types is not None:
|
||||
number_of_tuple_elements = len(tuple_types)
|
||||
elif tuple_shapes is not None:
|
||||
number_of_tuple_elements = len(tuple_shapes)
|
||||
elif shard_dimensions is not None:
|
||||
number_of_tuple_elements = len(shard_dimensions)
|
||||
else:
|
||||
raise ValueError(
|
||||
"number of tuple elements cannot be inferred from InfeedQueue "
|
||||
"constructor"
|
||||
)
|
||||
if number_of_tuple_elements <= 0:
|
||||
raise ValueError("number_of_tuple_elements %d must be > 0" %
|
||||
number_of_tuple_elements)
|
||||
# Make an empty sharding policy for each tuple element.
|
||||
self._sharding_policies = [
|
||||
tpu_sharding.ShardingPolicy()
|
||||
for _ in xrange(number_of_tuple_elements)
|
||||
]
|
||||
if tuple_types is not None:
|
||||
self.set_tuple_types(tuple_types)
|
||||
else:
|
||||
self._tuple_types = None
|
||||
if tuple_shapes is not None:
|
||||
self.set_tuple_shapes(tuple_shapes)
|
||||
else:
|
||||
self._tuple_shapes = None
|
||||
if shard_dimensions is not None:
|
||||
self.set_shard_dimensions(shard_dimensions)
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Checks that the configuration is self-consistent.
|
||||
|
||||
Raises:
|
||||
ValueError: if the shapes and sharding policies don't match.
|
||||
"""
|
||||
if self.tuple_shapes is not None:
|
||||
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
|
||||
# Raise an error if the policy is incompatible with the shape.
|
||||
_ = policy.get_sharded_shape(shape)
|
||||
|
||||
@property
|
||||
def number_of_tuple_elements(self):
|
||||
"""Returns the number of InfeedQueue tuple elements."""
|
||||
return len(self._sharding_policies)
|
||||
|
||||
@property
|
||||
def tuple_types(self):
|
||||
"""Returns the types of the InfeedQueue tuple elements."""
|
||||
return self._tuple_types
|
||||
|
||||
def set_tuple_types(self, tuple_types):
|
||||
"""Sets the type of each element of the queue.
|
||||
|
||||
tuple_types must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a dtype.
|
||||
|
||||
Args:
|
||||
tuple_types: the types of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_types is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_types cannot be converted to a
|
||||
dtype.
|
||||
"""
|
||||
if len(tuple_types) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_types is %s, but must be a list of length %d" %
|
||||
(str(tuple_types), self.number_of_tuple_elements))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_types, tuple_types):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible type. Frozen types are %s, updated types are %s" % (
|
||||
str(self._tuple_types), str(tuple_types)))
|
||||
else:
|
||||
try:
|
||||
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
|
||||
except (TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_types is %s, but must be a list of elements each "
|
||||
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
|
||||
|
||||
@property
|
||||
def tuple_shapes(self):
|
||||
"""Returns the shapes of the InfeedQueue tuple elements."""
|
||||
return self._tuple_shapes
|
||||
|
||||
def set_tuple_shapes(self, tuple_shapes):
|
||||
"""Sets the shape of each element of the queue.
|
||||
|
||||
tuple_shapes must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a TensorShape.
|
||||
|
||||
Args:
|
||||
tuple_shapes: the shapes of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if tuple_shapes is not of length
|
||||
self.number_of_tuple_elements.
|
||||
TypeError: if an element of tuple_shapes cannot be converted to
|
||||
a TensorShape.
|
||||
"""
|
||||
if len(tuple_shapes) != self.number_of_tuple_elements:
|
||||
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
|
||||
(str(tuple_shapes), self.number_of_tuple_elements))
|
||||
try:
|
||||
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(
|
||||
"tuple_shapes is %s, but must be a list of elements each "
|
||||
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
|
||||
str(e)))
|
||||
if self._frozen:
|
||||
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
|
||||
if frozen != updated:
|
||||
raise ValueError(
|
||||
"Trying to update InfeedQueue with frozen configuration with an "
|
||||
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
|
||||
% (str(self._tuple_shapes), str(tuple_shapes)))
|
||||
else:
|
||||
self._tuple_shapes = tuple_shapes
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def sharding_policies(self):
|
||||
"""Returns the sharding policies of the InfeedQueue tuple elements."""
|
||||
return self._sharding_policies
|
||||
|
||||
@property
|
||||
def shard_dimensions(self):
|
||||
"""Gets the shard dimension of each tuple element.
|
||||
|
||||
Returns:
|
||||
A list of length number_of_tuple_elements, where each list entry
|
||||
is the shard dimension of that tuple element or None if the
|
||||
shard dimension has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return [policy.shard_dimension for policy in self._sharding_policies]
|
||||
|
||||
def set_shard_dimensions(self, shard_dimensions):
|
||||
"""Sets the shard_dimension of each element of the queue.
|
||||
|
||||
shard_dimensions must be a list of length
|
||||
self.number_of_tuple_elements, and each element must be
|
||||
convertible to a Dimension compatible with self.tuple_shapes.
|
||||
|
||||
Args:
|
||||
shard_dimensions: the dimensions of each queue element.
|
||||
|
||||
Raises:
|
||||
ValueError: if shard_dimensions is not of length
|
||||
self.number_of_tuple_elements; or an element of
|
||||
shard_dimensions cannot be converted to a Dimension; or an
|
||||
element of shard_dimensions is a Dimension that is out of
|
||||
range for the corresponding tuple element shape.
|
||||
"""
|
||||
if len(shard_dimensions) != self.number_of_tuple_elements:
|
||||
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
|
||||
% (str(shard_dimensions),
|
||||
self.number_of_tuple_elements))
|
||||
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
|
||||
policy.set_shard_dimension(dimension)
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Gets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Returns:
|
||||
Number of shards or None if the number of shards has not been set.
|
||||
"""
|
||||
# The number of shards is always the same for all the policies.
|
||||
return self._sharding_policies[0].number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards to use for the InfeedQueue.
|
||||
|
||||
Args:
|
||||
number_of_shards: number of ways to shard the InfeedQueue.
|
||||
|
||||
Raises:
|
||||
ValueError: if number_of_shards is not > 0; or the policies have
|
||||
been frozen and number_of_shards was already set to something
|
||||
else.
|
||||
"""
|
||||
for policy in self._sharding_policies:
|
||||
policy.set_number_of_shards(number_of_shards)
|
||||
self._validate()
|
||||
|
||||
def set_configuration_from_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of Tensors whose types and shapes are used
|
||||
to set the queue configuration.
|
||||
|
||||
Args:
|
||||
input_tensors: list of Tensors of the same types and shapes as
|
||||
the desired queue Tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: if input_tensors is not a list of length
|
||||
self.number_of_tuple_elements
|
||||
"""
|
||||
if len(input_tensors) != self.number_of_tuple_elements:
|
||||
raise ValueError(
|
||||
"input_tensors is %s, but should be a list of %d Tensors", (
|
||||
str(input_tensors), self.number_of_tuple_elements))
|
||||
self.set_tuple_shapes([t.shape for t in input_tensors])
|
||||
self.set_tuple_types([t.dtype for t in input_tensors])
|
||||
|
||||
def set_configuration_from_sharded_input_tensors(self, input_tensors):
|
||||
"""Sets the shapes and types of the queue tuple elements.
|
||||
|
||||
input_tensors is a list of lists of Tensors whose types and shapes are used
|
||||
to set the queue configuration. The length of the outer list is the number
|
||||
of shards required, and each inner list is the tuple of Tensors to use to
|
||||
determine the types and shapes of the correponding shard. This method
|
||||
depends on the shard dimension, and calling it freezes the shard policy.
|
||||
|
||||
Args:
|
||||
input_tensors: list of lists of Tensors. The outer list length corresponds
|
||||
to the desired number of shards, and each inner list is the size
|
||||
and shape of the desired configuration of the corresponding shard.
|
||||
|
||||
Raises:
|
||||
ValueError: if any inner list is not a list of length
|
||||
self.number_of_tuple_elements; or the inner lists do not combine to
|
||||
form a consistent unsharded shape.
|
||||
TypeError: if the types of the Tensors in the inner lists do not match.
|
||||
"""
|
||||
if not self._frozen:
|
||||
# Unset the tuple shapes in case the configuration becomes
|
||||
# transiently inconsistent.
|
||||
self._tuple_shapes = None
|
||||
number_of_shards = len(input_tensors)
|
||||
self.set_number_of_shards(number_of_shards)
|
||||
for t in input_tensors:
|
||||
if len(t) != self.number_of_tuple_elements:
|
||||
raise ValueError(
|
||||
"input_tensors is %s but must be a list of lists, where each inner"
|
||||
" list has length number_of_tuple_elements=%d" % (
|
||||
str(input_tensors), self.number_of_tuple_elements))
|
||||
# Transpose the inputs to make a list of shard shapes for each tuple
|
||||
# element.
|
||||
sharded_shapes = [[t[i].shape for t in input_tensors]
|
||||
for i in xrange(self.number_of_tuple_elements)]
|
||||
# For each tuple, get the unsharded shape using that tuple's policy.
|
||||
unsharded_shapes = [
|
||||
policy.get_unsharded_shape(s)
|
||||
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
|
||||
]
|
||||
self.set_tuple_shapes(unsharded_shapes)
|
||||
for i in xrange(1, self.number_of_shards):
|
||||
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
|
||||
if t1.dtype != t2.dtype:
|
||||
raise TypeError(
|
||||
"types of the tuple elements of input_tensors %s are not "
|
||||
"consistent" % str(input_tensors))
|
||||
self.set_tuple_types([t.dtype for t in input_tensors[0]])
|
||||
|
||||
def freeze(self):
|
||||
"""Freezes the InfeedQueue so it can no longer be modified.
|
||||
|
||||
The configuration is implicitly frozen before any host-side or
|
||||
device-side Ops are generated. The configuration cannot be frozen
|
||||
until the types and shapes of the tuple elements have been set.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set.
|
||||
"""
|
||||
self._frozen = True
|
||||
if self._tuple_types is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple types.")
|
||||
if self._tuple_shapes is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for shape in self._tuple_shapes:
|
||||
if shape.dims is None:
|
||||
raise ValueError(
|
||||
"Can't freeze an InfeedQueue without setting all tuple shapes.")
|
||||
for policy in self._sharding_policies:
|
||||
policy.freeze()
|
||||
self._validate()
|
||||
|
||||
def generate_dequeue_op(self):
|
||||
"""Generates the device-side Op to dequeue a tuple from the queue.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen, which will raise errors if the shapes and types have not
|
||||
been fully specified.
|
||||
|
||||
Returns:
|
||||
A list of Outputs corresponding to a shard of infeed dequeued
|
||||
into XLA, suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if the types or shapes of the tuple elements have not been
|
||||
set; or if a dequeue op has already been generated.
|
||||
"""
|
||||
self.freeze()
|
||||
if self._generated_dequeue_op:
|
||||
raise ValueError("Can't generate two dequeue Ops from the same queue")
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
|
||||
def _generate_enqueue_op(self,
|
||||
inputs,
|
||||
name_prefix,
|
||||
index,
|
||||
device=None,
|
||||
tpu_ordinal=-1):
|
||||
"""Generate a host-side Op to enqueue a tuple to the queue.
|
||||
|
||||
If device is None the inputs are all required to have the same
|
||||
device specification, and the enqueue Op is colocated with
|
||||
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors with the types and shapes of the tuple elements.
|
||||
name_prefix: the base name for the Op.
|
||||
index: the shard index, used to uniquify the Op name.
|
||||
device: device to place the Op on, or None if it should be
|
||||
colocated with the inputs.
|
||||
tpu_ordinal: ordinal of the TPU device on the host to use for
|
||||
infeed if device is a CPU device. Should be set to -1 if device
|
||||
is a TPU device.
|
||||
|
||||
Returns:
|
||||
An Op corresponding to a shard of infeed enqueued at the host,
|
||||
suitable for use within a replicated block.
|
||||
|
||||
Raises:
|
||||
ValueError: if device is None and inputs do not all have the
|
||||
same device specification.
|
||||
"""
|
||||
full_name = "%s/%d" % (name_prefix, index)
|
||||
shapes = [t.shape for t in inputs]
|
||||
if device is None:
|
||||
devices = [t.device for t in inputs]
|
||||
for i in xrange(1, self.number_of_tuple_elements):
|
||||
if devices[0] != devices[i]:
|
||||
raise ValueError(
|
||||
"input devices for shard %d are %s, but should all be the same",
|
||||
index, str(devices))
|
||||
with ops.colocate_with(inputs[0]):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
else:
|
||||
with ops.device(device):
|
||||
return tpu_ops.infeed_enqueue_tuple(
|
||||
inputs=inputs,
|
||||
shapes=shapes,
|
||||
name=full_name,
|
||||
device_ordinal=tpu_ordinal)
|
||||
|
||||
def generate_enqueue_ops(self, sharded_inputs):
|
||||
"""Generates the host-side Ops to enqueue the shards of a tuple.
|
||||
|
||||
sharded_inputs is a list, one for each shard, of lists of
|
||||
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
|
||||
shard 0 if the queue. Returns the host-side Ops that must be run to
|
||||
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
|
||||
for shard i.
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of sharded_inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
sharded_inputs: a list of lists of Tensors. The length of the outer list
|
||||
determines the number of shards. Each inner list indicates the types
|
||||
and shapes of the tuples in the corresponding shard.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the shapes of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple; or if the elements of a tuple
|
||||
have different device constraints.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of sharded_inputs are not compatible with the
|
||||
frozen configuration; or if the types of the elements of sharded_inputs
|
||||
don't form a consistent unsharded tuple.
|
||||
"""
|
||||
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(shard, name_prefix, index)
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
||||
|
||||
# TODO(misard) Generalize this to the case of systems that don't
|
||||
# have 8 devices per host, and figure out what to do with
|
||||
# model-parallelism.
|
||||
def _default_placement_function(self, index):
|
||||
return "/task:%d/device:CPU:0" % (index / 8)
|
||||
|
||||
def _default_ordinal_function(self, index):
|
||||
return index % 8
|
||||
|
||||
# TODO(b/36470756) remove this from tutorials once we have a better story
|
||||
# for automatic placement of input pipelines.
|
||||
def split_inputs_and_generate_enqueue_ops(self,
|
||||
inputs,
|
||||
global_tpu_id=None,
|
||||
placement_function=None,
|
||||
tpu_ordinal_function=None):
|
||||
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
|
||||
|
||||
Generates the host-side Ops to enqueue a tuple.
|
||||
|
||||
This method performs poorly because it takes an entire input on a single
|
||||
host, splits it, and distributes it to all of the cores. It is present only
|
||||
to simplify tutorial examples.
|
||||
|
||||
inputs is a list of Tensors to use to feed the queue. Each input is split
|
||||
into self.number_of_shards shards. Returns an Op for each shard to enqueue
|
||||
the shard. The Op for shard i is placed on device placement_function(i).
|
||||
|
||||
Implicitly freezes the queue configuration if it is not already
|
||||
frozen. If the configuration has already been frozen, and is not
|
||||
compatible with the types and shapes of inputs, an error
|
||||
will be raised.
|
||||
|
||||
Args:
|
||||
inputs: a list of Tensors which indicates the types and shapes of the
|
||||
queue tuple.
|
||||
global_tpu_id: if not None, a Numpy 2D array indicating the global
|
||||
id of each TPU device in the system. The outer dimension of the
|
||||
array is host task id, and the inner dimension is device ordinal,
|
||||
so e.g., global_tpu_id[x][y] indicates the global id of device
|
||||
/task:x/device:TPU_NODE:y. If global_tpu_id is not None, but
|
||||
placement_function and ordinal_function are None, then global_tpu_id
|
||||
will be used to place infeed on the TPUs with the first k global ids,
|
||||
where k is the number of shards in the queue.
|
||||
placement_function: if not None, a function that takes the shard
|
||||
index as input and returns a device string indicating which
|
||||
device the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
tpu_ordinal_function: if not None, a function that takes the
|
||||
shard index as input and returns the ordinal of the TPU device
|
||||
the shard's infeed should be placed on. If placement_function
|
||||
and tpu_ordinal_function are None, inputs are sharded round-robin
|
||||
across the devices in the system.
|
||||
|
||||
Returns:
|
||||
A list of host-side Ops, one for each shard, that when executed together
|
||||
will enqueue a full-size element of infeed.
|
||||
|
||||
Raises:
|
||||
ValueError: if the queue configuration has previously been frozen and the
|
||||
shapes of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
TypeError: if the queue configuration has previously been frozen and the
|
||||
types of the elements of inputs are not compatible with the frozen
|
||||
configuration.
|
||||
"""
|
||||
if global_tpu_id is None:
|
||||
if placement_function is None:
|
||||
placement_function = self._default_placement_function
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = self._default_ordinal_function
|
||||
else:
|
||||
global_id_map = {}
|
||||
for host, devices in enumerate(global_tpu_id):
|
||||
for ordinal, global_id in enumerate(devices):
|
||||
global_id_map[global_id] = (host, ordinal)
|
||||
|
||||
def _placement_function_from_map(index):
|
||||
return "/task:%d/device:CPU:0" % global_id_map[index][0]
|
||||
|
||||
def _ordinal_function_from_map(index):
|
||||
return global_id_map[index][1]
|
||||
|
||||
if placement_function is None:
|
||||
placement_function = _placement_function_from_map
|
||||
if tpu_ordinal_function is None:
|
||||
tpu_ordinal_function = _ordinal_function_from_map
|
||||
self.set_configuration_from_input_tensors(inputs)
|
||||
self.freeze()
|
||||
if self._generated_enqueue_ops:
|
||||
raise ValueError("Can't generate two enqueue Ops from the same queue")
|
||||
self._generated_enqueue_ops = True
|
||||
split_name_prefix = "%s/split" % self._name
|
||||
if self.number_of_shards == 1:
|
||||
transposed_sharded_inputs = [[inp] for inp in inputs]
|
||||
else:
|
||||
transposed_sharded_inputs = [
|
||||
array_ops.split(
|
||||
inp,
|
||||
self.number_of_shards,
|
||||
axis=policy.shard_dimension,
|
||||
name="%s/%d" % (split_name_prefix, index))
|
||||
for (inp, policy, index) in zip(inputs, self._sharding_policies,
|
||||
xrange(self.number_of_tuple_elements))
|
||||
]
|
||||
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
|
||||
for i in xrange(self.number_of_shards)]
|
||||
name_prefix = "%s/enqueue" % self._name
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
shard,
|
||||
name_prefix,
|
||||
index,
|
||||
device=placement_function(index),
|
||||
tpu_ordinal=tpu_ordinal_function(index))
|
||||
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
|
||||
]
|
106
tensorflow/contrib/tpu/python/tpu/tpu_function.py
Normal file
106
tensorflow/contrib/tpu/python/tpu/tpu_function.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for functions used during TPU compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class TpuContext(object):
|
||||
"""A context object holding state about the TPU computation being built."""
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a new TpuContext."""
|
||||
self._number_of_shards = None
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
self._number_of_shards = number_of_shards
|
||||
|
||||
|
||||
# The Tpu context holds the number of shards when a sharded computation is
|
||||
# being built, or None if no computation is being built.
|
||||
_current_tpu_context = TpuContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tpu_shard_context(number_of_shards):
|
||||
if _current_tpu_context.number_of_shards is not None:
|
||||
raise NotImplementedError("tpu_shard_context cannot be nested.")
|
||||
try:
|
||||
_current_tpu_context.set_number_of_shards(number_of_shards)
|
||||
yield
|
||||
finally:
|
||||
_current_tpu_context.set_number_of_shards(None)
|
||||
|
||||
|
||||
def get_tpu_context():
|
||||
return _current_tpu_context
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to a tpu function.
|
||||
|
||||
Args:
|
||||
func: the Python function that will be called to generate the body
|
||||
of a TPUFunction.
|
||||
input_arity: the number of explicit arguments supplied by the
|
||||
caller.
|
||||
infeed_queue: if not None, the infeed queue that will supply
|
||||
additional arguments to the function.
|
||||
|
||||
Returns:
|
||||
None if function can be called with the supplied number of
|
||||
arguments, or an error string if it cannot.
|
||||
"""
|
||||
def format_error(complaint, quantity):
|
||||
return "%s %d argument%s" % (complaint, quantity, ""
|
||||
if quantity == 1 else "s")
|
||||
|
||||
number_of_arguments_needed = input_arity
|
||||
if infeed_queue is not None:
|
||||
number_of_arguments_needed += infeed_queue.number_of_tuple_elements
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
number_of_args = len(arg_spec.args)
|
||||
if arg_spec.defaults is None:
|
||||
number_of_defaults = 0
|
||||
else:
|
||||
number_of_defaults = len(arg_spec.defaults)
|
||||
min_required_arguments = number_of_args - number_of_defaults
|
||||
if number_of_arguments_needed < min_required_arguments:
|
||||
# The required number of arguments is not enough to call the function.
|
||||
if number_of_defaults == 0 and arg_spec.varargs is None:
|
||||
return format_error("exactly", number_of_args)
|
||||
else:
|
||||
return format_error("at least", min_required_arguments)
|
||||
if arg_spec.varargs is None and number_of_arguments_needed > number_of_args:
|
||||
# The required number of arguments is too many to call the function.
|
||||
if number_of_defaults == 0:
|
||||
return format_error("exactly", number_of_args)
|
||||
else:
|
||||
return format_error("at most", number_of_args)
|
||||
# Since there are varargs, func can accept any number of arguments
|
||||
# greater than the minimum.
|
||||
return None
|
||||
|
125
tensorflow/contrib/tpu/python/tpu/tpu_function_test.py
Normal file
125
tensorflow/contrib/tpu/python/tpu/tpu_function_test.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Tests for tpu_function helpers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionArgCheckTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
"""Tests that arg checker works for functions with no varargs or defaults.
|
||||
"""
|
||||
|
||||
def func(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual("exactly 3 arguments",
|
||||
tpu_function.check_function_argument_count(func, 2, None))
|
||||
queue = tpu_feed.InfeedQueue(2)
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual("exactly 3 arguments",
|
||||
tpu_function.check_function_argument_count(func, 2, queue))
|
||||
|
||||
def testDefaultArgs(self):
|
||||
"""Tests that arg checker works for a function with no varargs."""
|
||||
|
||||
def func(x, y, z=17):
|
||||
return x + y + z
|
||||
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 1, None))
|
||||
self.assertEqual("at most 3 arguments",
|
||||
tpu_function.check_function_argument_count(func, 4, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 0, queue))
|
||||
self.assertEqual("at most 3 arguments",
|
||||
tpu_function.check_function_argument_count(func, 4, queue))
|
||||
|
||||
def testVarArgs(self):
|
||||
"""Tests that arg checker works for a function with varargs."""
|
||||
|
||||
def func(x, y, *z):
|
||||
return x + y + len(z)
|
||||
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 0, queue))
|
||||
|
||||
def testVarArgsAndDefaults(self):
|
||||
"""Tests that arg checker works for a function with varargs and defaults."""
|
||||
|
||||
def func(x, y, z=17, *q):
|
||||
return x + y + z + len(q)
|
||||
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 4, None))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 5, None))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 1, None))
|
||||
queue = tpu_feed.InfeedQueue(1)
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 1, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 2, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 3, queue))
|
||||
self.assertEqual(None,
|
||||
tpu_function.check_function_argument_count(func, 4, queue))
|
||||
self.assertEqual("at least 2 arguments",
|
||||
tpu_function.check_function_argument_count(func, 0, queue))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
130
tensorflow/contrib/tpu/python/tpu/tpu_infeed_test.py
Normal file
130
tensorflow/contrib/tpu/python/tpu/tpu_infeed_test.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Tests for TPU InfeedQueue methods."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InfeedTest(test.TestCase):
|
||||
|
||||
def testConstructor(self):
|
||||
"""Tests that the constructor can be called with different arguments."""
|
||||
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
|
||||
self.assertEqual(i.number_of_tuple_elements, 2)
|
||||
self.assertEqual(i.tuple_types, None)
|
||||
self.assertEqual(i.tuple_shapes, None)
|
||||
self.assertEqual(i.number_of_shards, None)
|
||||
i = tpu_feed.InfeedQueue(
|
||||
tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32])
|
||||
self.assertEqual(i.number_of_tuple_elements, 3)
|
||||
self.assertEqual(i.tuple_types,
|
||||
[dtypes.float32, dtypes.int32, dtypes.int32])
|
||||
self.assertEqual(i.tuple_shapes, None)
|
||||
self.assertEqual(i.number_of_shards, None)
|
||||
i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]])
|
||||
self.assertEqual(i.number_of_tuple_elements, 2)
|
||||
self.assertEqual(i.tuple_types, None)
|
||||
self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
|
||||
self.assertEqual(i.number_of_shards, None)
|
||||
i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7])
|
||||
self.assertEqual(i.number_of_tuple_elements, 3)
|
||||
self.assertEqual(i.tuple_types, None)
|
||||
self.assertEqual(i.tuple_shapes, None)
|
||||
self.assertEqual([p.shard_dimension
|
||||
for p in i.sharding_policies], [1, 0, 7])
|
||||
with self.assertRaises(ValueError):
|
||||
i = tpu_feed.InfeedQueue()
|
||||
with self.assertRaises(ValueError):
|
||||
i = tpu_feed.InfeedQueue(
|
||||
number_of_tuple_elements=2, tuple_types=[dtypes.float32])
|
||||
with self.assertRaises(ValueError):
|
||||
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]])
|
||||
with self.assertRaises(ValueError):
|
||||
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1])
|
||||
with self.assertRaises(ValueError):
|
||||
i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1])
|
||||
|
||||
def testModification(self):
|
||||
"""Tests modification of the queue post-construction."""
|
||||
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
|
||||
i.set_tuple_types([dtypes.float32, dtypes.int32])
|
||||
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
|
||||
i.set_tuple_types([dtypes.float32, dtypes.float32])
|
||||
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32])
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_tuple_types([dtypes.float32])
|
||||
i.set_tuple_shapes([[1], [2, 3]])
|
||||
self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
|
||||
i.set_tuple_shapes([[1, 2], [3, 4]])
|
||||
self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]])
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_tuple_shapes([[1, 2]])
|
||||
i.set_number_of_shards(2)
|
||||
self.assertEqual(i.number_of_shards, 2)
|
||||
i.set_number_of_shards(3)
|
||||
self.assertEqual(i.number_of_shards, 3)
|
||||
t1 = constant_op.constant(1, dtypes.int32, shape=[6])
|
||||
t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18])
|
||||
i.set_configuration_from_input_tensors([t1, t2])
|
||||
self.assertEqual(i.tuple_shapes, [[6], [3, 18]])
|
||||
self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32])
|
||||
i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
|
||||
self.assertEqual(i.number_of_shards, 2)
|
||||
self.assertEqual(i.tuple_shapes, [[6, 18], [12]])
|
||||
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
|
||||
i.set_shard_dimensions([1, 0])
|
||||
i.set_number_of_shards(3)
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_number_of_shards(4)
|
||||
|
||||
def testFreezing(self):
|
||||
"""Tests freezing the queue."""
|
||||
i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
|
||||
t1 = constant_op.constant(1, dtypes.int32, shape=[2])
|
||||
t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4])
|
||||
i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
|
||||
self.assertEqual(i.number_of_shards, 2)
|
||||
self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
|
||||
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
|
||||
self.assertEqual(i.shard_dimensions, [0, 0])
|
||||
i.freeze()
|
||||
i.set_number_of_shards(2)
|
||||
i.set_tuple_shapes([[4, 4], [4]])
|
||||
i.set_tuple_types([dtypes.float32, dtypes.int32])
|
||||
i.set_shard_dimensions([0, 0])
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_number_of_shards(1)
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_tuple_shapes([[8, 8], [8]])
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_tuple_types([dtypes.int32, dtypes.float32])
|
||||
with self.assertRaises(ValueError):
|
||||
i.set_shard_dimensions([1, 0])
|
||||
self.assertEqual(i.number_of_shards, 2)
|
||||
self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
|
||||
self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
|
||||
self.assertEqual(i.shard_dimensions, [0, 0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
106
tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
Normal file
106
tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Optimizer that implements cross-shard gradient reduction for TPU."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
|
||||
|
||||
class CrossShardOptimizer(optimizer.Optimizer):
|
||||
"""A optimizer sums gradients across TPU shards."""
|
||||
|
||||
def __init__(self, opt, name="CrossShardOptimizer"):
|
||||
super(CrossShardOptimizer, self).__init__(False, name)
|
||||
self._opt = opt
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
"""Compute gradients of "loss" for the variables in "var_list".
|
||||
|
||||
This simply wraps the compute_gradients() from the real optimizer. The
|
||||
gradients will be aggregated in the apply_gradients() so that user can
|
||||
modify the gradients like clipping with per replica global norm if needed.
|
||||
The global norm with aggregated gradients can be bad as one replica's huge
|
||||
gradients can hurt the gradients from other replicas.
|
||||
|
||||
Args:
|
||||
*args: Arguments for compute_gradients().
|
||||
**kwargs: Keyword arguments for compute_gradients().
|
||||
|
||||
Returns:
|
||||
A list of (gradient, variable) pairs.
|
||||
"""
|
||||
return self._opt.compute_gradients(*args, **kwargs)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""Apply gradients to variables.
|
||||
|
||||
Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
|
||||
replicas, and then applies the real optimizer.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||||
compute_gradients().
|
||||
global_step: Optional Variable to increment by one after the
|
||||
variables have been updated.
|
||||
name: Optional name for the returned operation. Default to the
|
||||
name passed to the Optimizer constructor.
|
||||
|
||||
Returns:
|
||||
An `Operation` that applies the gradients. If `global_step` was not None,
|
||||
that operation also increments `global_step`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the grads_and_vars is malformed.
|
||||
"""
|
||||
summed_grads_and_vars = []
|
||||
for (grad, var) in grads_and_vars:
|
||||
if grad is None:
|
||||
summed_grads_and_vars.append((grad, var))
|
||||
else:
|
||||
summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var))
|
||||
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
|
||||
|
||||
def get_slot(self, *args, **kwargs):
|
||||
"""Return a slot named "name" created for "var" by the Optimizer.
|
||||
|
||||
This simply wraps the get_slot() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
"""
|
||||
return self._opt.get_slot(*args, **kwargs)
|
||||
|
||||
def get_slot_names(self, *args, **kwargs):
|
||||
"""Return a list of the names of slots created by the `Optimizer`.
|
||||
|
||||
This simply wraps the get_slot_names() from the actual optimizer.
|
||||
|
||||
Args:
|
||||
*args: Arguments for get_slot().
|
||||
**kwargs: Keyword arguments for get_slot().
|
||||
|
||||
Returns:
|
||||
A list of strings.
|
||||
"""
|
||||
return self._opt.get_slot_names(*args, **kwargs)
|
248
tensorflow/contrib/tpu/python/tpu/tpu_sharding.py
Normal file
248
tensorflow/contrib/tpu/python/tpu/tpu_sharding.py
Normal file
@ -0,0 +1,248 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Helper library for sharding during TPU compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
_DEFAULT_NUMBER_OF_SHARDS = 1
|
||||
_DEFAULT_SHARD_DIMENSION = 0
|
||||
|
||||
|
||||
# TODO(b/36777903) change other parts of tpu.py to use this class.
|
||||
class ShardingPolicy(object):
|
||||
"""An object use to hold the sharding policy for a Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._number_of_shards = None
|
||||
self._shard_dimension = None
|
||||
self._frozen = False
|
||||
|
||||
def __str__(self):
|
||||
if self.number_of_shards is None or self.shard_dimension is None:
|
||||
return "ShardingPolicy(unset)"
|
||||
else:
|
||||
return ("ShardingPolicy(%d shards dimension %d)" %
|
||||
(self.number_of_shards, self.shard_dimension))
|
||||
|
||||
def _fill_default_values(self):
|
||||
if self._number_of_shards is None:
|
||||
self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
|
||||
if self._shard_dimension is None:
|
||||
self._shard_dimension = tensor_shape.as_dimension(
|
||||
_DEFAULT_SHARD_DIMENSION)
|
||||
|
||||
def freeze(self):
|
||||
"""Prevents further modification to the sharding policy.
|
||||
|
||||
Any values that have not been set when freeze is called are set to
|
||||
defaults. If the ShardingPolicy is already frozen, this is a NoOp.
|
||||
"""
|
||||
if not self._frozen:
|
||||
self._fill_default_values()
|
||||
self._frozen = True
|
||||
|
||||
@property
|
||||
def number_of_shards(self):
|
||||
"""Returns the number of shards in the policy or None if unspecified."""
|
||||
return self._number_of_shards
|
||||
|
||||
def set_number_of_shards(self, number_of_shards):
|
||||
"""Sets the number of shards for the current policy.
|
||||
|
||||
If the policy has been frozen then number_of_shards must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
number_of_shards: The number of shards to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and number_of_shards
|
||||
differs from the frozen value; or number_of_shards <= 0.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._number_of_shards != number_of_shards:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %d shards since it has been "
|
||||
"frozen to use %d." % (number_of_shards, self._number_of_shards))
|
||||
else:
|
||||
if number_of_shards > 0:
|
||||
self._number_of_shards = number_of_shards
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can't set sharding policy to use %s shards; value must be >0",
|
||||
str(number_of_shards))
|
||||
|
||||
@property
|
||||
def shard_dimension(self):
|
||||
"""Returns the shard dimension of the policy or None if unspecified."""
|
||||
return self._shard_dimension
|
||||
|
||||
def set_shard_dimension(self, shard_dimension):
|
||||
"""Sets the shard dimension for the current policy.
|
||||
|
||||
If the policy has been frozen then shard_dimension must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
shard_dimension: The shard dimension to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and shard_dimension
|
||||
differs from the frozen value, or shard_dimension can't be
|
||||
interpreted as a Dimension.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._shard_dimension != shard_dimension:
|
||||
raise ValueError(
|
||||
"Can't set shard dimension to %d since it has been frozen to "
|
||||
"use %d." % (shard_dimension, self._shard_dimension))
|
||||
else:
|
||||
self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
|
||||
|
||||
def merge(self, other):
|
||||
"""Merges the policy of another policy into the current policy.
|
||||
|
||||
Args:
|
||||
other: The policy to merge into this one.
|
||||
|
||||
Raises:
|
||||
ValueError: If this policy has been frozen and the merge conflicts with
|
||||
the frozen policy.
|
||||
"""
|
||||
if other.number_of_shards is not None:
|
||||
self.set_number_of_shards(other.number_of_shards)
|
||||
if other.shard_dimension is not None:
|
||||
self.set_shard_dimension(other.shard_dimension)
|
||||
|
||||
def get_sharded_shape(self, shape, shard_index=None):
|
||||
"""Returns the shape of a shard of a full Tensor.
|
||||
|
||||
When given the shape of a 'full-size' Tensor, returns the shape of
|
||||
the sub-Tensor after it has been sharded. Freezes the policy if it
|
||||
has not yet been frozen.
|
||||
|
||||
Args:
|
||||
shape: The shape of the full-size Tensor to be sharded.
|
||||
shard_index: The index of the shard whose shape should be returned.
|
||||
shard_index can be None for sharding policies that use the same
|
||||
shape for every shard.
|
||||
freeze_config:
|
||||
|
||||
Returns:
|
||||
The shape of the sharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If shard_index is None when shards are of different
|
||||
shapes; or shard_index is not None and
|
||||
!(0<=shard_index<number_of_shards); or shape does not have at
|
||||
least self.shard_dimension+1 dimensions; or the value of
|
||||
shape's shard dimension is not a multiple of
|
||||
self.number_of_shards
|
||||
"""
|
||||
if self._shard_dimension is None or self._number_of_shards is None:
|
||||
# Don't raise an error if the config is unset.
|
||||
return None
|
||||
if shard_index is not None:
|
||||
if shard_index < 0 or shard_index >= self.number_of_shards:
|
||||
raise ValueError("shard_index %d, but must be in [0,%d)." %
|
||||
(shard_index, self._number_of_shards))
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
if (dims[self._shard_dimension] % self._number_of_shards) != 0:
|
||||
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
||||
(shape.as_list(), self._number_of_shards,
|
||||
self._shard_dimension))
|
||||
dims[self._shard_dimension] /= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def _unshard_shape(self, shape):
|
||||
"""Return the unsharded shape that would generate a given sharded shape.
|
||||
|
||||
Args:
|
||||
shape: the sharded shape to unshard
|
||||
|
||||
Returns:
|
||||
The unsharded shape.
|
||||
|
||||
Raises:
|
||||
ValueError: if shape is unknown or does not contain
|
||||
self.shard_dimension
|
||||
TypeError: if shape is not convertible to a TensorShape
|
||||
"""
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if self._number_of_shards == 1:
|
||||
# Don't do anything when there's only one shard.
|
||||
return shape
|
||||
ndims = shape.ndims
|
||||
if ndims is None:
|
||||
raise ValueError("shape must be a specified shape not Unknown")
|
||||
if ndims <= self._shard_dimension:
|
||||
raise ValueError("shape %s does not contain shard_dimension %d" %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
dims = shape.as_list()
|
||||
dims[self._shard_dimension] *= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def get_unsharded_shape(self, shapes):
|
||||
"""Returns the shape of an unsharded Tensor given a list of shards.
|
||||
|
||||
When given a list of shapes of shards, returns the shape of the
|
||||
unsharded Tensor that would generate the shards. Sets defaults for the
|
||||
policy if number_of_shards or shard_dimension is None.
|
||||
|
||||
Args:
|
||||
shapes: The shapes of the Tensor shards to be combined.
|
||||
|
||||
Returns:
|
||||
The shape of the unsharded version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if shapes is not a list of length
|
||||
self.number_of_shards; or any element of shapes is not a valid
|
||||
shape consistent with the sharding policy; or the list of
|
||||
shapes is not a valid sharding of a full shape.
|
||||
TypeError: if an element of shapes is not convertible to a
|
||||
TensorShape
|
||||
"""
|
||||
self._fill_default_values()
|
||||
if len(shapes) != self.number_of_shards:
|
||||
raise ValueError(
|
||||
"shapes is %s but must be a list of length number_of_shards=%d" % (
|
||||
str(shapes), self.number_of_shards))
|
||||
unsharded_shapes = [self._unshard_shape(s) for s in shapes]
|
||||
for i in xrange(self.number_of_shards - 1):
|
||||
if unsharded_shapes[i] != unsharded_shapes[self.number_of_shards - 1]:
|
||||
raise ValueError(
|
||||
"sharded shapes %s are not consistent shards of a full shape "
|
||||
"sharded %d ways along dimension %d" % (
|
||||
str(shapes), self.number_of_shards, self.shard_dimension))
|
||||
return unsharded_shapes[0]
|
138
tensorflow/contrib/tpu/python/tpu/tpu_sharding_test.py
Normal file
138
tensorflow/contrib/tpu/python/tpu/tpu_sharding_test.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Tests for tpu_function helpers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ShardingTest(test.TestCase):
|
||||
|
||||
def testFreeze(self):
|
||||
"""Tests that freezing a policy applies default values."""
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
p1.freeze()
|
||||
self.assertEqual(p1.number_of_shards,
|
||||
tpu_sharding._DEFAULT_NUMBER_OF_SHARDS)
|
||||
self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION)
|
||||
p2 = tpu_sharding.ShardingPolicy()
|
||||
p2.set_number_of_shards(17)
|
||||
p2.set_shard_dimension(23)
|
||||
p2.freeze()
|
||||
self.assertEqual(p2.number_of_shards, 17)
|
||||
self.assertEqual(p2.shard_dimension, 23)
|
||||
|
||||
def testFrozen(self):
|
||||
"""Tests that frozen policies can't be changed."""
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
p1.freeze()
|
||||
with self.assertRaises(ValueError):
|
||||
p1.set_number_of_shards(17)
|
||||
with self.assertRaises(ValueError):
|
||||
p1.set_shard_dimension(22)
|
||||
|
||||
def testStr(self):
|
||||
"""Tests the string representation."""
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
self.assertEqual(str(p1), "ShardingPolicy(unset)")
|
||||
p1.set_number_of_shards(17)
|
||||
self.assertEqual(str(p1), "ShardingPolicy(unset)")
|
||||
p1.set_shard_dimension(8)
|
||||
self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
|
||||
|
||||
def testMerge(self):
|
||||
"""Tests that merging works."""
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
p1.set_number_of_shards(17)
|
||||
p1.set_shard_dimension(23)
|
||||
p2 = tpu_sharding.ShardingPolicy()
|
||||
p2.merge(p1)
|
||||
self.assertEqual(p2.number_of_shards, 17)
|
||||
self.assertEqual(p2.shard_dimension, 23)
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
p1.set_shard_dimension(12)
|
||||
p2.merge(p1)
|
||||
self.assertEqual(p2.number_of_shards, 17)
|
||||
self.assertEqual(p2.shard_dimension, 12)
|
||||
p2.freeze()
|
||||
p2.merge(p1)
|
||||
self.assertEqual(p2.number_of_shards, 17)
|
||||
self.assertEqual(p2.shard_dimension, 12)
|
||||
p1.set_number_of_shards(1)
|
||||
with self.assertRaises(ValueError):
|
||||
p2.merge(p1)
|
||||
p1 = tpu_sharding.ShardingPolicy()
|
||||
p1.set_number_of_shards(17)
|
||||
p2.merge(p1)
|
||||
p1.set_shard_dimension(2)
|
||||
with self.assertRaises(ValueError):
|
||||
p2.merge(p1)
|
||||
|
||||
def testGetShardedShape(self):
|
||||
"""Tests getting a sharded shape."""
|
||||
p = tpu_sharding.ShardingPolicy()
|
||||
p.set_number_of_shards(3)
|
||||
p.set_shard_dimension(1)
|
||||
self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3])
|
||||
p.freeze()
|
||||
with self.assertRaises(ValueError):
|
||||
p.set_shard_dimension(0)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_sharded_shape([4, 9], shard_index=4)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_sharded_shape([4, 9], shard_index=-1)
|
||||
with self.assertRaises(TypeError):
|
||||
_ = p.get_sharded_shape("not_a_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_sharded_shape(tensor_shape.TensorShape(None))
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_sharded_shape([4, 10], shard_index=-1)
|
||||
|
||||
def testGetUnshardedShape(self):
|
||||
"""Tests getting an unsharded shape."""
|
||||
p = tpu_sharding.ShardingPolicy()
|
||||
p.set_number_of_shards(2)
|
||||
p.set_shard_dimension(1)
|
||||
self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unsharded_shape([[4, 3]])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unsharded_shape([[4, 3], [4, 2]])
|
||||
with self.assertRaises(TypeError):
|
||||
_ = p.get_unsharded_shape([[4, 3], "not_a_shape"])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unsharded_shape([None, [4, 3]])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unsharded_shape([[2], [4, 3]])
|
||||
|
||||
def testScalar(self):
|
||||
"""Tests sharding and unsharding scalars."""
|
||||
p = tpu_sharding.ShardingPolicy()
|
||||
p.freeze()
|
||||
self.assertEqual(p.get_sharded_shape([]), [])
|
||||
self.assertEqual(p.get_unsharded_shape([[]]), [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
213
tensorflow/contrib/tpu/python/tpu/training_loop.py
Normal file
213
tensorflow/contrib/tpu/python/tpu/training_loop.py
Normal file
@ -0,0 +1,213 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Library for constructing a training loop, suitable for TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop for TPUs.
|
||||
|
||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||
`condition` and `body` take the current value of the loop-carried
|
||||
tensors. 'body' additionally takes a tuple of infeed from
|
||||
infeed_queue if infeed_queue is not None. `condition` must return a
|
||||
single boolean value that determines whether iteration
|
||||
continues. `body` must return an updated list of values for the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
condition: a Python function that builds the loop condition.
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop, or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: an optional name for the loop.
|
||||
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: if body or condition has the wrong signature.
|
||||
"""
|
||||
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
|
||||
x in inputs]
|
||||
input_types = [x.dtype for x in inputs]
|
||||
input_arity = len(inputs)
|
||||
|
||||
body_arg_error = tpu_function.check_function_argument_count(
|
||||
body, input_arity, infeed_queue)
|
||||
if body_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
|
||||
input_arity, str([i.name for i in inputs]), body_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop body function cannot be called with the specified "
|
||||
"inputs. You specified %d inputs: %s and %d additional inputs from "
|
||||
"infeed, but the computation needs %s" % (input_arity, str(
|
||||
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
|
||||
body_arg_error))
|
||||
condition_arg_error = tpu_function.check_function_argument_count(
|
||||
condition, input_arity, None)
|
||||
if condition_arg_error is not None:
|
||||
if infeed_queue is None:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Supplied loop condition function cannot be called with the "
|
||||
"specified inputs. You specified %d inputs: %s, but the loop "
|
||||
"condition needs %s. Note that infeed is not passed to the loop "
|
||||
"condition." % (input_arity, str([i.name for i in inputs]),
|
||||
condition_arg_error))
|
||||
|
||||
def condition_wrapper(*inputs):
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
return condition(*inputs)
|
||||
|
||||
def body_wrapper(*inputs):
|
||||
"""Wrapper around `body` that handles infeed queues and control deps."""
|
||||
inputs = list(inputs)
|
||||
|
||||
# Discards the dummy output added for arity-0 loops.
|
||||
if input_arity == 0:
|
||||
inputs = []
|
||||
|
||||
# Runs `body` with the dequeue_ops appended.
|
||||
if infeed_queue:
|
||||
number_of_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if number_of_shards is None:
|
||||
raise ValueError("Can't build training loop with infeed when there is "
|
||||
"no tpu_shard_context. Are you building a loop or "
|
||||
"graph directly rather than from inside tpu.rewrite, "
|
||||
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
|
||||
infeed_queue.set_number_of_shards(number_of_shards)
|
||||
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
|
||||
else:
|
||||
dequeue_ops = []
|
||||
outputs = body(*(inputs + dequeue_ops))
|
||||
|
||||
# If the computation only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
|
||||
outputs = [
|
||||
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
||||
for o in outputs
|
||||
]
|
||||
|
||||
# Separates the returned Operations and Tensors.
|
||||
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
||||
output_tensors = [o for o in outputs
|
||||
if not isinstance(o, ops.Operation)]
|
||||
|
||||
if outputs != output_tensors + output_operations:
|
||||
raise ValueError(
|
||||
"TPU training loop body must return zero or more Tensor values "
|
||||
"followed by zero or more Operations.")
|
||||
|
||||
output_types = [op.dtype for op in output_tensors]
|
||||
if input_types != output_types:
|
||||
raise TypeError(
|
||||
"Mismatch between input types and output types for training loop "
|
||||
"body: {} vs {}".format(input_types, output_types))
|
||||
|
||||
# Add the dequeue operations to output_operations to ensure they are run
|
||||
# by the loop, even if the programmer's loop body does not use them.
|
||||
output_operations += dequeue_ops
|
||||
|
||||
# Add a dummy output, if needed.
|
||||
if not output_tensors:
|
||||
output_tensors = array_ops.constant(0)
|
||||
|
||||
if output_operations:
|
||||
# TODO(phawkins): in principle this is too restrictive since it serializes
|
||||
# the training loop steps. In practice it does not matter since this loop
|
||||
# will be compiled by XLA.
|
||||
return control_flow_ops.tuple(output_tensors,
|
||||
control_inputs=output_operations)
|
||||
else:
|
||||
return output_tensors
|
||||
|
||||
# If the body has arity 0, add a dummy loop-carried value to which we can add
|
||||
# control dependencies from any side-effecting operations.
|
||||
if input_arity == 0:
|
||||
inputs = [array_ops.constant(0)]
|
||||
return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs,
|
||||
name=name)
|
||||
|
||||
|
||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
"""Builds a training loop that executes a fixed number of interations.
|
||||
|
||||
The set of loop-carried tensors correspond to `inputs`.
|
||||
`body` must be a function that takes and returns the values of the
|
||||
loop-carried tensors.
|
||||
|
||||
Args:
|
||||
n: the number of loop iterations
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
name: an optional name for the loop.
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
Raises:
|
||||
ValueError: if there is a type error.
|
||||
"""
|
||||
def _convert_to_list(xs):
|
||||
if not isinstance(xs, (list, tuple)):
|
||||
return [xs]
|
||||
else:
|
||||
return list(xs)
|
||||
|
||||
def cond(i, *args):
|
||||
del args
|
||||
return i < n
|
||||
|
||||
def body_wrapper(i, *args):
|
||||
return [i + 1] + _convert_to_list(body(*args))
|
||||
|
||||
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
|
||||
outputs = while_loop(
|
||||
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
|
||||
outputs = _convert_to_list(outputs)
|
||||
if len(outputs) == 1:
|
||||
# Returns the Op rather than an empty list.
|
||||
return outputs[0].op
|
||||
else:
|
||||
return outputs[1:]
|
@ -144,6 +144,9 @@ sh_binary(
|
||||
"//tensorflow/contrib/slim:slim",
|
||||
"//tensorflow/contrib/slim/python/slim/data:data_pip",
|
||||
"//tensorflow/contrib/slim/python/slim/nets:nets_pip",
|
||||
"//tensorflow/contrib/tpu:tpu_estimator",
|
||||
"//tensorflow/contrib/tpu:tpu_helper_library",
|
||||
"//tensorflow/contrib/tpu:tpu_py",
|
||||
"//tensorflow/contrib/specs:specs",
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
|
||||
|
Loading…
Reference in New Issue
Block a user