TensorFlow: Upstream changes to git.

Change 109730179
	Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input.
Change 109729548
	TensorFlow: add RELEASE.md notes for 0.6.0.
Change 109728185
	Make seq2seq_test non-flaky by setting python and numpy random seed.
Change 109725913
	Refactor slot creation in optimizers and moving averages to separate file
Change 109718024
	TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s.
Change 109712251
	More performance improvement for convnet on GPU.
	+ Switch forward convolution format to NCHW.
	+ Allocate scratch space for forward- and backward- convolutions.
	+ Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space
	limit. The default limit in 1GB.
Change 109710898
	Added extract_sub_graph utility function

Base CL: 109731609
This commit is contained in:
Vijay Vasudevan 2015-12-08 14:55:13 -08:00
parent ddd4aaf528
commit 2c3738db9c
14 changed files with 763 additions and 137 deletions

View File

@ -1,3 +1,30 @@
# Release 0.6.0
## Major Features and Improvements
* Python 3.3+ support via changes to python codebase and ability
to specify python version via ./configure.
* Some improvements to GPU performance and memory usage:
[convnet benchmarks](https://github.com/soumith/convnet-benchmarks/issues/66)
roughly equivalent with native cudnn v2 performance. Improvements mostly due
to moving to 32-bit indices, faster shuffling kernels. More improvements to
come in later releases.
## Bug fixes
* Lots of fixes to documentation and tutorials, many contributed
by the public.
* 271 closed issues on github issues.
## Backwards-incompatible changes
* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion'
attribute from 0.0 to 1.0. This was a bug in the original release
that is now fixed.
# Release 0.5.0
Initial release of TensorFlow.

View File

@ -35,6 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -756,17 +757,6 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
// GPU definitions of both ops.
#if GOOGLE_CUDA
namespace {
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace
// The slow version (but compiles for GPU)
// Backprop for input.
@ -929,10 +919,15 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
pre_transformed_in_backprop.template flat<T>().size());
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
bool cudnn_launch_status =
stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc,
out_backprop_ptr, conv_desc,
input_desc, &in_backprop_ptr)
stream->ThenConvolveBackwardDataWithScratch(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
.ok();
if (!cudnn_launch_status) {
@ -1185,7 +1180,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->eigen_device<Device>(),
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
transformed_out_backprop.template flat<T>().size());
@ -1196,10 +1190,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
AsDeviceMemory(transformed_input.template flat<T>().data(),
transformed_input.template flat<T>().size());
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
bool cudnn_launch_status =
stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc,
out_backprop_ptr, conv_desc,
filter_desc, &filter_backprop_ptr)
stream->ThenConvolveBackwardFilterWithScratch(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator)
.ok();
if (!cudnn_launch_status) {

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
@ -34,6 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D")
#if GOOGLE_CUDA
namespace {
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
if (strings::safe_strto64(workspace_limit_in_mb_str,
&scratch_limit_in_mb)) {
return scratch_limit_in_mb * (1 << 20);
} else {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
}
}
return default_value_in_bytes;
}
} // namespace
template <typename T>
struct LaunchConvOp<GPUDevice, T> {
@ -287,18 +295,34 @@ struct LaunchConvOp<GPUDevice, T> {
input = transformed_input;
}
{
// Convert the input tensor from NHWC to NCHW.
Tensor transformed_input;
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({input.dim_size(0), input.dim_size(3),
input.dim_size(1), input.dim_size(2)}),
&transformed_input));
functor::NHWCToNCHW<GPUDevice, T>()(
ctx->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
input = transformed_input;
}
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(input.dim_size(0))
.set_height(input.dim_size(1))
.set_width(input.dim_size(2))
.set_feature_map_count(input.dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
.set_feature_map_count(input.dim_size(1))
.set_height(input.dim_size(2))
.set_width(input.dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::BatchDescriptor output_desc;
output_desc.set_count(output->dim_size(0))
.set_height(output->dim_size(1))
.set_width(output->dim_size(2))
.set_feature_map_count(output->dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(filter.dim_size(0))
.set_input_filter_width(filter.dim_size(1))
@ -320,17 +344,31 @@ struct LaunchConvOp<GPUDevice, T> {
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_output;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({output->dim_size(0), output->dim_size(3),
output->dim_size(1), output->dim_size(2)}),
&transformed_output));
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
auto filter_ptr =
AsDeviceMemory(transformed_filter.template flat<T>().data(),
transformed_filter.template flat<T>().size());
auto output_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());
auto output_ptr =
AsDeviceMemory(transformed_output.template flat<T>().data(),
transformed_output.template flat<T>().size());
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr,
conv_desc, output_desc, &output_ptr)
stream->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
filter_ptr, conv_desc, output_desc,
&output_ptr, &scratch_allocator)
.ok();
if (!cudnn_launch_status) {
@ -338,6 +376,12 @@ struct LaunchConvOp<GPUDevice, T> {
"cuDNN launch failure : input shape(", input.shape().DebugString(),
") filter shape(", filter.shape().DebugString(), ")"));
}
// Convert the output tensor back from NHWC to NCHW.
functor::NCHWToNHWC<GPUDevice, T>()(
ctx->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
output->tensor<T, 4>());
} else {
LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
padding, output);

