From 2c3738db9c4df83adc1aff29f5cb0e9735dd5eac Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Tue, 8 Dec 2015 14:55:13 -0800 Subject: [PATCH] TensorFlow: Upstream changes to git. Change 109730179 Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input. Change 109729548 TensorFlow: add RELEASE.md notes for 0.6.0. Change 109728185 Make seq2seq_test non-flaky by setting python and numpy random seed. Change 109725913 Refactor slot creation in optimizers and moving averages to separate file Change 109718024 TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s. Change 109712251 More performance improvement for convnet on GPU. + Switch forward convolution format to NCHW. + Allocate scratch space for forward- and backward- convolutions. + Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space limit. The default limit in 1GB. Change 109710898 Added extract_sub_graph utility function Base CL: 109731609 --- RELEASE.md | 27 +++ tensorflow/core/kernels/conv_grad_ops.cc | 36 ++-- tensorflow/core/kernels/conv_ops.cc | 80 +++++++-- tensorflow/core/kernels/conv_ops_gpu.h | 84 ++++++++++ tensorflow/python/client/graph_util.py | 62 +++++++ tensorflow/python/client/graph_util_test.py | 30 ++++ .../python/kernel_tests/embedding_ops_test.py | 155 +++++++++++++++--- .../python/kernel_tests/seq2seq_test.py | 47 +++--- tensorflow/python/ops/embedding_ops.py | 100 ++++++++--- tensorflow/python/ops/nn.py | 21 ++- tensorflow/python/training/moving_averages.py | 31 ++-- tensorflow/python/training/optimizer.py | 41 +++-- tensorflow/python/training/slot_creator.py | 108 ++++++++++++ .../python/training/slot_creator_test.py | 78 +++++++++ 14 files changed, 763 insertions(+), 137 deletions(-) create mode 100644 tensorflow/core/kernels/conv_ops_gpu.h create mode 100644 tensorflow/python/training/slot_creator.py create mode 100644 tensorflow/python/training/slot_creator_test.py diff --git a/RELEASE.md b/RELEASE.md index b475bd99e85..69b4fa20a0c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,30 @@ +# Release 0.6.0 + +## Major Features and Improvements + +* Python 3.3+ support via changes to python codebase and ability + to specify python version via ./configure. + +* Some improvements to GPU performance and memory usage: + [convnet benchmarks](https://github.com/soumith/convnet-benchmarks/issues/66) + roughly equivalent with native cudnn v2 performance. Improvements mostly due + to moving to 32-bit indices, faster shuffling kernels. More improvements to + come in later releases. + + +## Bug fixes + +* Lots of fixes to documentation and tutorials, many contributed + by the public. + +* 271 closed issues on github issues. + +## Backwards-incompatible changes + +* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion' + attribute from 0.0 to 1.0. This was a bug in the original release + that is now fixed. + # Release 0.5.0 Initial release of TensorFlow. diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 8bd13b4be3d..6047ddbe238 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -35,6 +35,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/stream_executor/stream.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -756,17 +757,6 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") // GPU definitions of both ops. #if GOOGLE_CUDA -namespace { -template -perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, - uint64 size) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), - size * sizeof(T)); - perftools::gputools::DeviceMemory typed(wrapped); - return typed; -} -} // namespace - // The slow version (but compiles for GPU) // Backprop for input. @@ -929,10 +919,15 @@ class Conv2DSlowBackpropInputOp : public OpKernel { AsDeviceMemory(pre_transformed_in_backprop.template flat().data(), pre_transformed_in_backprop.template flat().size()); + static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + context); bool cudnn_launch_status = - stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc, - out_backprop_ptr, conv_desc, - input_desc, &in_backprop_ptr) + stream->ThenConvolveBackwardDataWithScratch( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator) .ok(); if (!cudnn_launch_status) { @@ -1185,7 +1180,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->eigen_device(), const_cast(compatible_input).tensor(), transformed_input.tensor()); - auto out_backprop_ptr = AsDeviceMemory(transformed_out_backprop.template flat().data(), transformed_out_backprop.template flat().size()); @@ -1196,10 +1190,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { AsDeviceMemory(transformed_input.template flat().data(), transformed_input.template flat().size()); + static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + context); bool cudnn_launch_status = - stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc, - out_backprop_ptr, conv_desc, - filter_desc, &filter_backprop_ptr) + stream->ThenConvolveBackwardFilterWithScratch( + input_desc, input_ptr, output_desc, out_backprop_ptr, + conv_desc, filter_desc, &filter_backprop_ptr, + &scratch_allocator) .ok(); if (!cudnn_launch_status) { diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 8c0d4d73898..6af13cb6cd3 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/tensor.h" #include "tensorflow/core/public/tensor_shape.h" @@ -34,6 +35,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/stream_executor/stream.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D") #if GOOGLE_CUDA -namespace { -template -perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, - uint64 size) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), - size * sizeof(T)); - perftools::gputools::DeviceMemory typed(wrapped); - return typed; +int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + if (strings::safe_strto64(workspace_limit_in_mb_str, + &scratch_limit_in_mb)) { + return scratch_limit_in_mb * (1 << 20); + } else { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } + } + return default_value_in_bytes; } -} // namespace template struct LaunchConvOp { @@ -287,18 +295,34 @@ struct LaunchConvOp { input = transformed_input; } + { + // Convert the input tensor from NHWC to NCHW. + Tensor transformed_input; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({input.dim_size(0), input.dim_size(3), + input.dim_size(1), input.dim_size(2)}), + &transformed_input)); + functor::NHWCToNCHW()( + ctx->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + input = transformed_input; + } + perftools::gputools::dnn::BatchDescriptor input_desc; input_desc.set_count(input.dim_size(0)) - .set_height(input.dim_size(1)) - .set_width(input.dim_size(2)) - .set_feature_map_count(input.dim_size(3)) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth); + .set_feature_map_count(input.dim_size(1)) + .set_height(input.dim_size(2)) + .set_width(input.dim_size(3)) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::BatchDescriptor output_desc; output_desc.set_count(output->dim_size(0)) .set_height(output->dim_size(1)) .set_width(output->dim_size(2)) .set_feature_map_count(output->dim_size(3)) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth); + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::FilterDescriptor filter_desc; filter_desc.set_input_filter_height(filter.dim_size(0)) .set_input_filter_width(filter.dim_size(1)) @@ -320,17 +344,31 @@ struct LaunchConvOp { ctx->eigen_device(), To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); + Tensor transformed_output; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({output->dim_size(0), output->dim_size(3), + output->dim_size(1), output->dim_size(2)}), + &transformed_output)); + auto input_ptr = AsDeviceMemory(input.template flat().data(), input.template flat().size()); auto filter_ptr = AsDeviceMemory(transformed_filter.template flat().data(), transformed_filter.template flat().size()); - auto output_ptr = AsDeviceMemory(output->template flat().data(), - output->template flat().size()); + auto output_ptr = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = - stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr, - conv_desc, output_desc, &output_ptr) + stream->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator) .ok(); if (!cudnn_launch_status) { @@ -338,6 +376,12 @@ struct LaunchConvOp { "cuDNN launch failure : input shape(", input.shape().DebugString(), ") filter shape(", filter.shape().DebugString(), ")")); } + + // Convert the output tensor back from NHWC to NCHW. + functor::NCHWToNHWC()( + ctx->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); } else { LaunchGeneric::launch(ctx, input_param, filter, stride, padding, output); diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h new file mode 100644 index 00000000000..bbe06cb6a16 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -0,0 +1,84 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ + +#if GOOGLE_CUDA + +#include "tensorflow/stream_executor/scratch_allocator.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" + +namespace tensorflow { + +// TODO(zhengxq): move this to gpu_util.h. The use of such wrapers is wide +// spread. +template +perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, + uint64 size) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory), + size * sizeof(T)); + perftools::gputools::DeviceMemory typed(wrapped); + return typed; +} + +// Get the Cudnn workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes); + +// A class to provide scratch-space allocator for Stream-Executor Cudnn +// callback. TensorFlow is responsible for releasing the temporary buffers after +// the kernel finishes. +class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator { + public: + virtual ~CudnnScratchAllocator() {} + CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context) + : memory_limit_(memory_limit), context_(context) {} + virtual int64 GetMemoryLimitInBytes( + perftools::gputools::Stream* stream) override { + return memory_limit_; + } + virtual perftools::gputools::port::StatusOr< + perftools::gputools::DeviceMemory> + AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override { + Tensor temporary_memory; + + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory)); + if (!allocation_status.ok()) { + LOG(WARNING) << allocation_status; + context_->SetStatus(allocation_status); + return perftools::gputools::port::StatusOr< + perftools::gputools::DeviceMemory>(); + } + + return perftools::gputools::port::StatusOr< + perftools::gputools::DeviceMemory>( + AsDeviceMemory(temporary_memory.flat().data(), + temporary_memory.flat().size())); + } + + private: + int64 memory_limit_; + OpKernelContext* context_; +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py index 31b2dddc23c..154e1e13021 100644 --- a/tensorflow/python/client/graph_util.py +++ b/tensorflow/python/client/graph_util.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import tensorflow.python.platform @@ -155,3 +156,64 @@ def pin_to_cpu(op): logging.info("Operation %s has been assigned to a non-CPU (%s), so " "it will not be pinned to the CPU.", op.name, dev.device_type) return device + + +def _node_name(n): + if n.startswith("^"): + return n[1:] + else: + return n.split(":")[0] + + +def extract_sub_graph(graph_def, dest_nodes): + """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. + + Args: + graph_def: A graph_pb2.GraphDef proto. + dest_nodes: A list of strings specifying the destination node names. + Returns: + The GraphDef of the sub-graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + edges = {} # Keyed by the dest node name. + name_to_node_map = {} # Keyed by node name. + + # Keeps track of node sequences. It is important to still output the + # operations in the original order. + node_seq = {} # Keyed by node name. + seq = 0 + for node in graph_def.node: + n = _node_name(node.name) + name_to_node_map[n] = node + edges[n] = [_node_name(x) for x in node.input] + node_seq[n] = seq + seq += 1 + + for d in dest_nodes: + assert d in name_to_node_map, "%d is not in graph" % d + + nodes_to_keep = set() + # Breadth first search to find all the nodes that we should keep. + next_to_visit = dest_nodes[:] + while next_to_visit: + n = next_to_visit[0] + del next_to_visit[0] + if n in nodes_to_keep: + # Already visited this node. + continue + nodes_to_keep.add(n) + next_to_visit += edges[n] + + nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n]) + # Now construct the output GraphDef + out = graph_pb2.GraphDef() + for n in nodes_to_keep_list: + out.node.extend([copy.deepcopy(name_to_node_map[n])]) + + return out diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py index 6b7dba60bc0..73265361cd3 100644 --- a/tensorflow/python/client/graph_util_test.py +++ b/tensorflow/python/client/graph_util_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import tensorflow.python.platform +import tensorflow as tf + from tensorflow.python.client import graph_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -140,6 +142,34 @@ class DeviceFunctionsTest(googletest.TestCase): self.assertEqual(const_4.device, "/device:CPU:1") self.assertEqual(const_5.device, "/replica:0") + def testExtractSubGraph(self): + graph_def = tf.GraphDef() + n1 = graph_def.node.add() + n1.name = "n1" + n1.input.extend(["n5"]) + n2 = graph_def.node.add() + n2.name = "n2" + # Take the first output of the n1 node as the input. + n2.input.extend(["n1:0"]) + n3 = graph_def.node.add() + n3.name = "n3" + # Add a control input (which isn't really needed by the kernel, but + # rather to enforce execution order between nodes). + n3.input.extend(["^n2"]) + n4 = graph_def.node.add() + n4.name = "n4" + + # It is fine to have a loops in the graph as well. + n5 = graph_def.node.add() + n5.name = "n5" + n5.input.extend(["n1"]) + + sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) + self.assertEqual("n1", sub_graph.node[0].name) + self.assertEqual("n2", sub_graph.node[1].name) + self.assertEqual("n3", sub_graph.node[2].name) + self.assertEqual("n5", sub_graph.node[3].name) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 5f54f02bf06..9a4cd14eb1a 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -115,26 +115,34 @@ def _PName(param_id): def _EmbeddingParams(num_shards, vocab_size, dtype=tf.float32, - shape=None): + shape=None, + use_shapeless_placeholder=False): p = [] params = {} feed_dict = {} if not shape: shape = [10] - assert not vocab_size % num_shards - shape = [vocab_size // num_shards] + shape for i in range(num_shards): + shard_shape = [vocab_size // num_shards] + shape + if i < vocab_size % num_shards: # Excess goes evenly on the first shards + shard_shape[0] += 1 + param_name = _PName(i) - constant_t = tf.constant(1.0, shape=shape, dtype=dtype, - name=param_name) - p.append(constant_t) + + if use_shapeless_placeholder: + param = tf.placeholder(dtype, shape=None, name=param_name) + else: + param = tf.constant(1.0, shape=shard_shape, dtype=dtype, name=param_name) + p.append(param) np_type = "f" if dtype == tf.float32 else "d" - val = (np.random.rand(*shape).astype(np_type)) + 1 + val = (np.random.rand(*shard_shape).astype(np_type)) + 1 params[param_name + ":0"] = val - feed_dict[constant_t.name] = val + feed_dict[param.name] = val return p, params, feed_dict -def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None): +def _EmbeddingResult(params, id_vals, num_shards, vocab_size, + partition_strategy="mod", + weight_vals=None): if weight_vals is None: weight_vals = np.copy(id_vals) weight_vals.fill(1) @@ -147,8 +155,22 @@ def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None): ids = [ids] wts = [wts] for i, wt_val in zip(ids, wts): - val = np.copy(params[_PName(i % num_shards) + ":0"][ - i // num_shards, :]) * wt_val + if partition_strategy == "mod": + val = np.copy(params[_PName(i % num_shards) + ":0"][ + i // num_shards, :]) * wt_val + elif partition_strategy == "div": + ids_per_partition, extras = divmod(vocab_size, num_shards) + threshold = extras * (ids_per_partition + 1) + if i < threshold: + partition = i // (ids_per_partition + 1) + offset = i % (ids_per_partition + 1) + else: + partition = extras + (i - threshold) // ids_per_partition + offset = (i - threshold) % ids_per_partition + val = np.copy( + params[_PName(partition) + ":0"][offset, :]) * wt_val + else: + assert False if val_aggr is None: assert wt_aggr is None val_aggr = val @@ -182,17 +204,17 @@ class EmbeddingLookupTest(tf.test.TestCase): embedding = tf.nn.embedding_lookup(p, ids) tf_result = embedding.eval(feed_dict=feed_dict) - np_result, _ = _EmbeddingResult(params, id_vals, num_shards) + np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) self.assertAllEqual(np_result, tf_result) self.assertShapeEqual(np_result, embedding) - def testSharded(self): + def testShardedModPartitioningInt32Ids(self): with self.test_session(): num_shards = 5 - vocab_size = 25 - # Embedding dimensions is 10. The 10 x vocab_size embedding - # parameters are spread in num_shards matrices, so each - # matrix is 10 x (vocab_size / num_shards) + vocab_size = 13 + # Embedding dimensions is 10. The vocab_size x 10 embedding + # parameters are spread in num_shards matrices, so the first + # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) num_vals = 30 @@ -204,10 +226,103 @@ class EmbeddingLookupTest(tf.test.TestCase): embedding = tf.nn.embedding_lookup(p, ids) tf_result = embedding.eval(feed_dict=feed_dict) - np_result, _ = _EmbeddingResult(params, id_vals, num_shards) + np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) self.assertAllEqual(np_result, tf_result) self.assertShapeEqual(np_result, embedding) + def testShardedModPartitioningInt64Ids(self): + with self.test_session(): + num_shards = 5 + vocab_size = 13 + # Embedding dimensions is 10. The vocab_size x 10 embedding + # parameters are spread in num_shards matrices, so the first + # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. + p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) + + num_vals = 30 + # Fetch num_vals embeddings for random word ids. Since + # num_vals > vocab_size, this ought to have repetitions, so + # will test that aspect. + id_vals = np.random.randint(vocab_size, size=num_vals) + ids = tf.constant(list(id_vals), dtype=tf.int64) + + embedding = tf.nn.embedding_lookup(p, ids) + tf_result = embedding.eval(feed_dict=feed_dict) + np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) + self.assertAllEqual(np_result, tf_result) + self.assertShapeEqual(np_result, embedding) + + def testShardedDivPartitioningInt32Ids(self): + with self.test_session(): + num_shards = 5 + vocab_size = 13 + # Embedding dimensions is 10. The vocab_size x 10 embedding + # parameters are spread in num_shards matrices, so the first + # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. + p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) + + num_vals = 30 + # Fetch num_vals embeddings for random word ids. Since + # num_vals > vocab_size, this ought to have repetitions, so + # will test that aspect. + id_vals = np.random.randint(vocab_size, size=num_vals) + ids = tf.constant(list(id_vals), dtype=tf.int32) + + embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div") + tf_result = embedding.eval(feed_dict=feed_dict) + np_result, _ = _EmbeddingResult( + params, id_vals, num_shards, vocab_size, partition_strategy="div") + self.assertAllEqual(np_result, tf_result) + self.assertShapeEqual(np_result, embedding) + + def testShardedDivPartitioningInt64Ids(self): + with self.test_session(): + num_shards = 5 + vocab_size = 13 + # Embedding dimensions is 10. The vocab_size x 10 embedding + # parameters are spread in num_shards matrices, so the first + # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. + p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) + + num_vals = 30 + # Fetch num_vals embeddings for random word ids. Since + # num_vals > vocab_size, this ought to have repetitions, so + # will test that aspect. + id_vals = np.random.randint(vocab_size, size=num_vals) + ids = tf.constant(list(id_vals), dtype=tf.int64) + + embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div") + tf_result = embedding.eval(feed_dict=feed_dict) + np_result, _ = _EmbeddingResult( + params, id_vals, num_shards, vocab_size, partition_strategy="div") + self.assertAllEqual(np_result, tf_result) + self.assertShapeEqual(np_result, embedding) + + def testShardedDivPartitioningUnknownParamShape(self): + with self.test_session(): + num_shards = 5 + vocab_size = 13 + # Embedding dimensions is 10. The vocab_size x 10 embedding + # parameters are spread in num_shards matrices, so the first + # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. + + # We clear parameter shapes, to test when shape is not statically known. + p, params, feed_dict = _EmbeddingParams( + num_shards, vocab_size, use_shapeless_placeholder=True) + + num_vals = 30 + # Fetch num_vals embeddings for random word ids. Since + # num_vals > vocab_size, this ought to have repetitions, so + # will test that aspect. + id_vals = np.random.randint(vocab_size, size=num_vals) + ids = tf.constant(list(id_vals), dtype=tf.int64) + + embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div") + tf_result = embedding.eval(feed_dict=feed_dict) + np_result, _ = _EmbeddingResult( + params, id_vals, num_shards, vocab_size, partition_strategy="div") + self.assertAllEqual(np_result, tf_result) + def testGradientsEmbeddingLookup(self): vocab_size = 9 num_ids = 5 @@ -326,7 +441,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): return grouped_vals def testEmbeddingLookupSparse(self): - vocab_size = 25 + vocab_size = 13 batch_size = 10 param_shape = [2, 5] @@ -354,7 +469,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) np_embedding_sum, np_weight_sum = _EmbeddingResult( - params, grouped_ids, num_shards, + params, grouped_ids, num_shards, vocab_size, weight_vals=grouped_ignored_weights if ignore_weights else grouped_weights) if combiner == "mean": diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py index 1582d8d2ffe..d4e4b10080a 100644 --- a/tensorflow/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/python/kernel_tests/seq2seq_test.py @@ -354,6 +354,10 @@ class Seq2SeqTest(tf.test.TestCase): # We learn to copy 10 symbols in 2 buckets: length 4 and length 8. classes = 10 buckets = [(4, 4), (8, 8)] + perplexities = [[], []] # Results for each bucket. + tf.set_random_seed(111) + random.seed(111) + np.random.seed(111) with self.test_session() as sess: # We use sampled softmax so we keep output projection separate. @@ -378,8 +382,7 @@ class Seq2SeqTest(tf.test.TestCase): softmax_loss_function=SampledLoss) # Now we construct the copy model. - tf.set_random_seed(111) - batch_size = 32 + batch_size = 8 inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)] out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)] weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)] @@ -394,26 +397,26 @@ class Seq2SeqTest(tf.test.TestCase): update = optimizer.apply_gradients(zip(grads, params)) updates.append(update) sess.run([tf.initialize_all_variables()]) - for ep in xrange(3): - log_perp = 0.0 - for _ in xrange(50): - bucket = random.choice(np.arange(len(buckets))) - length = buckets[bucket][0] - i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)], - dtype=np.int32) for _ in xrange(length)] - # 0 is our "GO" symbol here. - o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i - feed = {} - for l in xrange(length): - feed[inp[l].name] = i[l] - feed[out[l].name] = o[l] - if length < 8: # For the 4-bucket, we need the 5th as target. - feed[out[length].name] = o[length] - res = sess.run([updates[bucket], losses[bucket]], feed) - log_perp += float(res[1]) - perp = math.exp(log_perp / 100) - print("step %d avg. perp %f" % ((ep + 1) * 50, perp)) - self.assertLess(perp, 2.5) + steps = 6 + for _ in xrange(steps): + bucket = random.choice(np.arange(len(buckets))) + length = buckets[bucket][0] + i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)], + dtype=np.int32) for _ in xrange(length)] + # 0 is our "GO" symbol here. + o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i + feed = {} + for l in xrange(length): + feed[inp[l].name] = i[l] + feed[out[l].name] = o[l] + if length < 8: # For the 4-bucket, we need the 5th as target. + feed[out[length].name] = o[length] + res = sess.run([updates[bucket], losses[bucket]], feed) + perplexities[bucket].append(math.exp(float(res[1]))) + for bucket in xrange(len(buckets)): + if len(perplexities[bucket]) > 1: # Assert that perplexity went down. + self.assertLess(perplexities[bucket][1], perplexities[bucket][0]) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index ff39c3c7513..c409b1e6c83 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -22,11 +22,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops -def embedding_lookup(params, ids, name=None): +def embedding_lookup(params, ids, partition_strategy="mod", name=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of @@ -35,16 +36,32 @@ def embedding_lookup(params, ids, name=None): interpreted as a partition of a larger embedding tensor. If `len(params) > 1`, each element `id` of `ids` is partitioned between - the elements of `params` by computing `p = id % len(params)`, and is - then used to look up the slice `params[p][id // len(params), ...]`. + the elements of `params` according to the `partition_strategy`. + In all strategies, if the id space does not evenly divide the number of + partitions, each of the first `(max_id + 1) % len(params)` partitions will + be assigned one more id. - The results of the lookup are then concatenated into a dense + If `partition_strategy` is `"mod"`, we assign each id to partition + `p = id % len(params)`. For instance, + 13 ids are split across 5 partitions as: + `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` + + If `partition_strategy` is `"div"`, we assign ids to partitions in a + contiguous manner. In this case, 13 ids are split across 5 partitions as: + `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` + + The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: - params: A list of tensors with the same shape and type. + params: A list of tensors with the same type and which can be concatenated + along dimension 0. Each `Tensor` must be appropriately sized for the given + `partition_strategy`. ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for the operation (optional). Returns: @@ -67,23 +84,59 @@ def embedding_lookup(params, ids, name=None): ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) - # Compute flat_ids % partitions for each id - ids_mod_p = flat_ids % np - if ids_mod_p.dtype != dtypes.int32: - ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32) - # Partition single list of ids based on ids % np into np separate lists - plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np) + + # Create p_assignments and set new_ids depending on the strategy. + if partition_strategy == "mod": + p_assignments = flat_ids % np + new_ids = flat_ids // np + elif partition_strategy == "div": + # Compute num_total_ids as the sum of dim-0 of params, then assign to + # partitions based on a constant number of ids per partition. Optimize + # if we already know the full shape statically. + dim_0_size = params[0].get_shape()[0] + for p in xrange(1, np): + dim_0_size += params[p].get_shape()[0] + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) + else: + dim_0_sizes = [] + for p in xrange(np): + with ops.device(params[p].device): + dim_0_sizes.append(array_ops.shape(params[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // np + extras = num_total_ids % np + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + is_in_first_extras_partitions = math_ops.cast( + p_assignments < extras, flat_ids.dtype) + new_ids = ( + is_in_first_extras_partitions * ( + flat_ids % (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * ( + (flat_ids - extras) % ids_per_partition)) + else: + raise ValueError("Unrecognized partition strategy: " + + partition_strategy) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into np separate lists + gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. - pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p, - np) + pindices = data_flow_ops.dynamic_partition(original_indices, + p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): - # TODO(agarwal): handle device allocations here and later in the - # colocate code. - gather_ids = plist[p] // np with ops.device(params[p].device): - partitioned_result.append(array_ops.gather(params[p], gather_ids)) + partitioned_result.append(array_ops.gather(params[p], gather_ids[p])) # Stitch these back together ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, name=name) @@ -106,6 +159,7 @@ def embedding_lookup(params, ids, name=None): # TODO(lif): Add support for higher-rank SparseTensors def embedding_lookup_sparse(params, sp_ids, sp_weights, + partition_strategy="mod", name=None, combiner="mean"): """Computes embeddings for the given ids and weights. @@ -120,16 +174,15 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, Args: params: A single tensor representing the complete embedding tensor, or a list of P tensors all of same shape except for the first dimension, - representing sharded embedding tensors. In the latter case, the ids are - partitioned by id % P, and we do separate lookups in params[p] for - 0 <= p < P, and then stitch the results back together into a single - result tensor. The first dimension is allowed to vary as the vocab - size is not necessarily a multiple of P. + representing sharded embedding tensors. sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), where N is typically batch size and M is arbitrary. sp_weights: either a SparseTensor of float / double weights, or None to indicate all weights should be taken to be 1. If specified, sp_weights must have exactly the same shape and indices as sp_ids. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: Optional name for the op. combiner: A string specifying the reduction op. Currently "mean" and "sum" are supported. @@ -187,7 +240,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, else: idx = None - embeddings = embedding_lookup(params, ids) + embeddings = embedding_lookup( + params, ids, partition_strategy=partition_strategy) if not ignore_weights: weights = sp_weights.values if weights.dtype != embeddings.dtype: diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index e8880d59a8b..72adf9e4986 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -553,6 +553,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, sampled_values=None, subtract_log_q=True, remove_accidental_hits=False, + partition_strategy="mod", name=None): """Helper function for nce_loss and sampled_softmax_loss functions. @@ -567,7 +568,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, Args: weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` objects whose concatenation along dimension 0 has shape - `[num_classes, dim]`. The (possibly-sharded) class embeddings. + `[num_classes, dim]`. The (possibly-partitioned) class embeddings. biases: A `Tensor` of shape `[num_classes]`. The class biases. inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. @@ -586,6 +587,9 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, remove_accidental_hits: A `bool`. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is False. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. + Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). Returns: out_logits, out_labels: `Tensor` objects each with shape @@ -624,7 +628,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, all_ids = array_ops.concat(0, [labels_flat, sampled]) # weights shape is [num_classes, dim] - all_w = embedding_ops.embedding_lookup(weights, all_ids) + all_w = embedding_ops.embedding_lookup( + weights, all_ids, partition_strategy=partition_strategy) all_b = embedding_ops.embedding_lookup(biases, all_ids) # true_w shape is [batch_size * num_true, dim] # true_b is a [batch_size * num_true] tensor @@ -704,6 +709,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, + partition_strategy="mod", name="nce_loss"): """Computes and returns the noise-contrastive estimation training loss. @@ -726,7 +732,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, Args: weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` objects whose concatenation along dimension 0 has shape - [num_classes, dim]. The (possibly-sharded) class embeddings. + [num_classes, dim]. The (possibly-partitioned) class embeddings. biases: A `Tensor` of shape `[num_classes]`. The class biases. inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. @@ -745,6 +751,9 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, our [Candidate Sampling Algorithms Reference] (../../extras/candidate_sampling.pdf). Default is False. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. + Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). Returns: @@ -756,6 +765,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, sampled_values=sampled_values, subtract_log_q=True, remove_accidental_hits=remove_accidental_hits, + partition_strategy=partition_strategy, name=name) sampled_losses = sigmoid_cross_entropy_with_logits(logits, labels, @@ -769,6 +779,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, + partition_strategy="mod", name="sampled_softmax_loss"): """Computes and returns the sampled softmax training loss. @@ -805,6 +816,9 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, remove_accidental_hits: A `bool`. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is True. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. + Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). Returns: @@ -817,6 +831,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, sampled_values=sampled_values, subtract_log_q=True, remove_accidental_hits=remove_accidental_hits, + partition_strategy=partition_strategy, name=name) sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels) # sampled_losses is a [batch_size] tensor. diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 1b9b6401862..6cd60310441 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -20,12 +20,12 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.training import slot_creator # TODO(touts): switch to variables.Variable. @@ -209,22 +209,19 @@ class ExponentialMovingAverage(object): raise TypeError("The variables must be float or double: %s" % var) if var in self._averages: raise ValueError("Moving average already computed for: %s" % var) - with ops.name_scope(var.op.name + "/" + self._name) as scope: - # For variables: to lower communication bandwidth across devices we keep - # the moving averages on the same device as the variables. For other - # tensors, we rely on the existing device allocation mechanism. - if isinstance(var, variables.Variable): - with ops.device(var.device): - avg = variables.Variable(var.initialized_value(), - name=scope, trainable=False) - elif var.op.type == "Variable": - with ops.device(var.device): - avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()), - name=scope, trainable=False) - else: - avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()), - name=scope, trainable=False) - self._averages[var] = avg + + # For variables: to lower communication bandwidth across devices we keep + # the moving averages on the same device as the variables. For other + # tensors, we rely on the existing device allocation mechanism. + if isinstance(var, variables.Variable): + avg = slot_creator.create_slot( + var, var.initialized_value(), self._name, + colocate_with_primary=True) + else: + avg = slot_creator.create_zeros_slot( + var, self._name, colocate_with_primary=(var.op.type == "Variable")) + self._averages[var] = avg + with ops.name_scope(self._name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") if self._num_updates is not None: diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index efe2fa13c26..0c9c0fda29f 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -22,11 +22,11 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.training import slot_creator class Optimizer(object): @@ -418,6 +418,22 @@ class Optimizer(object): # Utility methods for subclasses. # -------------- + def _slot_dict(self, slot_name): + """Returns a dict for caching slots created under the given name. + + Args: + slot_name: Name for the slot. + + Returns: + A dict that maps primary `Variable` objects to the slot created + for that variable, under the given slot name. + """ + named_slots = self._slots.get(slot_name, None) + if named_slots is None: + named_slots = {} + self._slots[slot_name] = named_slots + return named_slots + def _get_or_make_slot(self, var, val, slot_name, op_name): """Find or create a slot for a variable. @@ -431,19 +447,10 @@ class Optimizer(object): Returns: A `Variable` object. """ - named_slots = self._slots.get(slot_name, None) - if named_slots is None: - named_slots = {} - self._slots[slot_name] = named_slots - slot = named_slots.get(var, None) - if slot is None: - # Scope the slot name in the namespace of the Variable and - # create the slot on the same device as the variable. - with ops.name_scope(var.op.name + "/" + op_name) as scope: - with ops.device(var.device): - slot = variables.Variable(val, name=scope, trainable=False) - named_slots[var] = slot - return slot + named_slots = self._slot_dict(slot_name) + if var not in named_slots: + named_slots[var] = slot_creator.create_slot(var, val, op_name) + return named_slots[var] def _zeros_slot(self, var, slot_name, op_name): """Find or create a slot initialized with 0.0. @@ -457,5 +464,7 @@ class Optimizer(object): Returns: A `Variable` object. """ - val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype) - return self._get_or_make_slot(var, val, slot_name, op_name) + named_slots = self._slot_dict(slot_name) + if var not in named_slots: + named_slots[var] = slot_creator.create_zeros_slot(var, op_name) + return named_slots[var] diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py new file mode 100644 index 00000000000..dcd02f2fafb --- /dev/null +++ b/tensorflow/python/training/slot_creator.py @@ -0,0 +1,108 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Standard functions for creating slots. + +A slot is a `Variable` created with the same shape as a primary variable or +`Tensor`. A slot is always scoped in the namespace of the primary object and +typically has the same device and type. + +Slots are typically used as accumulators to track values associated with +the primary object: + +```python +# Optimizers can create a slot for each variable to track accumulators +accumulators = {var : create_zeros_slot(var, "momentum") for var in vs} +for var in vs: + apply_momentum(var, accumulators[var], lr, grad, momentum_tensor) + +# Slots can also be used for moving averages +mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg") +update_mavg = mavg.assign_sub((mavg - var) * (1 - decay)) +``` +""" +# pylint: disable=g-bad-name + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables + + +def _create_slot_var(primary, val, scope): + """Helper function for creating a slot variable.""" + + slot = variables.Variable(val, name=scope, trainable=False) + # pylint: disable=protected-access + if isinstance(primary, variables.Variable) and primary._save_slice_info: + # Primary is a partitioned variable, so we need to also indicate that + # the slot is a partitioned variable. Slots have the same partitioning + # as their primaries. + real_slot_name = scope[len(primary.op.name + "/"):-1] + slice_info = primary._save_slice_info + slot._set_save_slice_info(variables.Variable.SaveSliceInfo( + slice_info.full_name + "/" + real_slot_name, + slice_info.full_shape[:], + slice_info.var_offset[:], + slice_info.var_shape[:])) + # pylint: enable=protected-access + return slot + + +def create_slot(primary, val, name, colocate_with_primary=True): + """Create a slot initialized to the given value. + + The type of the slot is determined by the given value. + + Args: + primary: The primary `Variable` or `Tensor`. + val: A `Tensor` specifying the initial value of the slot. + name: Name to use for the slot variable. + colocate_with_primary: Boolean. If True the slot is located + on the same device as `primary`. + + Returns: + A `Variable` object. + """ + # Scope the slot name in the namespace of the primary variable. + with ops.name_scope(primary.op.name + "/" + name) as scope: + if colocate_with_primary: + with ops.device(primary.device): + return _create_slot_var(primary, val, scope) + else: + return _create_slot_var(primary, val, scope) + + +def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True): + """Create a slot initialized to 0 with same shape as the primary object. + + Args: + primary: The primary `Variable` or `Tensor`. + name: Name to use for the slot variable. + dtype: Type of the slot variable. Defaults to the type of `primary`. + colocate_with_primary: Boolean. If True the slot is located + on the same device as `primary`. + + Returns: + A `Variable` object. + """ + if dtype is None: + dtype = primary.dtype + val = array_ops.zeros(primary.get_shape().as_list(), dtype=dtype) + return create_slot(primary, val, name, + colocate_with_primary=colocate_with_primary) diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py new file mode 100644 index 00000000000..d0c57c40ed4 --- /dev/null +++ b/tensorflow/python/training/slot_creator_test.py @@ -0,0 +1,78 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional test for slot_creator.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow.python.platform + +import tensorflow as tf +from tensorflow.python.training import slot_creator + + +class SlotCreatorTest(tf.test.TestCase): + + def testCreateSlotFromVariable(self): + with self.test_session(): + v = tf.Variable([1.0, 2.5], name="var") + slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") + + tf.initialize_all_variables().run() + + self.assertEqual(slot.op.name, "var/slot") + self.assertEqual(slot.get_shape().as_list(), [2]) + self.assertEqual(slot.dtype.base_dtype, tf.float32) + self.assertAllEqual(slot.eval(), [1.0, 2.5]) + + def testCreateSlotFromTensor(self): + with self.test_session(): + v = tf.constant([1.0, 2.5], name="const") + slot = slot_creator.create_slot(v, v * 2, name="slot") + + tf.initialize_all_variables().run() + + self.assertEqual(slot.op.name, "const/slot") + self.assertEqual(slot.get_shape().as_list(), [2]) + self.assertEqual(slot.dtype.base_dtype, tf.float32) + self.assertAllEqual(slot.eval(), [2.0, 5.0]) + + def testCreateZerosSlotFromVariable(self): + with self.test_session(): + v = tf.Variable([1.0, 2.5], name="var") + slot = slot_creator.create_zeros_slot(v, name="slot", dtype=tf.float64) + + tf.initialize_all_variables().run() + + self.assertEqual(slot.op.name, "var/slot") + self.assertEqual(slot.get_shape().as_list(), [2]) + self.assertEqual(slot.dtype.base_dtype, tf.float64) + self.assertAllEqual(slot.eval(), [0.0, 0.0]) + + def testCreateZerosSlotFromTensor(self): + with self.test_session(): + v = tf.constant([1.0, 2.5], name="const") + + slot = slot_creator.create_zeros_slot(v, name="slot") + + tf.initialize_all_variables().run() + + self.assertEqual(slot.op.name, "const/slot") + self.assertEqual(slot.get_shape().as_list(), [2]) + self.assertEqual(slot.dtype.base_dtype, tf.float32) + self.assertAllEqual(slot.eval(), [0.0, 0.0]) + +if __name__ == "__main__": + tf.test.main()