From a498546ddf66f1c96b2fc9fd1ec5c57cd76e4aa0 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 18 Apr 2017 08:25:04 -0800 Subject: [PATCH] [TF contrib seq2seq] Beam search tree decoder op "GatherTree". Based heavily on Denny Britz's initial implementation. Change: 153470580 --- .../contrib/cmake/tf_core_kernels.cmake | 3 + tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 24 +++ tensorflow/contrib/seq2seq/BUILD | 70 ++++++- tensorflow/contrib/seq2seq/__init__.py | 3 + .../seq2seq/kernels/beam_search_ops.cc | 174 ++++++++++++++++++ .../contrib/seq2seq/kernels/beam_search_ops.h | 41 +++++ .../seq2seq/kernels/beam_search_ops_gpu.cu.cc | 88 +++++++++ .../contrib/seq2seq/ops/beam_search_ops.cc | 65 +++++++ .../kernel_tests/attention_wrapper_test.py | 2 +- .../python/kernel_tests/basic_decoder_test.py | 8 +- .../kernel_tests/beam_search_ops_test.py | 150 +++++++++++++++ .../python/kernel_tests/decoder_test.py | 4 +- .../seq2seq/python/kernel_tests/loss_test.py | 2 +- .../seq2seq/python/ops/beam_search_ops.py | 27 +++ 15 files changed, 653 insertions(+), 9 deletions(-) create mode 100644 tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc create mode 100644 tensorflow/contrib/seq2seq/kernels/beam_search_ops.h create mode 100644 tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc create mode 100644 tensorflow/contrib/seq2seq/ops/beam_search_ops.cc create mode 100644 tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py create mode 100644 tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 1b2096b6458..0c420a02534 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -59,6 +59,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/best_splits_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc" @@ -127,6 +129,7 @@ endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc" ) if(WIN32 AND tensorflow_ENABLE_GPU) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 6127f674f24..2beb264a54e 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -76,6 +76,7 @@ GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/co GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 983f976c95c..033d54cc1ea 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -447,6 +447,8 @@ add_python_module("tensorflow/contrib/saved_model") add_python_module("tensorflow/contrib/saved_model/python") add_python_module("tensorflow/contrib/saved_model/python/saved_model") add_python_module("tensorflow/contrib/seq2seq") +add_python_module("tensorflow/contrib/seq2seq/kernels") +add_python_module("tensorflow/contrib/seq2seq/ops") add_python_module("tensorflow/contrib/seq2seq/python") add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") add_python_module("tensorflow/contrib/seq2seq/python/ops") @@ -629,6 +631,8 @@ GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_lstm_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_seq2seq_beam_search_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/ops/gen_beam_search_ops.py) GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py) GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops" @@ -820,6 +824,26 @@ if(WIN32) DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/python/ops/) endif(WIN32) +if(WIN32) + # include contrib/seq2seq as .so + # + set(tf_beam_search_srcs + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc" + ) + + set(tf_beam_search_gpu_srcs + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc" + ) + + AddUserOps(TARGET _beam_search_ops + SOURCES "${tf_beam_search_srcs}" + GPUSOURCES ${tf_beam_search_gpu_srcs} + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/) +endif(WIN32) + ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 652bbba85ef..011c3ba4271 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -8,12 +8,28 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_kernel_library", + "tf_gen_op_wrapper_py", +) -py_library( +tf_custom_op_py_library( name = "seq2seq_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_beam_search_ops.so", + ], + kernels = [ + ":beam_search_ops_kernels", + ":beam_search_ops_op_lib", + ], srcs_version = "PY2AND3", deps = [ + ":beam_search_ops", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", @@ -29,6 +45,44 @@ py_library( ], ) +tf_custom_op_library( + name = "python/ops/_beam_search_ops.so", + srcs = [ + "kernels/beam_search_ops.cc", + "kernels/beam_search_ops.h", + "ops/beam_search_ops.cc", + ], + gpu_srcs = [ + "kernels/beam_search_ops_gpu.cu.cc", + "kernels/beam_search_ops.h", + ], + deps = [ + "//tensorflow/core/kernels:eigen_helpers", + ], +) + +tf_gen_op_wrapper_py( + name = "beam_search_ops", + deps = [":beam_search_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "beam_search_ops", + ], +) + +tf_kernel_library( + name = "beam_search_ops_kernels", + prefix = "kernels/beam_search_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ], +) + cuda_py_test( name = "loss_test", size = "medium", @@ -67,6 +121,20 @@ cuda_py_test( ], ) +cuda_py_test( + name = "beam_search_ops_test", + size = "medium", + srcs = ["python/kernel_tests/beam_search_ops_test.py"], + additional_deps = [ + ":seq2seq_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "decoder_test", size = "medium", diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py index 277434c1606..933f5ebe146 100644 --- a/tensorflow/contrib/seq2seq/__init__.py +++ b/tensorflow/contrib/seq2seq/__init__.py @@ -37,6 +37,8 @@ See the @{$python/contrib.seq2seq} guide. @@AttentionWrapperState @@AttentionWrapper + +@@gather_tree """ from __future__ import absolute_import @@ -46,6 +48,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import * from tensorflow.contrib.seq2seq.python.ops.basic_decoder import * +from tensorflow.contrib.seq2seq.python.ops.beam_search_ops import * from tensorflow.contrib.seq2seq.python.ops.decoder import * from tensorflow.contrib.seq2seq.python.ops.helper import * from tensorflow.contrib.seq2seq.python.ops.loss import * diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc new file mode 100644 index 00000000000..3b0568794dc --- /dev/null +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class GatherTreeOp : public OpKernel { + public: + explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Device& device = ctx->eigen_device(); + const Tensor& step_ids = ctx->input(0); + const Tensor& parent_ids = ctx->input(1); + const Tensor& sequence_length = ctx->input(2); + const TensorShape& step_ids_shape = step_ids.shape(); + OP_REQUIRES( + ctx, step_ids_shape.dims() == 3, + errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ", + step_ids_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(sequence_length.shape()), + errors::InvalidArgument("sequence_length must be a vector, saw shape: ", + sequence_length.shape().DebugString())); + OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1), + errors::InvalidArgument( + "Inconsistent batch sizes: sequence_length.shape[1] (", + sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (", + step_ids_shape.dim_size(0), ")")); + OP_REQUIRES( + ctx, step_ids_shape == parent_ids.shape(), + errors::InvalidArgument( + "step_ids.shape must match parent_ids.shape. but shapes are: ", + step_ids_shape.DebugString(), " and ", + parent_ids.shape().DebugString())); + Tensor* beams; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); + typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); + typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); + typename TTypes::ConstVec seq_len_t = sequence_length.vec(); + typename TTypes::Tensor beams_t = beams->tensor(); + functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, + seq_len_t, beams_t); + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("GatherTree").Device(DEVICE_CPU).TypeConstraint("T"), \ + GatherTreeOp); +REGISTER_KERNEL(int32); +#undef REGISTER_KERNEL + +namespace functor { + +// CPU specialization +template <> +struct GatherTree { + void operator()(OpKernelContext* ctx, const CPUDevice& d, + typename TTypes::ConstTensor step_ids, + typename TTypes::ConstTensor parent_ids, + typename TTypes::ConstVec sequence_length, + typename TTypes::Tensor beams) { + const int64 max_time = parent_ids.dimension(0); + const int64 batch_size = parent_ids.dimension(1); + const int64 beam_width = parent_ids.dimension(2); + beams.setConstant(-1); + + auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) { + int32 seq_len_b = -1; + int32 old_batch = -1; + for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) { + const int32 batch = i / beam_width; + const int32 beam = i % beam_width; + if (batch != old_batch) { + seq_len_b = sequence_length(batch); + old_batch = batch; + } + if (seq_len_b == 0) { + continue; + } + beams(seq_len_b - 1, batch, beam) = + step_ids(seq_len_b - 1, batch, beam); + int32 parent = parent_ids(seq_len_b - 1, batch, beam); + for (int32 level = seq_len_b - 2; level >= 0; --level) { + if (parent < 0 || parent > beam_width) { + ctx->SetStatus( + errors::InvalidArgument("Saw invalid parent id ", parent, + " at (batch, time, beam) == (", batch, + ", ", level, ", ", beam, ")")); + return; + } + beams(level, batch, beam) = step_ids(level, batch, parent); + parent = parent_ids(level, batch, parent); + } + } + }; + // Guesstimate of cost; ~5 lookup/store/compare per inner beam + // traversal time step. + const int64 batch_beam_cost = + Eigen::TensorOpCost::DivCost() + + 6 * Eigen::TensorOpCost::AddCost() + + max_time * (5 * Eigen::TensorOpCost::AddCost()); + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, + batch_size * beam_width, batch_beam_cost, DoWork); + } +}; + +} // namespace functor + +#if GOOGLE_CUDA +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void GatherTree::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes::ConstTensor step_ids, \ + typename TTypes::ConstTensor parent_ids, \ + typename TTypes::ConstVec sequence_length, \ + typename TTypes::Tensor beams); \ + extern template struct GatherTree; + +DECLARE_GPU_SPEC(int32); +#undef DECLARE_GPU_SPEC +} // end namespace functor + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("GatherTree").Device(DEVICE_GPU).TypeConstraint("T"), \ + GatherTreeOp); + +REGISTER_GPU_KERNEL(int32); +#undef REGISTER_GPU_KERNEL +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h new file mode 100644 index 00000000000..501a2eae848 --- /dev/null +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class OpKernelContext; + +namespace functor { + +template +struct GatherTree { + void operator()(OpKernelContext* ctx, const Device& d, + typename TTypes::ConstTensor step_ids, + typename TTypes::ConstTensor parent_ids, + typename TTypes::ConstVec sequence_length, + typename TTypes::Tensor beams); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc new file mode 100644 index 00000000000..8d8fc810015 --- /dev/null +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +template +__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, + const int32 beam_width, const T* step_ids, + const T* parent_ids, + const T* sequence_length, T* beams) { + CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) { + const int32 batch = i / beam_width; + const int32 beam = i % beam_width; + const int32 seq_len_b = ldg(sequence_length + batch); +#define GET_IX(time_ix, beam_ix) \ + (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) + const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam); + beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); + int32 parent = ldg(parent_ids + initial_beam_ix); + for (int32 level = seq_len_b - 2; level >= 0; --level) { + const int32 level_beam_ix = GET_IX(level, beam); + const int32 level_parent_ix = GET_IX(level, parent); + if (parent < 0 || parent > beam_width) { + beams[level_beam_ix] = -1; + parent = -1; + } else { + beams[level_beam_ix] = ldg(step_ids + level_parent_ix); + parent = ldg(parent_ids + level_parent_ix); + } + } +#undef GET_IX + } +} + +template +struct GatherTree { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + typename TTypes::ConstTensor step_ids, + typename TTypes::ConstTensor parent_ids, + typename TTypes::ConstVec sequence_length, + typename TTypes::Tensor beams) { + const int32 max_time = parent_ids.dimension(0); + const int32 batch_size = parent_ids.dimension(1); + const int32 beam_width = parent_ids.dimension(2); + // First kernel launch to zero things out + beams.device(d) = beams.constant(T(-1)); + + CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); + // clang-format off + GatherTreeOpKernel + <<>>( + batch_size, max_time, beam_width, + step_ids.data(), parent_ids.data(), sequence_length.data(), + beams.data()); + // clang-format on + } +}; + +#define DEFINE_GPU_SPECS(T) template struct GatherTree; + +DEFINE_GPU_SPECS(int32); +#undef DEFINE_GPU_SPECS + +} // end namespace functor +} // end namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc new file mode 100644 index 00000000000..c167736d882 --- /dev/null +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("GatherTree") + .Input("step_ids: T") + .Input("parent_ids: T") + .Input("sequence_length: T") + .Output("beams: T") + .Attr("T: {int32}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle step_ids, parent_ids, sequence_length; + + // step_ids, parent_ids, and output are all shaped: + // [batch_size, max_time, beam_width]. + // sequence_length is shaped [batch_size]. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length)); + + DimensionHandle batch_size = c->Dim(step_ids, 1); + + TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids)); + TF_RETURN_IF_ERROR( + c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); + + c->set_output(0, step_ids); + return tensorflow::Status::OK(); + }) + .Doc(R"doc( +Calculates the full beams from the per-step ids and parent beam ids. + +This op implements the following mathematical equations: + +```python +TODO(ebrevdo): fill in +``` + +step_ids: `[max_time, batch_size, beam_width]`. +parent_ids: `[max_time, batch_size, beam_width]`. +sequence_length: `[batch_size]`. +beams: `[max_time, batch_size, beam_width]`. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 067af818bb8..aa84ae060c9 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -91,7 +91,7 @@ class AttentionWrapperTest(test.TestCase): memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: with vs.variable_scope( "root", initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index 276801ba7c8..6b57293c6f7 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -43,7 +43,7 @@ class BasicDecoderTest(test.TestCase): cell_depth = 10 output_layer_depth = 3 - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) @@ -127,7 +127,7 @@ class BasicDecoderTest(test.TestCase): start_tokens = [0] * batch_size end_token = 1 - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(vocabulary_size) @@ -196,7 +196,7 @@ class BasicDecoderTest(test.TestCase): input_depth = 7 vocabulary_size = 10 - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: inputs = np.random.randn( batch_size, max_time, input_depth).astype(np.float32) embeddings = np.random.randn( @@ -290,7 +290,7 @@ class BasicDecoderTest(test.TestCase): else: auxiliary_inputs = None - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py new file mode 100644 index 00000000000..542254854a4 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -0,0 +1,150 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.beam_search_ops.""" +# pylint: disable=unused-import,g-bad-import-order +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# pylint: enable=unused-import + +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +def _transpose_batch_time(x): + return np.transpose(x, [1, 0, 2]).astype(np.int32) + + +class GatherTreeTest(test.TestCase): + + def testGatherTreeOne(self): + # (max_time = 4, batch_size = 1, beams = 3) + step_ids = _transpose_batch_time( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + parent_ids = _transpose_batch_time( + [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) + sequence_length = [3] + expected_result = _transpose_batch_time( + [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + beams = beam_search_ops.gather_tree( + step_ids=step_ids, parent_ids=parent_ids, + sequence_length=sequence_length) + with self.test_session(use_gpu=True): + self.assertAllEqual(expected_result, beams.eval()) + + def testBadParentValuesOnCPU(self): + # (batch_size = 1, max_time = 4, beams = 3) + # bad parent in beam 1 time 1 + step_ids = _transpose_batch_time( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + parent_ids = _transpose_batch_time( + [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) + sequence_length = [3] + with ops.device("/cpu:0"): + beams = beam_search_ops.gather_tree( + step_ids=step_ids, parent_ids=parent_ids, + sequence_length=sequence_length) + with self.test_session(): + with self.assertRaisesOpError( + r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): + _ = beams.eval() + + def testBadParentValuesOnGPU(self): + if not test.is_gpu_available(): + return + # (max_time = 4, batch_size = 1, beams = 3) + # bad parent in beam 1 time 1; appears as a negative index at time 0 + step_ids = _transpose_batch_time( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + parent_ids = _transpose_batch_time( + [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) + sequence_length = [3] + expected_result = _transpose_batch_time( + [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + with ops.device("/gpu:0"): + beams = beam_search_ops.gather_tree( + step_ids=step_ids, parent_ids=parent_ids, + sequence_length=sequence_length) + with self.test_session(use_gpu=True): + self.assertAllEqual(expected_result, beams.eval()) + + def testGatherTreeBatch(self): + sequence_length = [0, 1, 2, 3] + + with self.test_session(use_gpu=True): + # (max_time = 4, batch_size = 4, beam_width = 5) + step_ids = _transpose_batch_time( + [[[3, 4, 0, 4, 0], + [4, 2, 0, 3, 1], + [1, 1, 3, 2, 2], + [3, 1, 2, 3, 4]], + [[3, 4, 0, 4, 0], + [4, 2, 0, 3, 1], + [1, 1, 3, 2, 2], + [3, 1, 2, 3, 4]], + [[1, 2, 3, 4, 2], + [2, 1, 1, 3, 2], + [3, 0, 1, 0, 0], + [3, 4, 0, 2, 4]], + [[0, 2, 2, 3, 1], + [3, 2, 2, 2, 3], + [3, 4, 3, 0, 3], + [1, 2, 2, 2, 4]]]) + parent_ids = _transpose_batch_time( + [[[4, 2, 4, 3, 4], + [3, 4, 0, 2, 0], + [3, 1, 3, 2, 2], + [0, 2, 1, 4, 2]], + [[4, 2, 4, 3, 4], + [3, 4, 0, 2, 0], + [3, 1, 3, 2, 2], + [0, 2, 1, 4, 2]], + [[3, 0, 0, 4, 0], + [1, 2, 4, 2, 2], + [4, 4, 0, 3, 0], + [2, 4, 4, 3, 0]], + [[3, 1, 4, 1, 3], + [3, 2, 4, 0, 4], + [1, 0, 1, 4, 2], + [0, 3, 2, 0, 1]]]) + expected_beams = _transpose_batch_time( + [[[-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]], + [[3, 4, 0, 4, 0], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]], + [[2, 3, 2, 3, 3], + [2, 1, 1, 3, 2], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]], + [[2, 3, 2, 1, 1], + [2, 3, 2, 3, 2], + [3, 4, 3, 0, 3], + [-1, -1, -1, -1, -1]]]) + + beams = beam_search_ops.gather_tree( + step_ids=step_ids, parent_ids=parent_ids, + sequence_length=sequence_length) + self.assertAllEqual(expected_beams, beams.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 00854ed8b74..340ec9bbb22 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -44,7 +44,7 @@ class DynamicDecodeRNNTest(test.TestCase): cell_depth = 10 max_out = max(sequence_length) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: if time_major: inputs = np.random.randn(max_time, batch_size, input_depth).astype(np.float32) @@ -118,7 +118,7 @@ class DynamicDecodeRNNTest(test.TestCase): cell_depth = 10 max_out = max(sequence_length) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 8c6d85d1061..35c601a4bcf 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import test class LossTest(test.TestCase): def testSequenceLoss(self): - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): batch_size = 2 diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py new file mode 100644 index 00000000000..7d9fcc0c90a --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_ops.py @@ -0,0 +1,27 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Beam Search helper ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.seq2seq.ops import gen_beam_search_ops +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_beam_search_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_beam_search_ops.so")) + +gather_tree = gen_beam_search_ops.gather_tree