View File

@ -0,0 +1,84 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
namespace tensorflow {
// TODO(zhengxq): move this to gpu_util.h. The use of such wrapers is wide
// spread.
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
// Get the Cudnn workspace limit from the environment variable, which is in MB.
// Return the workspace memory limit in bytes. If no value is set, return the
// default value.
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes);
// A class to provide scratch-space allocator for Stream-Executor Cudnn
// callback. TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
public:
virtual ~CudnnScratchAllocator() {}
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), context_(context) {}
virtual int64 GetMemoryLimitInBytes(
perftools::gputools::Stream* stream) override {
return memory_limit_;
}
virtual perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
LOG(WARNING) << allocation_status;
context_->SetStatus(allocation_status);
return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>();
}
return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
private:
int64 memory_limit_;
OpKernelContext* context_;
};
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_

View File

@ -19,6 +19,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow.python.platform
@ -155,3 +156,64 @@ def pin_to_cpu(op):
logging.info("Operation %s has been assigned to a non-CPU (%s), so "
"it will not be pinned to the CPU.", op.name, dev.device_type)
return device
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
Args:
graph_def: A graph_pb2.GraphDef proto.
dest_nodes: A list of strings specifying the destination node names.
Returns:
The GraphDef of the sub-graph.
Raises:
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
edges = {} # Keyed by the dest node name.
name_to_node_map = {} # Keyed by node name.
# Keeps track of node sequences. It is important to still output the
# operations in the original order.
node_seq = {} # Keyed by node name.
seq = 0
for node in graph_def.node:
n = _node_name(node.name)
name_to_node_map[n] = node
edges[n] = [_node_name(x) for x in node.input]
node_seq[n] = seq
seq += 1
for d in dest_nodes:
assert d in name_to_node_map, "%d is not in graph" % d
nodes_to_keep = set()
# Breadth first search to find all the nodes that we should keep.
next_to_visit = dest_nodes[:]
while next_to_visit:
n = next_to_visit[0]
del next_to_visit[0]
if n in nodes_to_keep:
# Already visited this node.
continue
nodes_to_keep.add(n)
next_to_visit += edges[n]
nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
# Now construct the output GraphDef
out = graph_pb2.GraphDef()
for n in nodes_to_keep_list:
out.node.extend([copy.deepcopy(name_to_node_map[n])])
return out

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import tensorflow.python.platform
import tensorflow as tf
from tensorflow.python.client import graph_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -140,6 +142,34 @@ class DeviceFunctionsTest(googletest.TestCase):
self.assertEqual(const_4.device, "/device:CPU:1")
self.assertEqual(const_5.device, "/replica:0")
def testExtractSubGraph(self):
graph_def = tf.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
if __name__ == "__main__":
googletest.main()

View File

