[TF contrib seq2seq] Beam search tree decoder op "GatherTree".

Based heavily on Denny Britz's initial implementation.
Change: 153470580
This commit is contained in:
Eugene Brevdo 2017-04-18 08:25:04 -08:00 committed by TensorFlower Gardener
parent 617e217cfc
commit a498546ddf
15 changed files with 653 additions and 9 deletions

View File

@ -59,6 +59,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/best_splits_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc"
@ -127,6 +129,7 @@ endif(WIN32)
file(GLOB_RECURSE tf_core_gpu_kernels_srcs
"${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc"
)
if(WIN32 AND tensorflow_ENABLE_GPU)

View File

@ -76,6 +76,7 @@ GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/co
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")

View File

@ -447,6 +447,8 @@ add_python_module("tensorflow/contrib/saved_model")
add_python_module("tensorflow/contrib/saved_model/python")
add_python_module("tensorflow/contrib/saved_model/python/saved_model")
add_python_module("tensorflow/contrib/seq2seq")
add_python_module("tensorflow/contrib/seq2seq/kernels")
add_python_module("tensorflow/contrib/seq2seq/ops")
add_python_module("tensorflow/contrib/seq2seq/python")
add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests")
add_python_module("tensorflow/contrib/seq2seq/python/ops")
@ -629,6 +631,8 @@ GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_lstm_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_seq2seq_beam_search_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/ops/gen_beam_search_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops"
@ -820,6 +824,26 @@ if(WIN32)
DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/python/ops/)
endif(WIN32)
if(WIN32)
# include contrib/seq2seq as .so
#
set(tf_beam_search_srcs
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc"
)
set(tf_beam_search_gpu_srcs
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc"
)
AddUserOps(TARGET _beam_search_ops
SOURCES "${tf_beam_search_srcs}"
GPUSOURCES ${tf_beam_search_gpu_srcs}
DEPENDS pywrap_tensorflow_internal tf_python_ops
DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/)
endif(WIN32)
############################################################
# Build a PIP package containing the TensorFlow runtime.
############################################################

View File

