TensorFlow: Upstream changes to git.
Change 109730179 Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input. Change 109729548 TensorFlow: add RELEASE.md notes for 0.6.0. Change 109728185 Make seq2seq_test non-flaky by setting python and numpy random seed. Change 109725913 Refactor slot creation in optimizers and moving averages to separate file Change 109718024 TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s. Change 109712251 More performance improvement for convnet on GPU. + Switch forward convolution format to NCHW. + Allocate scratch space for forward- and backward- convolutions. + Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space limit. The default limit in 1GB. Change 109710898 Added extract_sub_graph utility function Base CL: 109731609
This commit is contained in:
parent
ddd4aaf528
commit
2c3738db9c
RELEASE.md
tensorflow
27
RELEASE.md
27
RELEASE.md
@ -1,3 +1,30 @@
|
|||||||
|
# Release 0.6.0
|
||||||
|
|
||||||
|
## Major Features and Improvements
|
||||||
|
|
||||||
|
* Python 3.3+ support via changes to python codebase and ability
|
||||||
|
to specify python version via ./configure.
|
||||||
|
|
||||||
|
* Some improvements to GPU performance and memory usage:
|
||||||
|
[convnet benchmarks](https://github.com/soumith/convnet-benchmarks/issues/66)
|
||||||
|
roughly equivalent with native cudnn v2 performance. Improvements mostly due
|
||||||
|
to moving to 32-bit indices, faster shuffling kernels. More improvements to
|
||||||
|
come in later releases.
|
||||||
|
|
||||||
|
|
||||||
|
## Bug fixes
|
||||||
|
|
||||||
|
* Lots of fixes to documentation and tutorials, many contributed
|
||||||
|
by the public.
|
||||||
|
|
||||||
|
* 271 closed issues on github issues.
|
||||||
|
|
||||||
|
## Backwards-incompatible changes
|
||||||
|
|
||||||
|
* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion'
|
||||||
|
attribute from 0.0 to 1.0. This was a bug in the original release
|
||||||
|
that is now fixed.
|
||||||
|
|
||||||
# Release 0.5.0
|
# Release 0.5.0
|
||||||
|
|
||||||
Initial release of TensorFlow.
|
Initial release of TensorFlow.
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -756,17 +757,6 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
|
|||||||
|
|
||||||
// GPU definitions of both ops.
|
// GPU definitions of both ops.
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
namespace {
|
|
||||||
template <typename T>
|
|
||||||
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
|
||||||
uint64 size) {
|
|
||||||
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
|
|
||||||
size * sizeof(T));
|
|
||||||
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
|
||||||
return typed;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// The slow version (but compiles for GPU)
|
// The slow version (but compiles for GPU)
|
||||||
|
|
||||||
// Backprop for input.
|
// Backprop for input.
|
||||||
@ -929,10 +919,15 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
|||||||
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
|
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
|
||||||
pre_transformed_in_backprop.template flat<T>().size());
|
pre_transformed_in_backprop.template flat<T>().size());
|
||||||
|
|
||||||
|
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
|
||||||
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
|
||||||
|
);
|
||||||
|
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||||
|
context);
|
||||||
bool cudnn_launch_status =
|
bool cudnn_launch_status =
|
||||||
stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc,
|
stream->ThenConvolveBackwardDataWithScratch(
|
||||||
out_backprop_ptr, conv_desc,
|
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||||
input_desc, &in_backprop_ptr)
|
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if (!cudnn_launch_status) {
|
if (!cudnn_launch_status) {
|
||||||
@ -1185,7 +1180,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
|||||||
context->eigen_device<Device>(),
|
context->eigen_device<Device>(),
|
||||||
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
|
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
|
||||||
transformed_input.tensor<T, 4>());
|
transformed_input.tensor<T, 4>());
|
||||||
|
|
||||||
auto out_backprop_ptr =
|
auto out_backprop_ptr =
|
||||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||||
transformed_out_backprop.template flat<T>().size());
|
transformed_out_backprop.template flat<T>().size());
|
||||||
@ -1196,10 +1190,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
|||||||
AsDeviceMemory(transformed_input.template flat<T>().data(),
|
AsDeviceMemory(transformed_input.template flat<T>().data(),
|
||||||
transformed_input.template flat<T>().size());
|
transformed_input.template flat<T>().size());
|
||||||
|
|
||||||
|
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
|
||||||
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
|
||||||
|
);
|
||||||
|
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||||
|
context);
|
||||||
bool cudnn_launch_status =
|
bool cudnn_launch_status =
|
||||||
stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc,
|
stream->ThenConvolveBackwardFilterWithScratch(
|
||||||
out_backprop_ptr, conv_desc,
|
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||||
filter_desc, &filter_backprop_ptr)
|
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||||
|
&scratch_allocator)
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if (!cudnn_launch_status) {
|
if (!cudnn_launch_status) {
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/public/tensor.h"
|
#include "tensorflow/core/public/tensor.h"
|
||||||
#include "tensorflow/core/public/tensor_shape.h"
|
#include "tensorflow/core/public/tensor_shape.h"
|
||||||
@ -34,6 +35,7 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D")
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
namespace {
|
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
|
||||||
template <typename T>
|
int64 default_value_in_bytes) {
|
||||||
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||||
uint64 size) {
|
if (workspace_limit_in_mb_str != nullptr &&
|
||||||
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
|
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||||
size * sizeof(T));
|
int64 scratch_limit_in_mb = -1;
|
||||||
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||||
return typed;
|
&scratch_limit_in_mb)) {
|
||||||
|
return scratch_limit_in_mb * (1 << 20);
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||||
|
<< workspace_limit_in_mb_str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return default_value_in_bytes;
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LaunchConvOp<GPUDevice, T> {
|
struct LaunchConvOp<GPUDevice, T> {
|
||||||
@ -287,18 +295,34 @@ struct LaunchConvOp<GPUDevice, T> {
|
|||||||
input = transformed_input;
|
input = transformed_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Convert the input tensor from NHWC to NCHW.
|
||||||
|
Tensor transformed_input;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({input.dim_size(0), input.dim_size(3),
|
||||||
|
input.dim_size(1), input.dim_size(2)}),
|
||||||
|
&transformed_input));
|
||||||
|
functor::NHWCToNCHW<GPUDevice, T>()(
|
||||||
|
ctx->eigen_device<GPUDevice>(),
|
||||||
|
const_cast<const Tensor&>(input).tensor<T, 4>(),
|
||||||
|
transformed_input.tensor<T, 4>());
|
||||||
|
input = transformed_input;
|
||||||
|
}
|
||||||
|
|
||||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||||
input_desc.set_count(input.dim_size(0))
|
input_desc.set_count(input.dim_size(0))
|
||||||
.set_height(input.dim_size(1))
|
.set_feature_map_count(input.dim_size(1))
|
||||||
.set_width(input.dim_size(2))
|
.set_height(input.dim_size(2))
|
||||||
.set_feature_map_count(input.dim_size(3))
|
.set_width(input.dim_size(3))
|
||||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
|
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||||
output_desc.set_count(output->dim_size(0))
|
output_desc.set_count(output->dim_size(0))
|
||||||
.set_height(output->dim_size(1))
|
.set_height(output->dim_size(1))
|
||||||
.set_width(output->dim_size(2))
|
.set_width(output->dim_size(2))
|
||||||
.set_feature_map_count(output->dim_size(3))
|
.set_feature_map_count(output->dim_size(3))
|
||||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
|
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||||
filter_desc.set_input_filter_height(filter.dim_size(0))
|
filter_desc.set_input_filter_height(filter.dim_size(0))
|
||||||
.set_input_filter_width(filter.dim_size(1))
|
.set_input_filter_width(filter.dim_size(1))
|
||||||
@ -320,17 +344,31 @@ struct LaunchConvOp<GPUDevice, T> {
|
|||||||
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
|
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
|
||||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||||
|
|
||||||
|
Tensor transformed_output;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({output->dim_size(0), output->dim_size(3),
|
||||||
|
output->dim_size(1), output->dim_size(2)}),
|
||||||
|
&transformed_output));
|
||||||
|
|
||||||
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||||
input.template flat<T>().size());
|
input.template flat<T>().size());
|
||||||
auto filter_ptr =
|
auto filter_ptr =
|
||||||
AsDeviceMemory(transformed_filter.template flat<T>().data(),
|
AsDeviceMemory(transformed_filter.template flat<T>().data(),
|
||||||
transformed_filter.template flat<T>().size());
|
transformed_filter.template flat<T>().size());
|
||||||
auto output_ptr = AsDeviceMemory(output->template flat<T>().data(),
|
auto output_ptr =
|
||||||
output->template flat<T>().size());
|
AsDeviceMemory(transformed_output.template flat<T>().data(),
|
||||||
|
transformed_output.template flat<T>().size());
|
||||||
|
|
||||||
|
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
|
||||||
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
|
||||||
|
);
|
||||||
|
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||||
bool cudnn_launch_status =
|
bool cudnn_launch_status =
|
||||||
stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr,
|
stream->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
|
||||||
conv_desc, output_desc, &output_ptr)
|
filter_ptr, conv_desc, output_desc,
|
||||||
|
&output_ptr, &scratch_allocator)
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if (!cudnn_launch_status) {
|
if (!cudnn_launch_status) {
|
||||||
@ -338,6 +376,12 @@ struct LaunchConvOp<GPUDevice, T> {
|
|||||||
"cuDNN launch failure : input shape(", input.shape().DebugString(),
|
"cuDNN launch failure : input shape(", input.shape().DebugString(),
|
||||||
") filter shape(", filter.shape().DebugString(), ")"));
|
") filter shape(", filter.shape().DebugString(), ")"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert the output tensor back from NHWC to NCHW.
|
||||||
|
functor::NCHWToNHWC<GPUDevice, T>()(
|
||||||
|
ctx->eigen_device<GPUDevice>(),
|
||||||
|
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
|
||||||
|
output->tensor<T, 4>());
|
||||||
} else {
|
} else {
|
||||||
LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
|
LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
|
||||||
padding, output);
|
padding, output);
|
||||||
|
84
tensorflow/core/kernels/conv_ops_gpu.h
Normal file
84
tensorflow/core/kernels/conv_ops_gpu.h
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// TODO(zhengxq): move this to gpu_util.h. The use of such wrapers is wide
|
||||||
|
// spread.
|
||||||
|
template <typename T>
|
||||||
|
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
||||||
|
uint64 size) {
|
||||||
|
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
|
||||||
|
size * sizeof(T));
|
||||||
|
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
||||||
|
return typed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the Cudnn workspace limit from the environment variable, which is in MB.
|
||||||
|
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||||
|
// default value.
|
||||||
|
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
|
||||||
|
int64 default_value_in_bytes);
|
||||||
|
|
||||||
|
// A class to provide scratch-space allocator for Stream-Executor Cudnn
|
||||||
|
// callback. TensorFlow is responsible for releasing the temporary buffers after
|
||||||
|
// the kernel finishes.
|
||||||
|
class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
|
||||||
|
public:
|
||||||
|
virtual ~CudnnScratchAllocator() {}
|
||||||
|
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
||||||
|
: memory_limit_(memory_limit), context_(context) {}
|
||||||
|
virtual int64 GetMemoryLimitInBytes(
|
||||||
|
perftools::gputools::Stream* stream) override {
|
||||||
|
return memory_limit_;
|
||||||
|
}
|
||||||
|
virtual perftools::gputools::port::StatusOr<
|
||||||
|
perftools::gputools::DeviceMemory<uint8>>
|
||||||
|
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
|
||||||
|
Tensor temporary_memory;
|
||||||
|
|
||||||
|
Status allocation_status(context_->allocate_temp(
|
||||||
|
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
|
||||||
|
if (!allocation_status.ok()) {
|
||||||
|
LOG(WARNING) << allocation_status;
|
||||||
|
context_->SetStatus(allocation_status);
|
||||||
|
return perftools::gputools::port::StatusOr<
|
||||||
|
perftools::gputools::DeviceMemory<uint8>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return perftools::gputools::port::StatusOr<
|
||||||
|
perftools::gputools::DeviceMemory<uint8>>(
|
||||||
|
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||||
|
temporary_memory.flat<uint8>().size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 memory_limit_;
|
||||||
|
OpKernelContext* context_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
|
@ -19,6 +19,7 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import copy
|
||||||
|
|
||||||
import tensorflow.python.platform
|
import tensorflow.python.platform
|
||||||
|
|
||||||
@ -155,3 +156,64 @@ def pin_to_cpu(op):
|
|||||||
logging.info("Operation %s has been assigned to a non-CPU (%s), so "
|
logging.info("Operation %s has been assigned to a non-CPU (%s), so "
|
||||||
"it will not be pinned to the CPU.", op.name, dev.device_type)
|
"it will not be pinned to the CPU.", op.name, dev.device_type)
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def _node_name(n):
|
||||||
|
if n.startswith("^"):
|
||||||
|
return n[1:]
|
||||||
|
else:
|
||||||
|
return n.split(":")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_sub_graph(graph_def, dest_nodes):
|
||||||
|
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_def: A graph_pb2.GraphDef proto.
|
||||||
|
dest_nodes: A list of strings specifying the destination node names.
|
||||||
|
Returns:
|
||||||
|
The GraphDef of the sub-graph.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||||
|
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
|
||||||
|
|
||||||
|
edges = {} # Keyed by the dest node name.
|
||||||
|
name_to_node_map = {} # Keyed by node name.
|
||||||
|
|
||||||
|
# Keeps track of node sequences. It is important to still output the
|
||||||
|
# operations in the original order.
|
||||||
|
node_seq = {} # Keyed by node name.
|
||||||
|
seq = 0
|
||||||
|
for node in graph_def.node:
|
||||||
|
n = _node_name(node.name)
|
||||||
|
name_to_node_map[n] = node
|
||||||
|
edges[n] = [_node_name(x) for x in node.input]
|
||||||
|
node_seq[n] = seq
|
||||||
|
seq += 1
|
||||||
|
|
||||||
|
for d in dest_nodes:
|
||||||
|
assert d in name_to_node_map, "%d is not in graph" % d
|
||||||
|
|
||||||
|
nodes_to_keep = set()
|
||||||
|
# Breadth first search to find all the nodes that we should keep.
|
||||||
|
next_to_visit = dest_nodes[:]
|
||||||
|
while next_to_visit:
|
||||||
|
n = next_to_visit[0]
|
||||||
|
del next_to_visit[0]
|
||||||
|
if n in nodes_to_keep:
|
||||||
|
# Already visited this node.
|
||||||
|
continue
|
||||||
|
nodes_to_keep.add(n)
|
||||||
|
next_to_visit += edges[n]
|
||||||
|
|
||||||
|
nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
|
||||||
|
# Now construct the output GraphDef
|
||||||
|
out = graph_pb2.GraphDef()
|
||||||
|
for n in nodes_to_keep_list:
|
||||||
|
out.node.extend([copy.deepcopy(name_to_node_map[n])])
|
||||||
|
|
||||||
|
return out
|
||||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import tensorflow.python.platform
|
import tensorflow.python.platform
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python.client import graph_util
|
from tensorflow.python.client import graph_util
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -140,6 +142,34 @@ class DeviceFunctionsTest(googletest.TestCase):
|
|||||||
self.assertEqual(const_4.device, "/device:CPU:1")
|
self.assertEqual(const_4.device, "/device:CPU:1")
|
||||||
self.assertEqual(const_5.device, "/replica:0")
|
self.assertEqual(const_5.device, "/replica:0")
|
||||||
|
|
||||||
|
def testExtractSubGraph(self):
|
||||||
|
graph_def = tf.GraphDef()
|
||||||
|
n1 = graph_def.node.add()
|
||||||
|
n1.name = "n1"
|
||||||
|
n1.input.extend(["n5"])
|
||||||
|
n2 = graph_def.node.add()
|
||||||
|
n2.name = "n2"
|
||||||
|
# Take the first output of the n1 node as the input.
|
||||||
|
n2.input.extend(["n1:0"])
|
||||||
|
n3 = graph_def.node.add()
|
||||||
|
n3.name = "n3"
|
||||||
|
# Add a control input (which isn't really needed by the kernel, but
|
||||||
|
# rather to enforce execution order between nodes).
|
||||||
|
n3.input.extend(["^n2"])
|
||||||
|
n4 = graph_def.node.add()
|
||||||
|
n4.name = "n4"
|
||||||
|
|
||||||
|
# It is fine to have a loops in the graph as well.
|
||||||
|
n5 = graph_def.node.add()
|
||||||
|
n5.name = "n5"
|
||||||
|
n5.input.extend(["n1"])
|
||||||
|
|
||||||
|
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
|
||||||
|
self.assertEqual("n1", sub_graph.node[0].name)
|
||||||
|
self.assertEqual("n2", sub_graph.node[1].name)
|
||||||
|
self.assertEqual("n3", sub_graph.node[2].name)
|
||||||
|
self.assertEqual("n5", sub_graph.node[3].name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -115,26 +115,34 @@ def _PName(param_id):
|
|||||||
|
|
||||||
def _EmbeddingParams(num_shards, vocab_size,
|
def _EmbeddingParams(num_shards, vocab_size,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
shape=None):
|
shape=None,
|
||||||
|
use_shapeless_placeholder=False):
|
||||||
p = []
|
p = []
|
||||||
params = {}
|
params = {}
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
if not shape: shape = [10]
|
if not shape: shape = [10]
|
||||||
assert not vocab_size % num_shards
|
|
||||||
shape = [vocab_size // num_shards] + shape
|
|
||||||
for i in range(num_shards):
|
for i in range(num_shards):
|
||||||
|
shard_shape = [vocab_size // num_shards] + shape
|
||||||
|
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
|
||||||
|
shard_shape[0] += 1
|
||||||
|
|
||||||
param_name = _PName(i)
|
param_name = _PName(i)
|
||||||
constant_t = tf.constant(1.0, shape=shape, dtype=dtype,
|
|
||||||
name=param_name)
|
if use_shapeless_placeholder:
|
||||||
p.append(constant_t)
|
param = tf.placeholder(dtype, shape=None, name=param_name)
|
||||||
|
else:
|
||||||
|
param = tf.constant(1.0, shape=shard_shape, dtype=dtype, name=param_name)
|
||||||
|
p.append(param)
|
||||||
np_type = "f" if dtype == tf.float32 else "d"
|
np_type = "f" if dtype == tf.float32 else "d"
|
||||||
val = (np.random.rand(*shape).astype(np_type)) + 1
|
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
|
||||||
params[param_name + ":0"] = val
|
params[param_name + ":0"] = val
|
||||||
feed_dict[constant_t.name] = val
|
feed_dict[param.name] = val
|
||||||
return p, params, feed_dict
|
return p, params, feed_dict
|
||||||
|
|
||||||
|
|
||||||
def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
|
def _EmbeddingResult(params, id_vals, num_shards, vocab_size,
|
||||||
|
partition_strategy="mod",
|
||||||
|
weight_vals=None):
|
||||||
if weight_vals is None:
|
if weight_vals is None:
|
||||||
weight_vals = np.copy(id_vals)
|
weight_vals = np.copy(id_vals)
|
||||||
weight_vals.fill(1)
|
weight_vals.fill(1)
|
||||||
@ -147,8 +155,22 @@ def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
|
|||||||
ids = [ids]
|
ids = [ids]
|
||||||
wts = [wts]
|
wts = [wts]
|
||||||
for i, wt_val in zip(ids, wts):
|
for i, wt_val in zip(ids, wts):
|
||||||
val = np.copy(params[_PName(i % num_shards) + ":0"][
|
if partition_strategy == "mod":
|
||||||
i // num_shards, :]) * wt_val
|
val = np.copy(params[_PName(i % num_shards) + ":0"][
|
||||||
|
i // num_shards, :]) * wt_val
|
||||||
|
elif partition_strategy == "div":
|
||||||
|
ids_per_partition, extras = divmod(vocab_size, num_shards)
|
||||||
|
threshold = extras * (ids_per_partition + 1)
|
||||||
|
if i < threshold:
|
||||||
|
partition = i // (ids_per_partition + 1)
|
||||||
|
offset = i % (ids_per_partition + 1)
|
||||||
|
else:
|
||||||
|
partition = extras + (i - threshold) // ids_per_partition
|
||||||
|
offset = (i - threshold) % ids_per_partition
|
||||||
|
val = np.copy(
|
||||||
|
params[_PName(partition) + ":0"][offset, :]) * wt_val
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
if val_aggr is None:
|
if val_aggr is None:
|
||||||
assert wt_aggr is None
|
assert wt_aggr is None
|
||||||
val_aggr = val
|
val_aggr = val
|
||||||
@ -182,17 +204,17 @@ class EmbeddingLookupTest(tf.test.TestCase):
|
|||||||
embedding = tf.nn.embedding_lookup(p, ids)
|
embedding = tf.nn.embedding_lookup(p, ids)
|
||||||
|
|
||||||
tf_result = embedding.eval(feed_dict=feed_dict)
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
|
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
||||||
self.assertAllEqual(np_result, tf_result)
|
self.assertAllEqual(np_result, tf_result)
|
||||||
self.assertShapeEqual(np_result, embedding)
|
self.assertShapeEqual(np_result, embedding)
|
||||||
|
|
||||||
def testSharded(self):
|
def testShardedModPartitioningInt32Ids(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
num_shards = 5
|
num_shards = 5
|
||||||
vocab_size = 25
|
vocab_size = 13
|
||||||
# Embedding dimensions is 10. The 10 x vocab_size embedding
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
||||||
# parameters are spread in num_shards matrices, so each
|
# parameters are spread in num_shards matrices, so the first
|
||||||
# matrix is 10 x (vocab_size / num_shards)
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
||||||
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
||||||
|
|
||||||
num_vals = 30
|
num_vals = 30
|
||||||
@ -204,10 +226,103 @@ class EmbeddingLookupTest(tf.test.TestCase):
|
|||||||
|
|
||||||
embedding = tf.nn.embedding_lookup(p, ids)
|
embedding = tf.nn.embedding_lookup(p, ids)
|
||||||
tf_result = embedding.eval(feed_dict=feed_dict)
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
|
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
||||||
self.assertAllEqual(np_result, tf_result)
|
self.assertAllEqual(np_result, tf_result)
|
||||||
self.assertShapeEqual(np_result, embedding)
|
self.assertShapeEqual(np_result, embedding)
|
||||||
|
|
||||||
|
def testShardedModPartitioningInt64Ids(self):
|
||||||
|
with self.test_session():
|
||||||
|
num_shards = 5
|
||||||
|
vocab_size = 13
|
||||||
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
||||||
|
# parameters are spread in num_shards matrices, so the first
|
||||||
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
||||||
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
||||||
|
|
||||||
|
num_vals = 30
|
||||||
|
# Fetch num_vals embeddings for random word ids. Since
|
||||||
|
# num_vals > vocab_size, this ought to have repetitions, so
|
||||||
|
# will test that aspect.
|
||||||
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
||||||
|
ids = tf.constant(list(id_vals), dtype=tf.int64)
|
||||||
|
|
||||||
|
embedding = tf.nn.embedding_lookup(p, ids)
|
||||||
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
|
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
||||||
|
self.assertAllEqual(np_result, tf_result)
|
||||||
|
self.assertShapeEqual(np_result, embedding)
|
||||||
|
|
||||||
|
def testShardedDivPartitioningInt32Ids(self):
|
||||||
|
with self.test_session():
|
||||||
|
num_shards = 5
|
||||||
|
vocab_size = 13
|
||||||
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
||||||
|
# parameters are spread in num_shards matrices, so the first
|
||||||
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
||||||
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
||||||
|
|
||||||
|
num_vals = 30
|
||||||
|
# Fetch num_vals embeddings for random word ids. Since
|
||||||
|
# num_vals > vocab_size, this ought to have repetitions, so
|
||||||
|
# will test that aspect.
|
||||||
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
||||||
|
ids = tf.constant(list(id_vals), dtype=tf.int32)
|
||||||
|
|
||||||
|
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
|
||||||
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
|
np_result, _ = _EmbeddingResult(
|
||||||
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
||||||
|
self.assertAllEqual(np_result, tf_result)
|
||||||
|
self.assertShapeEqual(np_result, embedding)
|
||||||
|
|
||||||
|
def testShardedDivPartitioningInt64Ids(self):
|
||||||
|
with self.test_session():
|
||||||
|
num_shards = 5
|
||||||
|
vocab_size = 13
|
||||||
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
||||||
|
# parameters are spread in num_shards matrices, so the first
|
||||||
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
||||||
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
||||||
|
|
||||||
|
num_vals = 30
|
||||||
|
# Fetch num_vals embeddings for random word ids. Since
|
||||||
|
# num_vals > vocab_size, this ought to have repetitions, so
|
||||||
|
# will test that aspect.
|
||||||
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
||||||
|
ids = tf.constant(list(id_vals), dtype=tf.int64)
|
||||||
|
|
||||||
|
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
|
||||||
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
|
np_result, _ = _EmbeddingResult(
|
||||||
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
||||||
|
self.assertAllEqual(np_result, tf_result)
|
||||||
|
self.assertShapeEqual(np_result, embedding)
|
||||||
|
|
||||||
|
def testShardedDivPartitioningUnknownParamShape(self):
|
||||||
|
with self.test_session():
|
||||||
|
num_shards = 5
|
||||||
|
vocab_size = 13
|
||||||
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
||||||
|
# parameters are spread in num_shards matrices, so the first
|
||||||
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
||||||
|
|
||||||
|
# We clear parameter shapes, to test when shape is not statically known.
|
||||||
|
p, params, feed_dict = _EmbeddingParams(
|
||||||
|
num_shards, vocab_size, use_shapeless_placeholder=True)
|
||||||
|
|
||||||
|
num_vals = 30
|
||||||
|
# Fetch num_vals embeddings for random word ids. Since
|
||||||
|
# num_vals > vocab_size, this ought to have repetitions, so
|
||||||
|
# will test that aspect.
|
||||||
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
||||||
|
ids = tf.constant(list(id_vals), dtype=tf.int64)
|
||||||
|
|
||||||
|
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
|
||||||
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
||||||
|
np_result, _ = _EmbeddingResult(
|
||||||
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
||||||
|
self.assertAllEqual(np_result, tf_result)
|
||||||
|
|
||||||
def testGradientsEmbeddingLookup(self):
|
def testGradientsEmbeddingLookup(self):
|
||||||
vocab_size = 9
|
vocab_size = 9
|
||||||
num_ids = 5
|
num_ids = 5
|
||||||
@ -326,7 +441,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
|
|||||||
return grouped_vals
|
return grouped_vals
|
||||||
|
|
||||||
def testEmbeddingLookupSparse(self):
|
def testEmbeddingLookupSparse(self):
|
||||||
vocab_size = 25
|
vocab_size = 13
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
param_shape = [2, 5]
|
param_shape = [2, 5]
|
||||||
|
|
||||||
@ -354,7 +469,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
|
|||||||
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
|
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
|
||||||
|
|
||||||
np_embedding_sum, np_weight_sum = _EmbeddingResult(
|
np_embedding_sum, np_weight_sum = _EmbeddingResult(
|
||||||
params, grouped_ids, num_shards,
|
params, grouped_ids, num_shards, vocab_size,
|
||||||
weight_vals=grouped_ignored_weights
|
weight_vals=grouped_ignored_weights
|
||||||
if ignore_weights else grouped_weights)
|
if ignore_weights else grouped_weights)
|
||||||
if combiner == "mean":
|
if combiner == "mean":
|
||||||
|
@ -354,6 +354,10 @@ class Seq2SeqTest(tf.test.TestCase):
|
|||||||
# We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
|
# We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
|
||||||
classes = 10
|
classes = 10
|
||||||
buckets = [(4, 4), (8, 8)]
|
buckets = [(4, 4), (8, 8)]
|
||||||
|
perplexities = [[], []] # Results for each bucket.
|
||||||
|
tf.set_random_seed(111)
|
||||||
|
random.seed(111)
|
||||||
|
np.random.seed(111)
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
# We use sampled softmax so we keep output projection separate.
|
# We use sampled softmax so we keep output projection separate.
|
||||||
@ -378,8 +382,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
|||||||
softmax_loss_function=SampledLoss)
|
softmax_loss_function=SampledLoss)
|
||||||
|
|
||||||
# Now we construct the copy model.
|
# Now we construct the copy model.
|
||||||
tf.set_random_seed(111)
|
batch_size = 8
|
||||||
batch_size = 32
|
|
||||||
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
|
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
|
||||||
out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
|
out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
|
||||||
weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)]
|
weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)]
|
||||||
@ -394,26 +397,26 @@ class Seq2SeqTest(tf.test.TestCase):
|
|||||||
update = optimizer.apply_gradients(zip(grads, params))
|
update = optimizer.apply_gradients(zip(grads, params))
|
||||||
updates.append(update)
|
updates.append(update)
|
||||||
sess.run([tf.initialize_all_variables()])
|
sess.run([tf.initialize_all_variables()])
|
||||||
for ep in xrange(3):
|
steps = 6
|
||||||
log_perp = 0.0
|
for _ in xrange(steps):
|
||||||
for _ in xrange(50):
|
bucket = random.choice(np.arange(len(buckets)))
|
||||||
bucket = random.choice(np.arange(len(buckets)))
|
length = buckets[bucket][0]
|
||||||
length = buckets[bucket][0]
|
i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
|
||||||
i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
|
dtype=np.int32) for _ in xrange(length)]
|
||||||
dtype=np.int32) for _ in xrange(length)]
|
# 0 is our "GO" symbol here.
|
||||||
# 0 is our "GO" symbol here.
|
o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i
|
||||||
o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i
|
feed = {}
|
||||||
feed = {}
|
for l in xrange(length):
|
||||||
for l in xrange(length):
|
feed[inp[l].name] = i[l]
|
||||||
feed[inp[l].name] = i[l]
|
feed[out[l].name] = o[l]
|
||||||
feed[out[l].name] = o[l]
|
if length < 8: # For the 4-bucket, we need the 5th as target.
|
||||||
if length < 8: # For the 4-bucket, we need the 5th as target.
|
feed[out[length].name] = o[length]
|
||||||
feed[out[length].name] = o[length]
|
res = sess.run([updates[bucket], losses[bucket]], feed)
|
||||||
res = sess.run([updates[bucket], losses[bucket]], feed)
|
perplexities[bucket].append(math.exp(float(res[1])))
|
||||||
log_perp += float(res[1])
|
for bucket in xrange(len(buckets)):
|
||||||
perp = math.exp(log_perp / 100)
|
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
|
||||||
print("step %d avg. perp %f" % ((ep + 1) * 50, perp))
|
self.assertLess(perplexities[bucket][1], perplexities[bucket][0])
|
||||||
self.assertLess(perp, 2.5)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -22,11 +22,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import constant_op
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
def embedding_lookup(params, ids, name=None):
|
def embedding_lookup(params, ids, partition_strategy="mod", name=None):
|
||||||
"""Looks up `ids` in a list of embedding tensors.
|
"""Looks up `ids` in a list of embedding tensors.
|
||||||
|
|
||||||
This function is used to perform parallel lookups on the list of
|
This function is used to perform parallel lookups on the list of
|
||||||
@ -35,16 +36,32 @@ def embedding_lookup(params, ids, name=None):
|
|||||||
interpreted as a partition of a larger embedding tensor.
|
interpreted as a partition of a larger embedding tensor.
|
||||||
|
|
||||||
If `len(params) > 1`, each element `id` of `ids` is partitioned between
|
If `len(params) > 1`, each element `id` of `ids` is partitioned between
|
||||||
the elements of `params` by computing `p = id % len(params)`, and is
|
the elements of `params` according to the `partition_strategy`.
|
||||||
then used to look up the slice `params[p][id // len(params), ...]`.
|
In all strategies, if the id space does not evenly divide the number of
|
||||||
|
partitions, each of the first `(max_id + 1) % len(params)` partitions will
|
||||||
|
be assigned one more id.
|
||||||
|
|
||||||
The results of the lookup are then concatenated into a dense
|
If `partition_strategy` is `"mod"`, we assign each id to partition
|
||||||
|
`p = id % len(params)`. For instance,
|
||||||
|
13 ids are split across 5 partitions as:
|
||||||
|
`[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
||||||
|
|
||||||
|
If `partition_strategy` is `"div"`, we assign ids to partitions in a
|
||||||
|
contiguous manner. In this case, 13 ids are split across 5 partitions as:
|
||||||
|
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
|
||||||
|
|
||||||
|
The results of the lookup are concatenated into a dense
|
||||||
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: A list of tensors with the same shape and type.
|
params: A list of tensors with the same type and which can be concatenated
|
||||||
|
along dimension 0. Each `Tensor` must be appropriately sized for the given
|
||||||
|
`partition_strategy`.
|
||||||
ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
|
ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
|
||||||
up in `params`.
|
up in `params`.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
|
||||||
|
is `"mod"`.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -67,23 +84,59 @@ def embedding_lookup(params, ids, name=None):
|
|||||||
ids = ops.convert_to_tensor(ids, name="ids")
|
ids = ops.convert_to_tensor(ids, name="ids")
|
||||||
flat_ids = array_ops.reshape(ids, [-1])
|
flat_ids = array_ops.reshape(ids, [-1])
|
||||||
original_indices = math_ops.range(array_ops.size(flat_ids))
|
original_indices = math_ops.range(array_ops.size(flat_ids))
|
||||||
# Compute flat_ids % partitions for each id
|
|
||||||
ids_mod_p = flat_ids % np
|
# Create p_assignments and set new_ids depending on the strategy.
|
||||||
if ids_mod_p.dtype != dtypes.int32:
|
if partition_strategy == "mod":
|
||||||
ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32)
|
p_assignments = flat_ids % np
|
||||||
# Partition single list of ids based on ids % np into np separate lists
|
new_ids = flat_ids // np
|
||||||
plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
|
elif partition_strategy == "div":
|
||||||
|
# Compute num_total_ids as the sum of dim-0 of params, then assign to
|
||||||
|
# partitions based on a constant number of ids per partition. Optimize
|
||||||
|
# if we already know the full shape statically.
|
||||||
|
dim_0_size = params[0].get_shape()[0]
|
||||||
|
for p in xrange(1, np):
|
||||||
|
dim_0_size += params[p].get_shape()[0]
|
||||||
|
if dim_0_size.value:
|
||||||
|
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
|
||||||
|
else:
|
||||||
|
dim_0_sizes = []
|
||||||
|
for p in xrange(np):
|
||||||
|
with ops.device(params[p].device):
|
||||||
|
dim_0_sizes.append(array_ops.shape(params[p])[0])
|
||||||
|
num_total_ids = math_ops.reduce_sum(
|
||||||
|
math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype))
|
||||||
|
ids_per_partition = num_total_ids // np
|
||||||
|
extras = num_total_ids % np
|
||||||
|
|
||||||
|
p_assignments = math_ops.maximum(
|
||||||
|
flat_ids // (ids_per_partition + 1),
|
||||||
|
(flat_ids - extras) // ids_per_partition)
|
||||||
|
|
||||||
|
# Emulate a conditional using a boolean indicator tensor
|
||||||
|
is_in_first_extras_partitions = math_ops.cast(
|
||||||
|
p_assignments < extras, flat_ids.dtype)
|
||||||
|
new_ids = (
|
||||||
|
is_in_first_extras_partitions * (
|
||||||
|
flat_ids % (ids_per_partition + 1)) +
|
||||||
|
(1 - is_in_first_extras_partitions) * (
|
||||||
|
(flat_ids - extras) % ids_per_partition))
|
||||||
|
else:
|
||||||
|
raise ValueError("Unrecognized partition strategy: " +
|
||||||
|
partition_strategy)
|
||||||
|
|
||||||
|
# Cast partition assignments to int32 for use in dynamic_partition.
|
||||||
|
# There really should not be more than 2^32 partitions.
|
||||||
|
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
|
||||||
|
# Partition list of ids based on assignments into np separate lists
|
||||||
|
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
|
||||||
# Similarly, partition the original indices.
|
# Similarly, partition the original indices.
|
||||||
pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p,
|
pindices = data_flow_ops.dynamic_partition(original_indices,
|
||||||
np)
|
p_assignments, np)
|
||||||
# Do np separate lookups, finding embeddings for plist[p] in params[p]
|
# Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||||
partitioned_result = []
|
partitioned_result = []
|
||||||
for p in xrange(np):
|
for p in xrange(np):
|
||||||
# TODO(agarwal): handle device allocations here and later in the
|
|
||||||
# colocate code.
|
|
||||||
gather_ids = plist[p] // np
|
|
||||||
with ops.device(params[p].device):
|
with ops.device(params[p].device):
|
||||||
partitioned_result.append(array_ops.gather(params[p], gather_ids))
|
partitioned_result.append(array_ops.gather(params[p], gather_ids[p]))
|
||||||
# Stitch these back together
|
# Stitch these back together
|
||||||
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
|
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
|
||||||
name=name)
|
name=name)
|
||||||
@ -106,6 +159,7 @@ def embedding_lookup(params, ids, name=None):
|
|||||||
|
|
||||||
# TODO(lif): Add support for higher-rank SparseTensors
|
# TODO(lif): Add support for higher-rank SparseTensors
|
||||||
def embedding_lookup_sparse(params, sp_ids, sp_weights,
|
def embedding_lookup_sparse(params, sp_ids, sp_weights,
|
||||||
|
partition_strategy="mod",
|
||||||
name=None,
|
name=None,
|
||||||
combiner="mean"):
|
combiner="mean"):
|
||||||
"""Computes embeddings for the given ids and weights.
|
"""Computes embeddings for the given ids and weights.
|
||||||
@ -120,16 +174,15 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
|
|||||||
Args:
|
Args:
|
||||||
params: A single tensor representing the complete embedding tensor,
|
params: A single tensor representing the complete embedding tensor,
|
||||||
or a list of P tensors all of same shape except for the first dimension,
|
or a list of P tensors all of same shape except for the first dimension,
|
||||||
representing sharded embedding tensors. In the latter case, the ids are
|
representing sharded embedding tensors.
|
||||||
partitioned by id % P, and we do separate lookups in params[p] for
|
|
||||||
0 <= p < P, and then stitch the results back together into a single
|
|
||||||
result tensor. The first dimension is allowed to vary as the vocab
|
|
||||||
size is not necessarily a multiple of P.
|
|
||||||
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
|
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
|
||||||
where N is typically batch size and M is arbitrary.
|
where N is typically batch size and M is arbitrary.
|
||||||
sp_weights: either a SparseTensor of float / double weights, or None to
|
sp_weights: either a SparseTensor of float / double weights, or None to
|
||||||
indicate all weights should be taken to be 1. If specified, sp_weights
|
indicate all weights should be taken to be 1. If specified, sp_weights
|
||||||
must have exactly the same shape and indices as sp_ids.
|
must have exactly the same shape and indices as sp_ids.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
|
||||||
|
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
|
||||||
name: Optional name for the op.
|
name: Optional name for the op.
|
||||||
combiner: A string specifying the reduction op. Currently "mean" and "sum"
|
combiner: A string specifying the reduction op. Currently "mean" and "sum"
|
||||||
are supported.
|
are supported.
|
||||||
@ -187,7 +240,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
|
|||||||
else:
|
else:
|
||||||
idx = None
|
idx = None
|
||||||
|
|
||||||
embeddings = embedding_lookup(params, ids)
|
embeddings = embedding_lookup(
|
||||||
|
params, ids, partition_strategy=partition_strategy)
|
||||||
if not ignore_weights:
|
if not ignore_weights:
|
||||||
weights = sp_weights.values
|
weights = sp_weights.values
|
||||||
if weights.dtype != embeddings.dtype:
|
if weights.dtype != embeddings.dtype:
|
||||||
|
@ -553,6 +553,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
|
|||||||
sampled_values=None,
|
sampled_values=None,
|
||||||
subtract_log_q=True,
|
subtract_log_q=True,
|
||||||
remove_accidental_hits=False,
|
remove_accidental_hits=False,
|
||||||
|
partition_strategy="mod",
|
||||||
name=None):
|
name=None):
|
||||||
"""Helper function for nce_loss and sampled_softmax_loss functions.
|
"""Helper function for nce_loss and sampled_softmax_loss functions.
|
||||||
|
|
||||||
@ -567,7 +568,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
|
|||||||
Args:
|
Args:
|
||||||
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
||||||
objects whose concatenation along dimension 0 has shape
|
objects whose concatenation along dimension 0 has shape
|
||||||
`[num_classes, dim]`. The (possibly-sharded) class embeddings.
|
`[num_classes, dim]`. The (possibly-partitioned) class embeddings.
|
||||||
biases: A `Tensor` of shape `[num_classes]`. The class biases.
|
biases: A `Tensor` of shape `[num_classes]`. The class biases.
|
||||||
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
|
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
|
||||||
activations of the input network.
|
activations of the input network.
|
||||||
@ -586,6 +587,9 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
|
|||||||
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
|
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
|
||||||
where a sampled class equals one of the target classes. Default is
|
where a sampled class equals one of the target classes. Default is
|
||||||
False.
|
False.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
|
||||||
|
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
Returns:
|
Returns:
|
||||||
out_logits, out_labels: `Tensor` objects each with shape
|
out_logits, out_labels: `Tensor` objects each with shape
|
||||||
@ -624,7 +628,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
|
|||||||
all_ids = array_ops.concat(0, [labels_flat, sampled])
|
all_ids = array_ops.concat(0, [labels_flat, sampled])
|
||||||
|
|
||||||
# weights shape is [num_classes, dim]
|
# weights shape is [num_classes, dim]
|
||||||
all_w = embedding_ops.embedding_lookup(weights, all_ids)
|
all_w = embedding_ops.embedding_lookup(
|
||||||
|
weights, all_ids, partition_strategy=partition_strategy)
|
||||||
all_b = embedding_ops.embedding_lookup(biases, all_ids)
|
all_b = embedding_ops.embedding_lookup(biases, all_ids)
|
||||||
# true_w shape is [batch_size * num_true, dim]
|
# true_w shape is [batch_size * num_true, dim]
|
||||||
# true_b is a [batch_size * num_true] tensor
|
# true_b is a [batch_size * num_true] tensor
|
||||||
@ -704,6 +709,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
|
|||||||
num_true=1,
|
num_true=1,
|
||||||
sampled_values=None,
|
sampled_values=None,
|
||||||
remove_accidental_hits=False,
|
remove_accidental_hits=False,
|
||||||
|
partition_strategy="mod",
|
||||||
name="nce_loss"):
|
name="nce_loss"):
|
||||||
"""Computes and returns the noise-contrastive estimation training loss.
|
"""Computes and returns the noise-contrastive estimation training loss.
|
||||||
|
|
||||||
@ -726,7 +732,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
|
|||||||
Args:
|
Args:
|
||||||
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
||||||
objects whose concatenation along dimension 0 has shape
|
objects whose concatenation along dimension 0 has shape
|
||||||
[num_classes, dim]. The (possibly-sharded) class embeddings.
|
[num_classes, dim]. The (possibly-partitioned) class embeddings.
|
||||||
biases: A `Tensor` of shape `[num_classes]`. The class biases.
|
biases: A `Tensor` of shape `[num_classes]`. The class biases.
|
||||||
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
|
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
|
||||||
activations of the input network.
|
activations of the input network.
|
||||||
@ -745,6 +751,9 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
|
|||||||
our [Candidate Sampling Algorithms Reference]
|
our [Candidate Sampling Algorithms Reference]
|
||||||
(../../extras/candidate_sampling.pdf).
|
(../../extras/candidate_sampling.pdf).
|
||||||
Default is False.
|
Default is False.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
|
||||||
|
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -756,6 +765,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
|
|||||||
sampled_values=sampled_values,
|
sampled_values=sampled_values,
|
||||||
subtract_log_q=True,
|
subtract_log_q=True,
|
||||||
remove_accidental_hits=remove_accidental_hits,
|
remove_accidental_hits=remove_accidental_hits,
|
||||||
|
partition_strategy=partition_strategy,
|
||||||
name=name)
|
name=name)
|
||||||
sampled_losses = sigmoid_cross_entropy_with_logits(logits,
|
sampled_losses = sigmoid_cross_entropy_with_logits(logits,
|
||||||
labels,
|
labels,
|
||||||
@ -769,6 +779,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
|
|||||||
num_classes, num_true=1,
|
num_classes, num_true=1,
|
||||||
sampled_values=None,
|
sampled_values=None,
|
||||||
remove_accidental_hits=True,
|
remove_accidental_hits=True,
|
||||||
|
partition_strategy="mod",
|
||||||
name="sampled_softmax_loss"):
|
name="sampled_softmax_loss"):
|
||||||
"""Computes and returns the sampled softmax training loss.
|
"""Computes and returns the sampled softmax training loss.
|
||||||
|
|
||||||
@ -805,6 +816,9 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
|
|||||||
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
|
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
|
||||||
where a sampled class equals one of the target classes. Default is
|
where a sampled class equals one of the target classes. Default is
|
||||||
True.
|
True.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
|
||||||
|
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -817,6 +831,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
|
|||||||
sampled_values=sampled_values,
|
sampled_values=sampled_values,
|
||||||
subtract_log_q=True,
|
subtract_log_q=True,
|
||||||
remove_accidental_hits=remove_accidental_hits,
|
remove_accidental_hits=remove_accidental_hits,
|
||||||
|
partition_strategy=partition_strategy,
|
||||||
name=name)
|
name=name)
|
||||||
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
|
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
|
||||||
# sampled_losses is a [batch_size] tensor.
|
# sampled_losses is a [batch_size] tensor.
|
||||||
|
@ -20,12 +20,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import constant_op
|
from tensorflow.python.ops import constant_op
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.training import slot_creator
|
||||||
|
|
||||||
|
|
||||||
# TODO(touts): switch to variables.Variable.
|
# TODO(touts): switch to variables.Variable.
|
||||||
@ -209,22 +209,19 @@ class ExponentialMovingAverage(object):
|
|||||||
raise TypeError("The variables must be float or double: %s" % var)
|
raise TypeError("The variables must be float or double: %s" % var)
|
||||||
if var in self._averages:
|
if var in self._averages:
|
||||||
raise ValueError("Moving average already computed for: %s" % var)
|
raise ValueError("Moving average already computed for: %s" % var)
|
||||||
with ops.name_scope(var.op.name + "/" + self._name) as scope:
|
|
||||||
# For variables: to lower communication bandwidth across devices we keep
|
# For variables: to lower communication bandwidth across devices we keep
|
||||||
# the moving averages on the same device as the variables. For other
|
# the moving averages on the same device as the variables. For other
|
||||||
# tensors, we rely on the existing device allocation mechanism.
|
# tensors, we rely on the existing device allocation mechanism.
|
||||||
if isinstance(var, variables.Variable):
|
if isinstance(var, variables.Variable):
|
||||||
with ops.device(var.device):
|
avg = slot_creator.create_slot(
|
||||||
avg = variables.Variable(var.initialized_value(),
|
var, var.initialized_value(), self._name,
|
||||||
name=scope, trainable=False)
|
colocate_with_primary=True)
|
||||||
elif var.op.type == "Variable":
|
else:
|
||||||
with ops.device(var.device):
|
avg = slot_creator.create_zeros_slot(
|
||||||
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
|
var, self._name, colocate_with_primary=(var.op.type == "Variable"))
|
||||||
name=scope, trainable=False)
|
self._averages[var] = avg
|
||||||
else:
|
|
||||||
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
|
|
||||||
name=scope, trainable=False)
|
|
||||||
self._averages[var] = avg
|
|
||||||
with ops.name_scope(self._name) as scope:
|
with ops.name_scope(self._name) as scope:
|
||||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||||
if self._num_updates is not None:
|
if self._num_updates is not None:
|
||||||
|
@ -22,11 +22,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.training import slot_creator
|
||||||
|
|
||||||
|
|
||||||
class Optimizer(object):
|
class Optimizer(object):
|
||||||
@ -418,6 +418,22 @@ class Optimizer(object):
|
|||||||
# Utility methods for subclasses.
|
# Utility methods for subclasses.
|
||||||
# --------------
|
# --------------
|
||||||
|
|
||||||
|
def _slot_dict(self, slot_name):
|
||||||
|
"""Returns a dict for caching slots created under the given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_name: Name for the slot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict that maps primary `Variable` objects to the slot created
|
||||||
|
for that variable, under the given slot name.
|
||||||
|
"""
|
||||||
|
named_slots = self._slots.get(slot_name, None)
|
||||||
|
if named_slots is None:
|
||||||
|
named_slots = {}
|
||||||
|
self._slots[slot_name] = named_slots
|
||||||
|
return named_slots
|
||||||
|
|
||||||
def _get_or_make_slot(self, var, val, slot_name, op_name):
|
def _get_or_make_slot(self, var, val, slot_name, op_name):
|
||||||
"""Find or create a slot for a variable.
|
"""Find or create a slot for a variable.
|
||||||
|
|
||||||
@ -431,19 +447,10 @@ class Optimizer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Variable` object.
|
A `Variable` object.
|
||||||
"""
|
"""
|
||||||
named_slots = self._slots.get(slot_name, None)
|
named_slots = self._slot_dict(slot_name)
|
||||||
if named_slots is None:
|
if var not in named_slots:
|
||||||
named_slots = {}
|
named_slots[var] = slot_creator.create_slot(var, val, op_name)
|
||||||
self._slots[slot_name] = named_slots
|
return named_slots[var]
|
||||||
slot = named_slots.get(var, None)
|
|
||||||
if slot is None:
|
|
||||||
# Scope the slot name in the namespace of the Variable and
|
|
||||||
# create the slot on the same device as the variable.
|
|
||||||
with ops.name_scope(var.op.name + "/" + op_name) as scope:
|
|
||||||
with ops.device(var.device):
|
|
||||||
slot = variables.Variable(val, name=scope, trainable=False)
|
|
||||||
named_slots[var] = slot
|
|
||||||
return slot
|
|
||||||
|
|
||||||
def _zeros_slot(self, var, slot_name, op_name):
|
def _zeros_slot(self, var, slot_name, op_name):
|
||||||
"""Find or create a slot initialized with 0.0.
|
"""Find or create a slot initialized with 0.0.
|
||||||
@ -457,5 +464,7 @@ class Optimizer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Variable` object.
|
A `Variable` object.
|
||||||
"""
|
"""
|
||||||
val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype)
|
named_slots = self._slot_dict(slot_name)
|
||||||
return self._get_or_make_slot(var, val, slot_name, op_name)
|
if var not in named_slots:
|
||||||
|
named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
|
||||||
|
return named_slots[var]
|
||||||
|
108
tensorflow/python/training/slot_creator.py
Normal file
108
tensorflow/python/training/slot_creator.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Standard functions for creating slots.
|
||||||
|
|
||||||
|
A slot is a `Variable` created with the same shape as a primary variable or
|
||||||
|
`Tensor`. A slot is always scoped in the namespace of the primary object and
|
||||||
|
typically has the same device and type.
|
||||||
|
|
||||||
|
Slots are typically used as accumulators to track values associated with
|
||||||
|
the primary object:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Optimizers can create a slot for each variable to track accumulators
|
||||||
|
accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
|
||||||
|
for var in vs:
|
||||||
|
apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
|
||||||
|
|
||||||
|
# Slots can also be used for moving averages
|
||||||
|
mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
|
||||||
|
update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# pylint: disable=g-bad-name
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
|
def _create_slot_var(primary, val, scope):
|
||||||
|
"""Helper function for creating a slot variable."""
|
||||||
|
|
||||||
|
slot = variables.Variable(val, name=scope, trainable=False)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if isinstance(primary, variables.Variable) and primary._save_slice_info:
|
||||||
|
# Primary is a partitioned variable, so we need to also indicate that
|
||||||
|
# the slot is a partitioned variable. Slots have the same partitioning
|
||||||
|
# as their primaries.
|
||||||
|
real_slot_name = scope[len(primary.op.name + "/"):-1]
|
||||||
|
slice_info = primary._save_slice_info
|
||||||
|
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||||
|
slice_info.full_name + "/" + real_slot_name,
|
||||||
|
slice_info.full_shape[:],
|
||||||
|
slice_info.var_offset[:],
|
||||||
|
slice_info.var_shape[:]))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return slot
|
||||||
|
|
||||||
|
|
||||||
|
def create_slot(primary, val, name, colocate_with_primary=True):
|
||||||
|
"""Create a slot initialized to the given value.
|
||||||
|
|
||||||
|
The type of the slot is determined by the given value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primary: The primary `Variable` or `Tensor`.
|
||||||
|
val: A `Tensor` specifying the initial value of the slot.
|
||||||
|
name: Name to use for the slot variable.
|
||||||
|
colocate_with_primary: Boolean. If True the slot is located
|
||||||
|
on the same device as `primary`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Variable` object.
|
||||||
|
"""
|
||||||
|
# Scope the slot name in the namespace of the primary variable.
|
||||||
|
with ops.name_scope(primary.op.name + "/" + name) as scope:
|
||||||
|
if colocate_with_primary:
|
||||||
|
with ops.device(primary.device):
|
||||||
|
return _create_slot_var(primary, val, scope)
|
||||||
|
else:
|
||||||
|
return _create_slot_var(primary, val, scope)
|
||||||
|
|
||||||
|
|
||||||
|
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
||||||
|
"""Create a slot initialized to 0 with same shape as the primary object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primary: The primary `Variable` or `Tensor`.
|
||||||
|
name: Name to use for the slot variable.
|
||||||
|
dtype: Type of the slot variable. Defaults to the type of `primary`.
|
||||||
|
colocate_with_primary: Boolean. If True the slot is located
|
||||||
|
on the same device as `primary`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Variable` object.
|
||||||
|
"""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = primary.dtype
|
||||||
|
val = array_ops.zeros(primary.get_shape().as_list(), dtype=dtype)
|
||||||
|
return create_slot(primary, val, name,
|
||||||
|
colocate_with_primary=colocate_with_primary)
|
78
tensorflow/python/training/slot_creator_test.py
Normal file
78
tensorflow/python/training/slot_creator_test.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Functional test for slot_creator."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow.python.platform
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.training import slot_creator
|
||||||
|
|
||||||
|
|
||||||
|
class SlotCreatorTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testCreateSlotFromVariable(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = tf.Variable([1.0, 2.5], name="var")
|
||||||
|
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
self.assertEqual(slot.op.name, "var/slot")
|
||||||
|
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||||
|
self.assertEqual(slot.dtype.base_dtype, tf.float32)
|
||||||
|
self.assertAllEqual(slot.eval(), [1.0, 2.5])
|
||||||
|
|
||||||
|
def testCreateSlotFromTensor(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = tf.constant([1.0, 2.5], name="const")
|
||||||
|
slot = slot_creator.create_slot(v, v * 2, name="slot")
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
self.assertEqual(slot.op.name, "const/slot")
|
||||||
|
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||||
|
self.assertEqual(slot.dtype.base_dtype, tf.float32)
|
||||||
|
self.assertAllEqual(slot.eval(), [2.0, 5.0])
|
||||||
|
|
||||||
|
def testCreateZerosSlotFromVariable(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = tf.Variable([1.0, 2.5], name="var")
|
||||||
|
slot = slot_creator.create_zeros_slot(v, name="slot", dtype=tf.float64)
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
self.assertEqual(slot.op.name, "var/slot")
|
||||||
|
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||||
|
self.assertEqual(slot.dtype.base_dtype, tf.float64)
|
||||||
|
self.assertAllEqual(slot.eval(), [0.0, 0.0])
|
||||||
|
|
||||||
|
def testCreateZerosSlotFromTensor(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = tf.constant([1.0, 2.5], name="const")
|
||||||
|
|
||||||
|
slot = slot_creator.create_zeros_slot(v, name="slot")
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
self.assertEqual(slot.op.name, "const/slot")
|
||||||
|
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||||
|
self.assertEqual(slot.dtype.base_dtype, tf.float32)
|
||||||
|
self.assertAllEqual(slot.eval(), [0.0, 0.0])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user