@ -115,26 +115,34 @@ def _PName(param_id):
def _EmbeddingParams(num_shards, vocab_size,
dtype=tf.float32,
shape=None):
shape=None,
use_shapeless_placeholder=False):
p = []
params = {}
feed_dict = {}
if not shape: shape = [10]
assert not vocab_size % num_shards
shape = [vocab_size // num_shards] + shape
for i in range(num_shards):
shard_shape = [vocab_size // num_shards] + shape
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
shard_shape[0] += 1
param_name = _PName(i)
constant_t = tf.constant(1.0, shape=shape, dtype=dtype,
name=param_name)
p.append(constant_t)
if use_shapeless_placeholder:
param = tf.placeholder(dtype, shape=None, name=param_name)
else:
param = tf.constant(1.0, shape=shard_shape, dtype=dtype, name=param_name)
p.append(param)
np_type = "f" if dtype == tf.float32 else "d"
val = (np.random.rand(*shape).astype(np_type)) + 1
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
params[param_name + ":0"] = val
feed_dict[constant_t.name] = val
feed_dict[param.name] = val
return p, params, feed_dict
def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
def _EmbeddingResult(params, id_vals, num_shards, vocab_size,
partition_strategy="mod",
weight_vals=None):
if weight_vals is None:
weight_vals = np.copy(id_vals)
weight_vals.fill(1)
@ -147,8 +155,22 @@ def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
ids = [ids]
wts = [wts]
for i, wt_val in zip(ids, wts):
val = np.copy(params[_PName(i % num_shards) + ":0"][
i // num_shards, :]) * wt_val
if partition_strategy == "mod":
val = np.copy(params[_PName(i % num_shards) + ":0"][
i // num_shards, :]) * wt_val
elif partition_strategy == "div":
ids_per_partition, extras = divmod(vocab_size, num_shards)
threshold = extras * (ids_per_partition + 1)
if i < threshold:
partition = i // (ids_per_partition + 1)
offset = i % (ids_per_partition + 1)
else:
partition = extras + (i - threshold) // ids_per_partition
offset = (i - threshold) % ids_per_partition
val = np.copy(
params[_PName(partition) + ":0"][offset, :]) * wt_val
else:
assert False
if val_aggr is None:
assert wt_aggr is None
val_aggr = val
@ -182,17 +204,17 @@ class EmbeddingLookupTest(tf.test.TestCase):
embedding = tf.nn.embedding_lookup(p, ids)
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
def testSharded(self):
def testShardedModPartitioningInt32Ids(self):
with self.test_session():
num_shards = 5
vocab_size = 25
# Embedding dimensions is 10. The 10 x vocab_size embedding
# parameters are spread in num_shards matrices, so each
# matrix is 10 x (vocab_size / num_shards)
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
# parameters are spread in num_shards matrices, so the first
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
num_vals = 30
@ -204,10 +226,103 @@ class EmbeddingLookupTest(tf.test.TestCase):
embedding = tf.nn.embedding_lookup(p, ids)
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
def testShardedModPartitioningInt64Ids(self):
with self.test_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
# parameters are spread in num_shards matrices, so the first
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
num_vals = 30
# Fetch num_vals embeddings for random word ids. Since
# num_vals > vocab_size, this ought to have repetitions, so
# will test that aspect.
id_vals = np.random.randint(vocab_size, size=num_vals)
ids = tf.constant(list(id_vals), dtype=tf.int64)
embedding = tf.nn.embedding_lookup(p, ids)
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt32Ids(self):
with self.test_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
# parameters are spread in num_shards matrices, so the first
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
num_vals = 30
# Fetch num_vals embeddings for random word ids. Since
# num_vals > vocab_size, this ought to have repetitions, so
# will test that aspect.
id_vals = np.random.randint(vocab_size, size=num_vals)
ids = tf.constant(list(id_vals), dtype=tf.int32)
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(
params, id_vals, num_shards, vocab_size, partition_strategy="div")
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt64Ids(self):
with self.test_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
# parameters are spread in num_shards matrices, so the first
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
num_vals = 30
# Fetch num_vals embeddings for random word ids. Since
# num_vals > vocab_size, this ought to have repetitions, so
# will test that aspect.
id_vals = np.random.randint(vocab_size, size=num_vals)
ids = tf.constant(list(id_vals), dtype=tf.int64)
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(
params, id_vals, num_shards, vocab_size, partition_strategy="div")
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningUnknownParamShape(self):
with self.test_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
# parameters are spread in num_shards matrices, so the first
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
# We clear parameter shapes, to test when shape is not statically known.
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, use_shapeless_placeholder=True)
num_vals = 30
# Fetch num_vals embeddings for random word ids. Since
# num_vals > vocab_size, this ought to have repetitions, so
# will test that aspect.
id_vals = np.random.randint(vocab_size, size=num_vals)
ids = tf.constant(list(id_vals), dtype=tf.int64)
embedding = tf.nn.embedding_lookup(p, ids, partition_strategy="div")
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _ = _EmbeddingResult(
params, id_vals, num_shards, vocab_size, partition_strategy="div")
self.assertAllEqual(np_result, tf_result)
def testGradientsEmbeddingLookup(self):
vocab_size = 9
num_ids = 5
@ -326,7 +441,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
return grouped_vals
def testEmbeddingLookupSparse(self):
vocab_size = 25
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
@ -354,7 +469,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
np_embedding_sum, np_weight_sum = _EmbeddingResult(
params, grouped_ids, num_shards,
params, grouped_ids, num_shards, vocab_size,
weight_vals=grouped_ignored_weights
if ignore_weights else grouped_weights)
if combiner == "mean":

View File

@ -354,6 +354,10 @@ class Seq2SeqTest(tf.test.TestCase):
# We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
classes = 10
buckets = [(4, 4), (8, 8)]
perplexities = [[], []] # Results for each bucket.
tf.set_random_seed(111)
random.seed(111)
np.random.seed(111)
with self.test_session() as sess:
# We use sampled softmax so we keep output projection separate.
@ -378,8 +382,7 @@ class Seq2SeqTest(tf.test.TestCase):
softmax_loss_function=SampledLoss)
# Now we construct the copy model.
tf.set_random_seed(111)
batch_size = 32
batch_size = 8
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)]
@ -394,26 +397,26 @@ class Seq2SeqTest(tf.test.TestCase):
update = optimizer.apply_gradients(zip(grads, params))
updates.append(update)
sess.run([tf.initialize_all_variables()])
for ep in xrange(3):
log_perp = 0.0
for _ in xrange(50):
bucket = random.choice(np.arange(len(buckets)))
length = buckets[bucket][0]
i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
dtype=np.int32) for _ in xrange(length)]
# 0 is our "GO" symbol here.
o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i
feed = {}
for l in xrange(length):
feed[inp[l].name] = i[l]
feed[out[l].name] = o[l]
if length < 8: # For the 4-bucket, we need the 5th as target.
feed[out[length].name] = o[length]
res = sess.run([updates[bucket], losses[bucket]], feed)
log_perp += float(res[1])
perp = math.exp(log_perp / 100)
print("step %d avg. perp %f" % ((ep + 1) * 50, perp))
self.assertLess(perp, 2.5)
steps = 6
for _ in xrange(steps):
bucket = random.choice(np.arange(len(buckets)))
length = buckets[bucket][0]
i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
dtype=np.int32) for _ in xrange(length)]
# 0 is our "GO" symbol here.
o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i
feed = {}
for l in xrange(length):
feed[inp[l].name] = i[l]
feed[out[l].name] = o[l]
if length < 8: # For the 4-bucket, we need the 5th as target.
feed[out[length].name] = o[length]
res = sess.run([updates[bucket], losses[bucket]], feed)
perplexities[bucket].append(math.exp(float(res[1])))
for bucket in xrange(len(buckets)):
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
self.assertLess(perplexities[bucket][1], perplexities[bucket][0])
if __name__ == "__main__":
tf.test.main()