@ -8,12 +8,28 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_kernel_library",
"tf_gen_op_wrapper_py",
)
py_library(
tf_custom_op_py_library(
name = "seq2seq_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
dso = [
":python/ops/_beam_search_ops.so",
],
kernels = [
":beam_search_ops_kernels",
":beam_search_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":beam_search_ops",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/rnn:rnn_py",
@ -29,6 +45,44 @@ py_library(
],
)
tf_custom_op_library(
name = "python/ops/_beam_search_ops.so",
srcs = [
"kernels/beam_search_ops.cc",
"kernels/beam_search_ops.h",
"ops/beam_search_ops.cc",
],
gpu_srcs = [
"kernels/beam_search_ops_gpu.cu.cc",
"kernels/beam_search_ops.h",
],
deps = [
"//tensorflow/core/kernels:eigen_helpers",
],
)
tf_gen_op_wrapper_py(
name = "beam_search_ops",
deps = [":beam_search_ops_op_lib"],
)
tf_gen_op_libs(
op_lib_names = [
"beam_search_ops",
],
)
tf_kernel_library(
name = "beam_search_ops_kernels",
prefix = "kernels/beam_search_ops",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
],
)
cuda_py_test(
name = "loss_test",
size = "medium",
@ -67,6 +121,20 @@ cuda_py_test(
],
)
cuda_py_test(
name = "beam_search_ops_test",
size = "medium",
srcs = ["python/kernel_tests/beam_search_ops_test.py"],
additional_deps = [
":seq2seq_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "decoder_test",
size = "medium",

View File

@ -37,6 +37,8 @@ See the @{$python/contrib.seq2seq} guide.
@@AttentionWrapperState
@@AttentionWrapper
@@gather_tree
"""
from __future__ import absolute_import
@ -46,6 +48,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import *
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import *
from tensorflow.contrib.seq2seq.python.ops.beam_search_ops import *
from tensorflow.contrib.seq2seq.python.ops.decoder import *
from tensorflow.contrib.seq2seq.python.ops.helper import *
from tensorflow.contrib.seq2seq.python.ops.loss import *

View File

@ -0,0 +1,174 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class GatherTreeOp : public OpKernel {
public:
explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Device& device = ctx->eigen_device<Device>();
const Tensor& step_ids = ctx->input(0);
const Tensor& parent_ids = ctx->input(1);
const Tensor& sequence_length = ctx->input(2);
const TensorShape& step_ids_shape = step_ids.shape();
OP_REQUIRES(
ctx, step_ids_shape.dims() == 3,
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
step_ids_shape.DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(sequence_length.shape()),
errors::InvalidArgument("sequence_length must be a vector, saw shape: ",
sequence_length.shape().DebugString()));
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[1] (",
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
step_ids_shape.dim_size(0), ")"));
OP_REQUIRES(
ctx, step_ids_shape == parent_ids.shape(),
errors::InvalidArgument(
"step_ids.shape must match parent_ids.shape. but shapes are: ",
step_ids_shape.DebugString(), " and ",
parent_ids.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
typename TTypes<T>::ConstVec seq_len_t = sequence_length.vec<T>();
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
seq_len_t, beams_t);
}
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GatherTree").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
GatherTreeOp<CPUDevice, T>);
REGISTER_KERNEL(int32);
#undef REGISTER_KERNEL
namespace functor {
// CPU specialization
template <>
struct GatherTree<CPUDevice, int32> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<int32, 3>::ConstTensor step_ids,
typename TTypes<int32, 3>::ConstTensor parent_ids,
typename TTypes<int32>::ConstVec sequence_length,
typename TTypes<int32, 3>::Tensor beams) {
const int64 max_time = parent_ids.dimension(0);
const int64 batch_size = parent_ids.dimension(1);
const int64 beam_width = parent_ids.dimension(2);
beams.setConstant(-1);
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
int32 seq_len_b = -1;
int32 old_batch = -1;
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
if (batch != old_batch) {
seq_len_b = sequence_length(batch);
old_batch = batch;
}
if (seq_len_b == 0) {
continue;
}
beams(seq_len_b - 1, batch, beam) =
step_ids(seq_len_b - 1, batch, beam);
int32 parent = parent_ids(seq_len_b - 1, batch, beam);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
if (parent < 0 || parent > beam_width) {
ctx->SetStatus(
errors::InvalidArgument("Saw invalid parent id ", parent,
" at (batch, time, beam) == (", batch,
", ", level, ", ", beam, ")"));
return;
}
beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent);
}
}
};
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
// traversal time step.
const int64 batch_beam_cost =
Eigen::TensorOpCost::DivCost<int32>() +
6 * Eigen::TensorOpCost::AddCost<int32>() +
max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_size * beam_width, batch_beam_cost, DoWork);
}
};
} // namespace functor
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
typename TTypes<T>::ConstVec sequence_length, \
typename TTypes<T, 3>::Tensor beams); \
extern template struct GatherTree<GPUDevice, T>;
DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
} // end namespace functor
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
GatherTreeOp<GPUDevice, T>);
REGISTER_GPU_KERNEL(int32);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -0,0 +1,41 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class OpKernelContext;
namespace functor {
template <typename Device, typename T>
struct GatherTree {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstVec sequence_length,
typename TTypes<T, 3>::Tensor beams);
};
} // namespace functor
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_

View File

@ -0,0 +1,88 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
template <typename T>
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 beam_width, const T* step_ids,
const T* parent_ids,
const T* sequence_length, T* beams) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
const int32 seq_len_b = ldg(sequence_length + batch);
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) {
beams[level_beam_ix] = -1;
parent = -1;
} else {
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
parent = ldg(parent_ids + level_parent_ix);
}
}
#undef GET_IX
}
}
template <typename T>
struct GatherTree<GPUDevice, T> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstVec sequence_length,
typename TTypes<T, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
// First kernel launch to zero things out
beams.device(d) = beams.constant(T(-1));
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
// clang-format off
GatherTreeOpKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
batch_size, max_time, beam_width,
step_ids.data(), parent_ids.data(), sequence_length.data(),
beams.data());
// clang-format on
}
};
#define DEFINE_GPU_SPECS(T) template struct GatherTree<GPUDevice, T>;
DEFINE_GPU_SPECS(int32);
#undef DEFINE_GPU_SPECS
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("GatherTree")
.Input("step_ids: T")
.Input("parent_ids: T")
.Input("sequence_length: T")
.Output("beams: T")
.Attr("T: {int32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle step_ids, parent_ids, sequence_length;
// step_ids, parent_ids, and output are all shaped:
// [batch_size, max_time, beam_width].
// sequence_length is shaped [batch_size].
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length));
DimensionHandle batch_size = c->Dim(step_ids, 1);
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
TF_RETURN_IF_ERROR(
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
c->set_output(0, step_ids);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Calculates the full beams from the per-step ids and parent beam ids.
This op implements the following mathematical equations:
```python
TODO(ebrevdo): fill in
```
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
sequence_length: `[batch_size]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");
} // end namespace tensorflow

View File

@ -91,7 +91,7 @@ class AttentionWrapperTest(test.TestCase):
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
with vs.variable_scope(
"root",
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):

View File

@ -43,7 +43,7 @@ class BasicDecoderTest(test.TestCase):
cell_depth = 10
output_layer_depth = 3
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(cell_depth)
@ -127,7 +127,7 @@ class BasicDecoderTest(test.TestCase):
start_tokens = [0] * batch_size
end_token = 1
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
embeddings = np.random.randn(vocabulary_size,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(vocabulary_size)
@ -196,7 +196,7 @@ class BasicDecoderTest(test.TestCase):
input_depth = 7
vocabulary_size = 10
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(
batch_size, max_time, input_depth).astype(np.float32)
embeddings = np.random.randn(
@ -290,7 +290,7 @@ class BasicDecoderTest(test.TestCase):
else:
auxiliary_inputs = None
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(cell_depth)

View File

@ -0,0 +1,150 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.beam_search_ops."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
def _transpose_batch_time(x):
return np.transpose(x, [1, 0, 2]).astype(np.int32)
class GatherTreeTest(test.TestCase):
def testGatherTreeOne(self):
# (max_time = 4, batch_size = 1, beams = 3)
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
expected_result = _transpose_batch_time(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
def testBadParentValuesOnCPU(self):
# (batch_size = 1, max_time = 4, beams = 3)
# bad parent in beam 1 time 1
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
with self.test_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
_ = beams.eval()
def testBadParentValuesOnGPU(self):
if not test.is_gpu_available():
return
# (max_time = 4, batch_size = 1, beams = 3)
# bad parent in beam 1 time 1; appears as a negative index at time 0
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
with ops.device("/gpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
def testGatherTreeBatch(self):
sequence_length = [0, 1, 2, 3]
with self.test_session(use_gpu=True):
# (max_time = 4, batch_size = 4, beam_width = 5)
step_ids = _transpose_batch_time(
[[[3, 4, 0, 4, 0],
[4, 2, 0, 3, 1],
[1, 1, 3, 2, 2],
[3, 1, 2, 3, 4]],
[[3, 4, 0, 4, 0],
[4, 2, 0, 3, 1],
[1, 1, 3, 2, 2],
[3, 1, 2, 3, 4]],
[[1, 2, 3, 4, 2],
[2, 1, 1, 3, 2],
[3, 0, 1, 0, 0],
[3, 4, 0, 2, 4]],
[[0, 2, 2, 3, 1],
[3, 2, 2, 2, 3],
[3, 4, 3, 0, 3],
[1, 2, 2, 2, 4]]])
parent_ids = _transpose_batch_time(
[[[4, 2, 4, 3, 4],
[3, 4, 0, 2, 0],
[3, 1, 3, 2, 2],
[0, 2, 1, 4, 2]],
[[4, 2, 4, 3, 4],
[3, 4, 0, 2, 0],
[3, 1, 3, 2, 2],
[0, 2, 1, 4, 2]],
[[3, 0, 0, 4, 0],
[1, 2, 4, 2, 2],
[4, 4, 0, 3, 0],
[2, 4, 4, 3, 0]],
[[3, 1, 4, 1, 3],
[3, 2, 4, 0, 4],
[1, 0, 1, 4, 2],
[0, 3, 2, 0, 1]]])
expected_beams = _transpose_batch_time(
[[[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[3, 4, 0, 4, 0],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[2, 3, 2, 3, 3],
[2, 1, 1, 3, 2],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[2, 3, 2, 1, 1],
[2, 3, 2, 3, 2],
[3, 4, 3, 0, 3],
[-1, -1, -1, -1, -1]]])
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
self.assertAllEqual(expected_beams, beams.eval())
if __name__ == "__main__":
test.main()

View File

@ -44,7 +44,7 @@ class DynamicDecodeRNNTest(test.TestCase):
cell_depth = 10
max_out = max(sequence_length)
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
if time_major:
inputs = np.random.randn(max_time, batch_size,
input_depth).astype(np.float32)
@ -118,7 +118,7 @@ class DynamicDecodeRNNTest(test.TestCase):
cell_depth = 10
max_out = max(sequence_length)
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)

View File

@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class LossTest(test.TestCase):
def testSequenceLoss(self):
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
batch_size = 2

View File

@ -0,0 +1,27 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam Search helper ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.seq2seq.ops import gen_beam_search_ops
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader
_beam_search_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_beam_search_ops.so"))
gather_tree = gen_beam_search_ops.gather_tree