[TF contrib seq2seq] Beam search tree decoder op "GatherTree".
Based heavily on Denny Britz's initial implementation. Change: 153470580
This commit is contained in:
parent
617e217cfc
commit
a498546ddf
@ -59,6 +59,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/best_splits_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc"
|
||||
@ -127,6 +129,7 @@ endif(WIN32)
|
||||
file(GLOB_RECURSE tf_core_gpu_kernels_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc"
|
||||
)
|
||||
|
||||
if(WIN32 AND tensorflow_ENABLE_GPU)
|
||||
|
@ -76,6 +76,7 @@ GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/co
|
||||
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
|
||||
|
@ -447,6 +447,8 @@ add_python_module("tensorflow/contrib/saved_model")
|
||||
add_python_module("tensorflow/contrib/saved_model/python")
|
||||
add_python_module("tensorflow/contrib/saved_model/python/saved_model")
|
||||
add_python_module("tensorflow/contrib/seq2seq")
|
||||
add_python_module("tensorflow/contrib/seq2seq/kernels")
|
||||
add_python_module("tensorflow/contrib/seq2seq/ops")
|
||||
add_python_module("tensorflow/contrib/seq2seq/python")
|
||||
add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/seq2seq/python/ops")
|
||||
@ -629,6 +631,8 @@ GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_lstm_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_seq2seq_beam_search_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/ops/gen_beam_search_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops"
|
||||
@ -820,6 +824,26 @@ if(WIN32)
|
||||
DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/python/ops/)
|
||||
endif(WIN32)
|
||||
|
||||
if(WIN32)
|
||||
# include contrib/seq2seq as .so
|
||||
#
|
||||
set(tf_beam_search_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc"
|
||||
)
|
||||
|
||||
set(tf_beam_search_gpu_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc"
|
||||
)
|
||||
|
||||
AddUserOps(TARGET _beam_search_ops
|
||||
SOURCES "${tf_beam_search_srcs}"
|
||||
GPUSOURCES ${tf_beam_search_gpu_srcs}
|
||||
DEPENDS pywrap_tensorflow_internal tf_python_ops
|
||||
DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/)
|
||||
endif(WIN32)
|
||||
|
||||
############################################################
|
||||
# Build a PIP package containing the TensorFlow runtime.
|
||||
############################################################
|
||||
|
@ -8,12 +8,28 @@ exports_files(["LICENSE"])
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_kernel_library",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
|
||||
py_library(
|
||||
tf_custom_op_py_library(
|
||||
name = "seq2seq_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
dso = [
|
||||
":python/ops/_beam_search_ops.so",
|
||||
],
|
||||
kernels = [
|
||||
":beam_search_ops_kernels",
|
||||
":beam_search_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":beam_search_ops",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
@ -29,6 +45,44 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_beam_search_ops.so",
|
||||
srcs = [
|
||||
"kernels/beam_search_ops.cc",
|
||||
"kernels/beam_search_ops.h",
|
||||
"ops/beam_search_ops.cc",
|
||||
],
|
||||
gpu_srcs = [
|
||||
"kernels/beam_search_ops_gpu.cu.cc",
|
||||
"kernels/beam_search_ops.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "beam_search_ops",
|
||||
deps = [":beam_search_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"beam_search_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "beam_search_ops_kernels",
|
||||
prefix = "kernels/beam_search_ops",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "loss_test",
|
||||
size = "medium",
|
||||
@ -67,6 +121,20 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "beam_search_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/beam_search_ops_test.py"],
|
||||
additional_deps = [
|
||||
":seq2seq_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "decoder_test",
|
||||
size = "medium",
|
||||
|
@ -37,6 +37,8 @@ See the @{$python/contrib.seq2seq} guide.
|
||||
|
||||
@@AttentionWrapperState
|
||||
@@AttentionWrapper
|
||||
|
||||
@@gather_tree
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -46,6 +48,7 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.beam_search_ops import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.decoder import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.helper import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.loss import *
|
||||
|
174
tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
Normal file
174
tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class GatherTreeOp : public OpKernel {
|
||||
public:
|
||||
explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
const Tensor& step_ids = ctx->input(0);
|
||||
const Tensor& parent_ids = ctx->input(1);
|
||||
const Tensor& sequence_length = ctx->input(2);
|
||||
const TensorShape& step_ids_shape = step_ids.shape();
|
||||
OP_REQUIRES(
|
||||
ctx, step_ids_shape.dims() == 3,
|
||||
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
|
||||
step_ids_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsVector(sequence_length.shape()),
|
||||
errors::InvalidArgument("sequence_length must be a vector, saw shape: ",
|
||||
sequence_length.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent batch sizes: sequence_length.shape[1] (",
|
||||
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
|
||||
step_ids_shape.dim_size(0), ")"));
|
||||
OP_REQUIRES(
|
||||
ctx, step_ids_shape == parent_ids.shape(),
|
||||
errors::InvalidArgument(
|
||||
"step_ids.shape must match parent_ids.shape. but shapes are: ",
|
||||
step_ids_shape.DebugString(), " and ",
|
||||
parent_ids.shape().DebugString()));
|
||||
Tensor* beams;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
|
||||
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
|
||||
typename TTypes<T>::ConstVec seq_len_t = sequence_length.vec<T>();
|
||||
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
|
||||
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
|
||||
seq_len_t, beams_t);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("GatherTree").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
GatherTreeOp<CPUDevice, T>);
|
||||
REGISTER_KERNEL(int32);
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
namespace functor {
|
||||
|
||||
// CPU specialization
|
||||
template <>
|
||||
struct GatherTree<CPUDevice, int32> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
typename TTypes<int32, 3>::ConstTensor step_ids,
|
||||
typename TTypes<int32, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<int32>::ConstVec sequence_length,
|
||||
typename TTypes<int32, 3>::Tensor beams) {
|
||||
const int64 max_time = parent_ids.dimension(0);
|
||||
const int64 batch_size = parent_ids.dimension(1);
|
||||
const int64 beam_width = parent_ids.dimension(2);
|
||||
beams.setConstant(-1);
|
||||
|
||||
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
|
||||
int32 seq_len_b = -1;
|
||||
int32 old_batch = -1;
|
||||
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
if (batch != old_batch) {
|
||||
seq_len_b = sequence_length(batch);
|
||||
old_batch = batch;
|
||||
}
|
||||
if (seq_len_b == 0) {
|
||||
continue;
|
||||
}
|
||||
beams(seq_len_b - 1, batch, beam) =
|
||||
step_ids(seq_len_b - 1, batch, beam);
|
||||
int32 parent = parent_ids(seq_len_b - 1, batch, beam);
|
||||
for (int32 level = seq_len_b - 2; level >= 0; --level) {
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
ctx->SetStatus(
|
||||
errors::InvalidArgument("Saw invalid parent id ", parent,
|
||||
" at (batch, time, beam) == (", batch,
|
||||
", ", level, ", ", beam, ")"));
|
||||
return;
|
||||
}
|
||||
beams(level, batch, beam) = step_ids(level, batch, parent);
|
||||
parent = parent_ids(level, batch, parent);
|
||||
}
|
||||
}
|
||||
};
|
||||
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
|
||||
// traversal time step.
|
||||
const int64 batch_beam_cost =
|
||||
Eigen::TensorOpCost::DivCost<int32>() +
|
||||
6 * Eigen::TensorOpCost::AddCost<int32>() +
|
||||
max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
batch_size * beam_width, batch_beam_cost, DoWork);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GatherTree<GPUDevice, T>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T, 3>::ConstTensor step_ids, \
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids, \
|
||||
typename TTypes<T>::ConstVec sequence_length, \
|
||||
typename TTypes<T, 3>::Tensor beams); \
|
||||
extern template struct GatherTree<GPUDevice, T>;
|
||||
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
GatherTreeOp<GPUDevice, T>);
|
||||
|
||||
REGISTER_GPU_KERNEL(int32);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // end namespace tensorflow
|
41
tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
Normal file
41
tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/eigen_activations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct GatherTree {
|
||||
void operator()(OpKernelContext* ctx, const Device& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstVec sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
|
88
tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
Normal file
88
tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T>
|
||||
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
const int32 beam_width, const T* step_ids,
|
||||
const T* parent_ids,
|
||||
const T* sequence_length, T* beams) {
|
||||
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
const int32 seq_len_b = ldg(sequence_length + batch);
|
||||
#define GET_IX(time_ix, beam_ix) \
|
||||
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
|
||||
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
|
||||
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
|
||||
int32 parent = ldg(parent_ids + initial_beam_ix);
|
||||
for (int32 level = seq_len_b - 2; level >= 0; --level) {
|
||||
const int32 level_beam_ix = GET_IX(level, beam);
|
||||
const int32 level_parent_ix = GET_IX(level, parent);
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
beams[level_beam_ix] = -1;
|
||||
parent = -1;
|
||||
} else {
|
||||
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
|
||||
parent = ldg(parent_ids + level_parent_ix);
|
||||
}
|
||||
}
|
||||
#undef GET_IX
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GatherTree<GPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstVec sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams) {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
// First kernel launch to zero things out
|
||||
beams.device(d) = beams.constant(T(-1));
|
||||
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
|
||||
// clang-format off
|
||||
GatherTreeOpKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
batch_size, max_time, beam_width,
|
||||
step_ids.data(), parent_ids.data(), sequence_length.data(),
|
||||
beams.data());
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) template struct GatherTree<GPUDevice, T>;
|
||||
|
||||
DEFINE_GPU_SPECS(int32);
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
65
tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
Normal file
65
tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("GatherTree")
|
||||
.Input("step_ids: T")
|
||||
.Input("parent_ids: T")
|
||||
.Input("sequence_length: T")
|
||||
.Output("beams: T")
|
||||
.Attr("T: {int32}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle step_ids, parent_ids, sequence_length;
|
||||
|
||||
// step_ids, parent_ids, and output are all shaped:
|
||||
// [batch_size, max_time, beam_width].
|
||||
// sequence_length is shaped [batch_size].
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length));
|
||||
|
||||
DimensionHandle batch_size = c->Dim(step_ids, 1);
|
||||
|
||||
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
|
||||
|
||||
c->set_output(0, step_ids);
|
||||
return tensorflow::Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Calculates the full beams from the per-step ids and parent beam ids.
|
||||
|
||||
This op implements the following mathematical equations:
|
||||
|
||||
```python
|
||||
TODO(ebrevdo): fill in
|
||||
```
|
||||
|
||||
step_ids: `[max_time, batch_size, beam_width]`.
|
||||
parent_ids: `[max_time, batch_size, beam_width]`.
|
||||
sequence_length: `[batch_size]`.
|
||||
beams: `[max_time, batch_size, beam_width]`.
|
||||
)doc");
|
||||
|
||||
} // end namespace tensorflow
|
@ -91,7 +91,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
memory=encoder_outputs,
|
||||
memory_sequence_length=encoder_sequence_length)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with vs.variable_scope(
|
||||
"root",
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
|
||||
|
@ -43,7 +43,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
cell_depth = 10
|
||||
output_layer_depth = 3
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
@ -127,7 +127,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
start_tokens = [0] * batch_size
|
||||
end_token = 1
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
embeddings = np.random.randn(vocabulary_size,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
||||
@ -196,7 +196,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
input_depth = 7
|
||||
vocabulary_size = 10
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(
|
||||
batch_size, max_time, input_depth).astype(np.float32)
|
||||
embeddings = np.random.randn(
|
||||
@ -290,7 +290,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
else:
|
||||
auxiliary_inputs = None
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
|
@ -0,0 +1,150 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.seq2seq.beam_search_ops."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _transpose_batch_time(x):
|
||||
return np.transpose(x, [1, 0, 2]).astype(np.int32)
|
||||
|
||||
|
||||
class GatherTreeTest(test.TestCase):
|
||||
|
||||
def testGatherTreeOne(self):
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testBadParentValuesOnCPU(self):
|
||||
# (batch_size = 1, max_time = 4, beams = 3)
|
||||
# bad parent in beam 1 time 1
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
with ops.device("/cpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
_ = beams.eval()
|
||||
|
||||
def testBadParentValuesOnGPU(self):
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
# bad parent in beam 1 time 1; appears as a negative index at time 0
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
with ops.device("/gpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testGatherTreeBatch(self):
|
||||
sequence_length = [0, 1, 2, 3]
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
# (max_time = 4, batch_size = 4, beam_width = 5)
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[3, 4, 0, 4, 0],
|
||||
[4, 2, 0, 3, 1],
|
||||
[1, 1, 3, 2, 2],
|
||||
[3, 1, 2, 3, 4]],
|
||||
[[3, 4, 0, 4, 0],
|
||||
[4, 2, 0, 3, 1],
|
||||
[1, 1, 3, 2, 2],
|
||||
[3, 1, 2, 3, 4]],
|
||||
[[1, 2, 3, 4, 2],
|
||||
[2, 1, 1, 3, 2],
|
||||
[3, 0, 1, 0, 0],
|
||||
[3, 4, 0, 2, 4]],
|
||||
[[0, 2, 2, 3, 1],
|
||||
[3, 2, 2, 2, 3],
|
||||
[3, 4, 3, 0, 3],
|
||||
[1, 2, 2, 2, 4]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[4, 2, 4, 3, 4],
|
||||
[3, 4, 0, 2, 0],
|
||||
[3, 1, 3, 2, 2],
|
||||
[0, 2, 1, 4, 2]],
|
||||
[[4, 2, 4, 3, 4],
|
||||
[3, 4, 0, 2, 0],
|
||||
[3, 1, 3, 2, 2],
|
||||
[0, 2, 1, 4, 2]],
|
||||
[[3, 0, 0, 4, 0],
|
||||
[1, 2, 4, 2, 2],
|
||||
[4, 4, 0, 3, 0],
|
||||
[2, 4, 4, 3, 0]],
|
||||
[[3, 1, 4, 1, 3],
|
||||
[3, 2, 4, 0, 4],
|
||||
[1, 0, 1, 4, 2],
|
||||
[0, 3, 2, 0, 1]]])
|
||||
expected_beams = _transpose_batch_time(
|
||||
[[[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[3, 4, 0, 4, 0],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[2, 3, 2, 3, 3],
|
||||
[2, 1, 1, 3, 2],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[2, 3, 2, 1, 1],
|
||||
[2, 3, 2, 3, 2],
|
||||
[3, 4, 3, 0, 3],
|
||||
[-1, -1, -1, -1, -1]]])
|
||||
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
self.assertAllEqual(expected_beams, beams.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -44,7 +44,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
cell_depth = 10
|
||||
max_out = max(sequence_length)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
if time_major:
|
||||
inputs = np.random.randn(max_time, batch_size,
|
||||
input_depth).astype(np.float32)
|
||||
@ -118,7 +118,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
cell_depth = 10
|
||||
max_out = max(sequence_length)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
|
||||
|
@ -34,7 +34,7 @@ from tensorflow.python.platform import test
|
||||
class LossTest(test.TestCase):
|
||||
|
||||
def testSequenceLoss(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
batch_size = 2
|
||||
|
27
tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py
Normal file
27
tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Beam Search helper ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.seq2seq.ops import gen_beam_search_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_beam_search_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_beam_search_ops.so"))
|
||||
|
||||
gather_tree = gen_beam_search_ops.gather_tree
|
Loading…
Reference in New Issue
Block a user