View File

@ -22,11 +22,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
def embedding_lookup(params, ids, name=None):
def embedding_lookup(params, ids, partition_strategy="mod", name=None):
"""Looks up `ids` in a list of embedding tensors.
This function is used to perform parallel lookups on the list of
@ -35,16 +36,32 @@ def embedding_lookup(params, ids, name=None):
interpreted as a partition of a larger embedding tensor.
If `len(params) > 1`, each element `id` of `ids` is partitioned between
the elements of `params` by computing `p = id % len(params)`, and is
then used to look up the slice `params[p][id // len(params), ...]`.
the elements of `params` according to the `partition_strategy`.
In all strategies, if the id space does not evenly divide the number of
partitions, each of the first `(max_id + 1) % len(params)` partitions will
be assigned one more id.
The results of the lookup are then concatenated into a dense
If `partition_strategy` is `"mod"`, we assign each id to partition
`p = id % len(params)`. For instance,
13 ids are split across 5 partitions as:
`[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
If `partition_strategy` is `"div"`, we assign ids to partitions in a
contiguous manner. In this case, 13 ids are split across 5 partitions as:
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
The results of the lookup are concatenated into a dense
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
Args:
params: A list of tensors with the same shape and type.
params: A list of tensors with the same type and which can be concatenated
along dimension 0. Each `Tensor` must be appropriately sized for the given
`partition_strategy`.
ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
up in `params`.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`.
name: A name for the operation (optional).
Returns:
@ -67,23 +84,59 @@ def embedding_lookup(params, ids, name=None):
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
# Compute flat_ids % partitions for each id
ids_mod_p = flat_ids % np
if ids_mod_p.dtype != dtypes.int32:
ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32)
# Partition single list of ids based on ids % np into np separate lists
plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
# Create p_assignments and set new_ids depending on the strategy.
if partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
elif partition_strategy == "div":
# Compute num_total_ids as the sum of dim-0 of params, then assign to
# partitions based on a constant number of ids per partition. Optimize
# if we already know the full shape statically.
dim_0_size = params[0].get_shape()[0]
for p in xrange(1, np):
dim_0_size += params[p].get_shape()[0]
if dim_0_size.value:
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
else:
dim_0_sizes = []
for p in xrange(np):
with ops.device(params[p].device):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
p_assignments = math_ops.maximum(
flat_ids // (ids_per_partition + 1),
(flat_ids - extras) // ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
is_in_first_extras_partitions = math_ops.cast(
p_assignments < extras, flat_ids.dtype)
new_ids = (
is_in_first_extras_partitions * (
flat_ids % (ids_per_partition + 1)) +
(1 - is_in_first_extras_partitions) * (
(flat_ids - extras) % ids_per_partition))
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p,
np)
pindices = data_flow_ops.dynamic_partition(original_indices,
p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
# TODO(agarwal): handle device allocations here and later in the
# colocate code.
gather_ids = plist[p] // np
with ops.device(params[p].device):
partitioned_result.append(array_ops.gather(params[p], gather_ids))
partitioned_result.append(array_ops.gather(params[p], gather_ids[p]))
# Stitch these back together
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
name=name)
@ -106,6 +159,7 @@ def embedding_lookup(params, ids, name=None):
# TODO(lif): Add support for higher-rank SparseTensors
def embedding_lookup_sparse(params, sp_ids, sp_weights,
partition_strategy="mod",
name=None,
combiner="mean"):
"""Computes embeddings for the given ids and weights.
@ -120,16 +174,15 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
Args:
params: A single tensor representing the complete embedding tensor,
or a list of P tensors all of same shape except for the first dimension,
representing sharded embedding tensors. In the latter case, the ids are
partitioned by id % P, and we do separate lookups in params[p] for
0 <= p < P, and then stitch the results back together into a single
result tensor. The first dimension is allowed to vary as the vocab
size is not necessarily a multiple of P.
representing sharded embedding tensors.
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
where N is typically batch size and M is arbitrary.
sp_weights: either a SparseTensor of float / double weights, or None to
indicate all weights should be taken to be 1. If specified, sp_weights
must have exactly the same shape and indices as sp_ids.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: Optional name for the op.
combiner: A string specifying the reduction op. Currently "mean" and "sum"
are supported.
@ -187,7 +240,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
else:
idx = None
embeddings = embedding_lookup(params, ids)
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy)
if not ignore_weights:
weights = sp_weights.values
if weights.dtype != embeddings.dtype:

View File

@ -553,6 +553,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
sampled_values=None,
subtract_log_q=True,
remove_accidental_hits=False,
partition_strategy="mod",
name=None):
"""Helper function for nce_loss and sampled_softmax_loss functions.
@ -567,7 +568,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
Args:
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
objects whose concatenation along dimension 0 has shape
`[num_classes, dim]`. The (possibly-sharded) class embeddings.
`[num_classes, dim]`. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
@ -586,6 +587,9 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is
False.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
Returns:
out_logits, out_labels: `Tensor` objects each with shape
@ -624,7 +628,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
all_ids = array_ops.concat(0, [labels_flat, sampled])
# weights shape is [num_classes, dim]
all_w = embedding_ops.embedding_lookup(weights, all_ids)
all_w = embedding_ops.embedding_lookup(
weights, all_ids, partition_strategy=partition_strategy)
all_b = embedding_ops.embedding_lookup(biases, all_ids)
# true_w shape is [batch_size * num_true, dim]
# true_b is a [batch_size * num_true] tensor
@ -704,6 +709,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
num_true=1,
sampled_values=None,
remove_accidental_hits=False,
partition_strategy="mod",
name="nce_loss"):
"""Computes and returns the noise-contrastive estimation training loss.
@ -726,7 +732,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
Args:
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
objects whose concatenation along dimension 0 has shape
[num_classes, dim]. The (possibly-sharded) class embeddings.
[num_classes, dim]. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
@ -745,6 +751,9 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
our [Candidate Sampling Algorithms Reference]
(../../extras/candidate_sampling.pdf).
Default is False.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
Returns:
@ -756,6 +765,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
sampled_losses = sigmoid_cross_entropy_with_logits(logits,
labels,
@ -769,6 +779,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
num_classes, num_true=1,
sampled_values=None,
remove_accidental_hits=True,
partition_strategy="mod",
name="sampled_softmax_loss"):
"""Computes and returns the sampled softmax training loss.
@ -805,6 +816,9 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is
True.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
Returns:
@ -817,6 +831,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
# sampled_losses is a [batch_size] tensor.

View File

@ -20,12 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
# TODO(touts): switch to variables.Variable.
@ -209,22 +209,19 @@ class ExponentialMovingAverage(object):
raise TypeError("The variables must be float or double: %s" % var)
if var in self._averages:
raise ValueError("Moving average already computed for: %s" % var)
with ops.name_scope(var.op.name + "/" + self._name) as scope:
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
if isinstance(var, variables.Variable):
with ops.device(var.device):
avg = variables.Variable(var.initialized_value(),
name=scope, trainable=False)
elif var.op.type == "Variable":
with ops.device(var.device):
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
name=scope, trainable=False)
else:
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
name=scope, trainable=False)
self._averages[var] = avg
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(
var, var.initialized_value(), self._name,
colocate_with_primary=True)
else:
avg = slot_creator.create_zeros_slot(
var, self._name, colocate_with_primary=(var.op.type == "Variable"))
self._averages[var] = avg
with ops.name_scope(self._name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
if self._num_updates is not None:

View File

@ -22,11 +22,11 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
class Optimizer(object):
@ -418,6 +418,22 @@ class Optimizer(object):
# Utility methods for subclasses.
# --------------
def _slot_dict(self, slot_name):
"""Returns a dict for caching slots created under the given name.
Args:
slot_name: Name for the slot.
Returns:
A dict that maps primary `Variable` objects to the slot created
for that variable, under the given slot name.
"""
named_slots = self._slots.get(slot_name, None)
if named_slots is None:
named_slots = {}
self._slots[slot_name] = named_slots
return named_slots
def _get_or_make_slot(self, var, val, slot_name, op_name):
"""Find or create a slot for a variable.
@ -431,19 +447,10 @@ class Optimizer(object):
Returns:
A `Variable` object.
"""
named_slots = self._slots.get(slot_name, None)
if named_slots is None:
named_slots = {}
self._slots[slot_name] = named_slots
slot = named_slots.get(var, None)
if slot is None:
# Scope the slot name in the namespace of the Variable and
# create the slot on the same device as the variable.
with ops.name_scope(var.op.name + "/" + op_name) as scope:
with ops.device(var.device):
slot = variables.Variable(val, name=scope, trainable=False)
named_slots[var] = slot
return slot
named_slots = self._slot_dict(slot_name)
if var not in named_slots:
named_slots[var] = slot_creator.create_slot(var, val, op_name)
return named_slots[var]
def _zeros_slot(self, var, slot_name, op_name):
"""Find or create a slot initialized with 0.0.
@ -457,5 +464,7 @@ class Optimizer(object):
Returns:
A `Variable` object.
"""
val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype)
return self._get_or_make_slot(var, val, slot_name, op_name)
named_slots = self._slot_dict(slot_name)
if var not in named_slots:
named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
return named_slots[var]

View File

@ -0,0 +1,108 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Standard functions for creating slots.
A slot is a `Variable` created with the same shape as a primary variable or
`Tensor`. A slot is always scoped in the namespace of the primary object and
typically has the same device and type.
Slots are typically used as accumulators to track values associated with
the primary object:
```python
# Optimizers can create a slot for each variable to track accumulators
accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
for var in vs:
apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
# Slots can also be used for moving averages
mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
```
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
def _create_slot_var(primary, val, scope):
"""Helper function for creating a slot variable."""
slot = variables.Variable(val, name=scope, trainable=False)
# pylint: disable=protected-access
if isinstance(primary, variables.Variable) and primary._save_slice_info:
# Primary is a partitioned variable, so we need to also indicate that
# the slot is a partitioned variable. Slots have the same partitioning
# as their primaries.
real_slot_name = scope[len(primary.op.name + "/"):-1]
slice_info = primary._save_slice_info
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:],
slice_info.var_offset[:],
slice_info.var_shape[:]))
# pylint: enable=protected-access
return slot
def create_slot(primary, val, name, colocate_with_primary=True):
"""Create a slot initialized to the given value.
The type of the slot is determined by the given value.
Args:
primary: The primary `Variable` or `Tensor`.
val: A `Tensor` specifying the initial value of the slot.
name: Name to use for the slot variable.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
Returns:
A `Variable` object.
"""
# Scope the slot name in the namespace of the primary variable.
with ops.name_scope(primary.op.name + "/" + name) as scope:
if colocate_with_primary:
with ops.device(primary.device):
return _create_slot_var(primary, val, scope)
else:
return _create_slot_var(primary, val, scope)
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
"""Create a slot initialized to 0 with same shape as the primary object.
Args:
primary: The primary `Variable` or `Tensor`.
name: Name to use for the slot variable.
dtype: Type of the slot variable. Defaults to the type of `primary`.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
Returns:
A `Variable` object.
"""
if dtype is None:
dtype = primary.dtype
val = array_ops.zeros(primary.get_shape().as_list(), dtype=dtype)
return create_slot(primary, val, name,
colocate_with_primary=colocate_with_primary)

View File

@ -0,0 +1,78 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional test for slot_creator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
import tensorflow as tf
from tensorflow.python.training import slot_creator
class SlotCreatorTest(tf.test.TestCase):
def testCreateSlotFromVariable(self):
with self.test_session():
v = tf.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
tf.initialize_all_variables().run()
self.assertEqual(slot.op.name, "var/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, tf.float32)
self.assertAllEqual(slot.eval(), [1.0, 2.5])
def testCreateSlotFromTensor(self):
with self.test_session():
v = tf.constant([1.0, 2.5], name="const")
slot = slot_creator.create_slot(v, v * 2, name="slot")
tf.initialize_all_variables().run()
self.assertEqual(slot.op.name, "const/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, tf.float32)
self.assertAllEqual(slot.eval(), [2.0, 5.0])
def testCreateZerosSlotFromVariable(self):
with self.test_session():
v = tf.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_zeros_slot(v, name="slot", dtype=tf.float64)
tf.initialize_all_variables().run()
self.assertEqual(slot.op.name, "var/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, tf.float64)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
def testCreateZerosSlotFromTensor(self):
with self.test_session():
v = tf.constant([1.0, 2.5], name="const")
slot = slot_creator.create_zeros_slot(v, name="slot")
tf.initialize_all_variables().run()
self.assertEqual(slot.op.name, "const/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, tf.float32)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
if __name__ == "__main__":
tf.test.main()